diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 0bfb86bf1f..5d45689c8c 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
+ self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
self.store.is_support_user = Mock(return_value=defer.succeed(False))
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
- self.store.get_user_by_access_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(
+ return_value=defer.succeed(user_info)
+ )
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase):
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"}
- self.store.get_user_by_access_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(
+ return_value=defer.succeed(user_info)
+ )
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase):
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
@@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
@@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ # This just needs to return a truth-y value.
+ self.store.get_user_by_id = Mock(
+ return_value=defer.succeed({"is_guest": False})
+ )
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
- return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
+ return_value=defer.succeed(
+ {"name": "@baldrick:matrix.org", "device_id": "device"}
+ )
)
user_id = "@baldrick:matrix.org"
@@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
- self.store.get_user_by_id = Mock(return_value={"is_guest": True})
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase):
def get_user(tok):
if token != tok:
- return None
- return {
- "name": USER_ID,
- "is_guest": False,
- "token_id": 1234,
- "device_id": "DEVICE",
- }
+ return defer.succeed(None)
+ return defer.succeed(
+ {
+ "name": USER_ID,
+ "is_guest": False,
+ "token_id": 1234,
+ "device_id": "DEVICE",
+ }
+ )
self.store.get_user_by_access_token = get_user
- self.store.get_user_by_id = Mock(return_value={"is_guest": False})
+ self.store.get_user_by_id = Mock(
+ return_value=defer.succeed({"is_guest": False})
+ )
# check the token works
request = Mock(args={})
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 4e67503cf0..1fab1d6b69 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_presence(events=events)
@@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart + "2", filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart + "2", filter_id=filter_id
+ )
)
results = user_filter.filter_presence(events=events)
@@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_room_state(events=events)
@@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_room_state(events)
@@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(
user_filter_json,
(
- yield self.datastore.get_user_filter(
- user_localpart=user_localpart, filter_id=0
+ yield defer.ensureDeferred(
+ self.datastore.get_user_filter(
+ user_localpart=user_localpart, filter_id=0
+ )
)
),
)
@@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase):
user_localpart=user_localpart, user_filter=user_filter_json
)
- filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
self.assertEquals(filter.get_filter_json(), user_filter_json)
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index d580e729c5..1e1f30d790 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,4 +1,6 @@
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
+from synapse.appservice import ApplicationService
+from synapse.types import create_requester
from tests import unittest
@@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase):
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
+ def test_allowed_user_via_can_requester_do_action(self):
+ user_requester = create_requester("@user:example.com")
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_requester_do_action(
+ user_requester, _time_now_s=0
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ user_requester, _time_now_s=5
+ )
+ self.assertFalse(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ user_requester, _time_now_s=10
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(20.0, time_allowed)
+
+ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
+ appservice = ApplicationService(
+ None, "example.com", id="foo", rate_limited=True,
+ )
+ as_requester = create_requester("@user:example.com", app_service=appservice)
+
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=0
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=5
+ )
+ self.assertFalse(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=10
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(20.0, time_allowed)
+
+ def test_allowed_appservice_via_can_requester_do_action(self):
+ appservice = ApplicationService(
+ None, "example.com", id="foo", rate_limited=False,
+ )
+ as_requester = create_requester("@user:example.com", app_service=appservice)
+
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=0
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(-1, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=5
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(-1, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=10
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(-1, time_allowed)
+
def test_allowed_via_ratelimit(self):
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index be20a89682..641093d349 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -30,6 +30,16 @@ class FrontendProxyTests(HomeserverTestCase):
def default_config(self):
c = super().default_config()
c["worker_app"] = "synapse.app.frontend_proxy"
+
+ c["worker_listeners"] = [
+ {
+ "type": "http",
+ "port": 8080,
+ "bind_addresses": ["0.0.0.0"],
+ "resources": [{"names": ["client"]}],
+ }
+ ]
+
return c
def test_listen_http_with_presence_enabled(self):
@@ -39,14 +49,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Presence is on
self.hs.config.use_presence = True
- config = {
- "port": 8080,
- "bind_addresses": ["0.0.0.0"],
- "resources": [{"names": ["client"]}],
- }
-
# Listen with the config
- self.hs._listen_http(config)
+ self.hs._listen_http(self.hs.config.worker.worker_listeners[0])
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
@@ -67,14 +71,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Presence is off
self.hs.config.use_presence = False
- config = {
- "port": 8080,
- "bind_addresses": ["0.0.0.0"],
- "resources": [{"names": ["client"]}],
- }
-
# Listen with the config
- self.hs._listen_http(config)
+ self.hs._listen_http(self.hs.config.worker.worker_listeners[0])
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 7364f9f1ec..0f016c32eb 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -18,6 +18,7 @@ from parameterized import parameterized
from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer
+from synapse.config.server import parse_listener_def
from tests.unittest import HomeserverTestCase
@@ -35,6 +36,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# have to tell the FederationHandler not to try to access stuff that is only
# in the primary store.
conf["worker_app"] = "yes"
+
return conf
@parameterized.expand(
@@ -53,12 +55,13 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
"""
config = {
"port": 8080,
+ "type": "http",
"bind_addresses": ["0.0.0.0"],
"resources": [{"names": names}],
}
# Listen with the config
- self.hs._listen_http(config)
+ self.hs._listen_http(parse_listener_def(config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
@@ -101,12 +104,13 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
"""
config = {
"port": 8080,
+ "type": "http",
"bind_addresses": ["0.0.0.0"],
"resources": [{"names": names}],
}
# Listen with the config
- self.hs._listener_http(config, config)
+ self.hs._listener_http(self.hs.get_config(), parse_listener_def(config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 4003869ed6..236b608d58 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org"
- self.assertFalse((yield self.service.is_interested(self.event)))
+ self.assertFalse(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_room_member_is_checked(self):
@@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_room_id_match(self):
@@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_room_id_no_match(self):
@@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
- self.assertFalse((yield self.service.is_interested(self.event)))
+ self.assertFalse(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.store.get_aliases_for_room.return_value = [
- "#irc_foobar:matrix.org",
- "#athing:matrix.org",
- ]
- self.store.get_users_in_room.return_value = []
- self.assertTrue((yield self.service.is_interested(self.event, self.store)))
+ self.store.get_aliases_for_room.return_value = defer.succeed(
+ ["#irc_foobar:matrix.org", "#athing:matrix.org"]
+ )
+ self.store.get_users_in_room.return_value = defer.succeed([])
+ self.assertTrue(
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
+ )
def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
@@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.store.get_aliases_for_room.return_value = [
- "#xmpp_foobar:matrix.org",
- "#athing:matrix.org",
- ]
- self.store.get_users_in_room.return_value = []
- self.assertFalse((yield self.service.is_interested(self.event, self.store)))
+ self.store.get_aliases_for_room.return_value = defer.succeed(
+ ["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
+ )
+ self.store.get_users_in_room.return_value = defer.succeed([])
+ self.assertFalse(
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
+ )
@defer.inlineCallbacks
def test_regex_multiple_matches(self):
@@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
)
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
- self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
- self.store.get_users_in_room.return_value = []
- self.assertTrue((yield self.service.is_interested(self.event, self.store)))
+ self.store.get_aliases_for_room.return_value = defer.succeed(
+ ["#irc_barfoo:matrix.org"]
+ )
+ self.store.get_users_in_room.return_value = defer.succeed([])
+ self.assertTrue(
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
+ )
@defer.inlineCallbacks
def test_interested_in_self(self):
@@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.type = "m.room.member"
self.event.content = {"membership": "invite"}
self.event.state_key = self.service.sender
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
- self.store.get_users_in_room.return_value = [
- "@alice:here",
- "@irc_fo:here", # AS user
- "@bob:here",
- ]
- self.store.get_aliases_for_room.return_value = []
+ # Note that @irc_fo:here is the AS user.
+ self.store.get_users_in_room.return_value = defer.succeed(
+ ["@alice:here", "@irc_fo:here", "@bob:here"]
+ )
+ self.store.get_aliases_for_room.return_value = defer.succeed([])
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(
- (yield self.service.is_interested(event=self.event, store=self.store))
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(event=self.event, store=self.store)
+ )
+ )
)
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 52f89d3f83..68a4caabbf 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -25,6 +25,7 @@ from synapse.appservice.scheduler import (
from synapse.logging.context import make_deferred_yieldable
from tests import unittest
+from tests.test_utils import make_awaitable
from ..utils import MockClock
@@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.get_appservice_state = Mock(
return_value=defer.succeed(ApplicationServiceState.UP)
)
- txn.send = Mock(return_value=defer.succeed(True))
+ txn.send = Mock(return_value=make_awaitable(True))
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call
- self.txnctrl.send(service, events)
+ self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved
@@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call
- self.txnctrl.send(service, events)
+ self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved
@@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
return_value=defer.succeed(ApplicationServiceState.UP)
)
self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
- txn.send = Mock(return_value=defer.succeed(False)) # fails to send
+ txn.send = Mock(return_value=make_awaitable(False)) # fails to send
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call
- self.txnctrl.send(service, events)
+ self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events
@@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
# shouldn't have called anything prior to waiting for exp backoff
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
- txn.send = Mock(return_value=True)
+ txn.send = Mock(return_value=make_awaitable(True))
+ txn.complete.return_value = make_awaitable(None)
# wait for exp backoff
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
@@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
- txn.send = Mock(return_value=False)
+ txn.send = Mock(return_value=make_awaitable(False))
+ txn.complete.return_value = make_awaitable(None)
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
@@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEquals(3, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
self.assertEquals(0, self.callback.call_count)
- txn.send = Mock(return_value=True) # successfully send the txn
+ txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16)
self.assertEquals(1, txn.send.call_count) # new mock reset call count
diff --git a/tests/config/test_base.py b/tests/config/test_base.py
new file mode 100644
index 0000000000..42ee5f56d9
--- /dev/null
+++ b/tests/config/test_base.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path
+import tempfile
+
+from synapse.config import ConfigError
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+
+
+class BaseConfigTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.hs = hs
+
+ def test_loading_missing_templates(self):
+ # Use a temporary directory that exists on the system, but that isn't likely to
+ # contain template files
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Attempt to load an HTML template from our custom template directory
+ template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0]
+
+ # If no errors, we should've gotten the default template instead
+
+ # Render the template
+ a_random_string = random_string(5)
+ html_content = template.render({"error_description": a_random_string})
+
+ # Check that our string exists in the template
+ self.assertIn(
+ a_random_string,
+ html_content,
+ "Template file did not contain our test string",
+ )
+
+ def test_loading_custom_templates(self):
+ # Use a temporary directory that exists on the system
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Create a temporary bogus template file
+ with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template:
+ # Get temporary file's filename
+ template_filename = os.path.basename(tmp_template.name)
+
+ # Write a custom HTML template
+ contents = b"{{ test_variable }}"
+ tmp_template.write(contents)
+ tmp_template.flush()
+
+ # Attempt to load the template from our custom template directory
+ template = (
+ self.hs.config.read_templates([template_filename], tmp_dir)
+ )[0]
+
+ # Render the template
+ a_random_string = random_string(5)
+ html_content = template.render({"test_variable": a_random_string})
+
+ # Check that our string exists in the template
+ self.assertIn(
+ a_random_string,
+ html_content,
+ "Template file did not contain our test string",
+ )
+
+ def test_loading_template_from_nonexistent_custom_directory(self):
+ with self.assertRaises(ConfigError):
+ self.hs.config.read_templates(
+ ["some_filename.html"], "a_nonexistent_directory"
+ )
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 70c8e72303..0d4b05304b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -40,6 +40,7 @@ from synapse.logging.context import (
from synapse.storage.keys import FetchKeyResult
from tests import unittest
+from tests.test_utils import make_awaitable
class MockPerspectiveServer(object):
@@ -102,11 +103,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
}
persp_deferred = defer.Deferred()
- @defer.inlineCallbacks
- def get_perspectives(**kwargs):
+ async def get_perspectives(**kwargs):
self.assertEquals(current_context().request, "11")
with PreserveLoggingContext():
- yield persp_deferred
+ await persp_deferred
return persp_resp
self.http_client.post_json.side_effect = get_perspectives
@@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
self.failureResultOf(d, SynapseError)
- # should suceed on a signed object
+ # should succeed on a signed object
d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
# self.assertFalse(d.called)
self.get_success(d)
@@ -202,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
with a null `ts_valid_until_ms`
"""
mock_fetcher = keyring.KeyFetcher()
- mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
+ mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
kr = keyring.Keyring(
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
@@ -245,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key(1)
- def get_keys(keys_to_fetch):
+ async def get_keys(keys_to_fetch):
# there should only be one request object (with the max validity)
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
- }
+ return {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
- )
+ }
mock_fetcher = keyring.KeyFetcher()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
@@ -282,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""If the first fetcher cannot provide a recent enough key, we fall back"""
key1 = signedjson.key.generate_signing_key(1)
- def get_keys1(keys_to_fetch):
+ async def get_keys1(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
- }
- }
- )
+ return {
+ "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
+ }
- def get_keys2(keys_to_fetch):
+ async def get_keys2(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
- }
+ return {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
- )
+ }
mock_fetcher1 = keyring.KeyFetcher()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
@@ -355,7 +347,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
}
signedjson.sign.sign_json(response, SERVER_NAME, testkey)
- def get_json(destination, path, **kwargs):
+ async def get_json(destination, path, **kwargs):
self.assertEqual(destination, SERVER_NAME)
self.assertEqual(path, "/_matrix/key/v2/server/key1")
return response
@@ -444,7 +436,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect a perspectives-server key query
"""
- def post_json(destination, path, data, **kwargs):
+ async def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
self.assertEqual(path, "/_matrix/key/v2/query")
@@ -580,14 +572,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# remove the perspectives server's signature
response = build_response()
del response["signatures"][self.mock_perspective_server.server_name]
- self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
# remove the origin server's signature
response = build_response()
del response["signatures"][SERVER_NAME]
- self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 640f5f3bce..3a80626224 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -41,8 +41,10 @@ class TestEventContext(unittest.HomeserverTestCase):
serialize/deserialize.
"""
- event, context = create_event(
- self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+ event, context = self.get_success(
+ create_event(
+ self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+ )
)
self._check_serialize_deserialize(event, context)
@@ -51,12 +53,14 @@ class TestEventContext(unittest.HomeserverTestCase):
"""Test that an EventContext for a state event (with not previous entry)
is the same after serialize/deserialize.
"""
- event, context = create_event(
- self.hs,
- room_id=self.room_id,
- type="m.test",
- sender=self.user_id,
- state_key="",
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ sender=self.user_id,
+ state_key="",
+ )
)
self._check_serialize_deserialize(event, context)
@@ -65,13 +69,15 @@ class TestEventContext(unittest.HomeserverTestCase):
"""Test that an EventContext for a state event (which replaces a
previous entry) is the same after serialize/deserialize.
"""
- event, context = create_event(
- self.hs,
- room_id=self.room_id,
- type="m.room.member",
- sender=self.user_id,
- state_key=self.user_id,
- content={"membership": "leave"},
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.room.member",
+ sender=self.user_id,
+ state_key=self.user_id,
+ content={"membership": "leave"},
+ )
)
self._check_serialize_deserialize(event, context)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 0c9987be54..9bd515080c 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room
from synapse.types import UserID
from tests import unittest
+from tests.test_utils import make_awaitable
class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@@ -78,9 +79,44 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
+ handler.federation_handler.do_invite_join = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ )
+
+ d = handler._remote_join(
+ None,
+ ["other.example.com"],
+ "roomid",
+ UserID.from_string(u1),
+ {"membership": "join"},
+ )
+
+ self.pump()
+
+ # The request failed with a SynapseError saying the resource limit was
+ # exceeded.
+ f = self.get_failure(d, SynapseError)
+ self.assertEqual(f.value.code, 400, f.value)
+ self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def test_join_too_large_admin(self):
+ # Check whether an admin can join if option "admins_can_join" is undefined,
+ # this option defaults to false, so the join should fail.
+
+ u1 = self.register_user("u1", "pass", admin=True)
+
+ handler = self.hs.get_room_member_handler()
+ fed_transport = self.hs.get_federation_transport_client()
+
+ # Mock out some things, because we don't want to test the whole join
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=defer.succeed(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -116,9 +152,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable(None)
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=defer.succeed(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
# Artificially raise the complexity
@@ -141,3 +179,85 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+
+class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
+ # Test the behavior of joining rooms which exceed the complexity if option
+ # limit_remote_rooms.admins_can_join is True.
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["limit_remote_rooms"] = {
+ "enabled": True,
+ "complexity": 0.05,
+ "admins_can_join": True,
+ }
+ return config
+
+ def test_join_too_large_no_admin(self):
+ # A user which is not an admin should not be able to join a remote room
+ # which is too complex.
+
+ u1 = self.register_user("u1", "pass")
+
+ handler = self.hs.get_room_member_handler()
+ fed_transport = self.hs.get_federation_transport_client()
+
+ # Mock out some things, because we don't want to test the whole join
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
+ handler.federation_handler.do_invite_join = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ )
+
+ d = handler._remote_join(
+ None,
+ ["other.example.com"],
+ "roomid",
+ UserID.from_string(u1),
+ {"membership": "join"},
+ )
+
+ self.pump()
+
+ # The request failed with a SynapseError saying the resource limit was
+ # exceeded.
+ f = self.get_failure(d, SynapseError)
+ self.assertEqual(f.value.code, 400, f.value)
+ self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def test_join_too_large_admin(self):
+ # An admin should be able to join rooms where a complexity check fails.
+
+ u1 = self.register_user("u1", "pass", admin=True)
+
+ handler = self.hs.get_room_member_handler()
+ fed_transport = self.hs.get_federation_transport_client()
+
+ # Mock out some things, because we don't want to test the whole join
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
+ handler.federation_handler.do_invite_join = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ )
+
+ d = handler._remote_join(
+ None,
+ ["other.example.com"],
+ "roomid",
+ UserID.from_string(u1),
+ {"membership": "join"},
+ )
+
+ self.pump()
+
+ # The request success since the user is an admin
+ self.get_success(d)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 1a9bd5f37d..5f512ff8bf 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -26,31 +26,34 @@ from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.types import JsonDict, ReadReceipt
+from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
+ mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
+ # Ensure a new Awaitable is created for each call.
+ mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+ ["test", "host2"]
+ )
return self.setup_test_homeserver(
- state_handler=Mock(spec=["get_current_hosts_in_room"]),
+ state_handler=mock_state_handler,
federation_transport_client=Mock(spec=["send_transaction"]),
)
@override_config({"send_federation": True})
def test_send_receipts(self):
- mock_state_handler = self.hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
- mock_send_transaction.return_value = defer.succeed({})
+ mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
)
- self.successResultOf(sender.send_read_receipt(receipt))
+ self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
@@ -81,19 +84,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
- mock_state_handler = self.hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
- mock_send_transaction.return_value = defer.succeed({})
+ mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
)
- self.successResultOf(sender.send_read_receipt(receipt))
+ self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
@@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
)
- self.successResultOf(sender.send_read_receipt(receipt))
+ self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
mock_send_transaction.assert_not_called()
@@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
- state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)
@@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
return c
def prepare(self, reactor, clock, hs):
- # stub out get_current_hosts_in_room
- mock_state_handler = hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
# stub out get_users_who_share_room_with_user so that it claims that
# `@user2:host2` is in the room
def get_users_who_share_room_with_user(user_id):
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index ba7148ec01..2a0b7c1b56 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.handlers.appservice import ApplicationServicesHandler
+from tests.test_utils import make_awaitable
from tests.utils import MockClock
from .. import unittest
@@ -32,10 +33,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api = Mock()
self.mock_scheduler = Mock()
hs = Mock()
- hs.get_datastore = Mock(return_value=self.mock_store)
- self.mock_store.get_received_ts.return_value = 0
- hs.get_application_service_api = Mock(return_value=self.mock_as_api)
- hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
+ hs.get_datastore.return_value = self.mock_store
+ self.mock_store.get_received_ts.return_value = defer.succeed(0)
+ self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None)
+ hs.get_application_service_api.return_value = self.mock_as_api
+ hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
self.handler = ApplicationServicesHandler(hs)
@@ -48,18 +50,18 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice(is_interested=False),
]
- self.mock_store.get_app_services = Mock(return_value=services)
- self.mock_store.get_user_by_id = Mock(return_value=[])
+ self.mock_as_api.query_user.return_value = defer.succeed(True)
+ self.mock_store.get_app_services.return_value = services
+ self.mock_store.get_user_by_id.return_value = defer.succeed([])
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
self.mock_store.get_new_events_for_appservice.side_effect = [
- (0, [event]),
- (0, []),
+ defer.succeed((0, [event])),
+ defer.succeed((0, [])),
]
- self.mock_as_api.push = Mock()
- yield self.handler.notify_interested_services(0)
+ yield defer.ensureDeferred(self.handler.notify_interested_services(0))
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
)
@@ -68,36 +70,34 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
- services[0].is_interested_in_user = Mock(return_value=True)
- self.mock_store.get_app_services = Mock(return_value=services)
- self.mock_store.get_user_by_id = Mock(return_value=None)
+ services[0].is_interested_in_user.return_value = True
+ self.mock_store.get_app_services.return_value = services
+ self.mock_store.get_user_by_id.return_value = defer.succeed(None)
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.push = Mock()
- self.mock_as_api.query_user = Mock()
+ self.mock_as_api.query_user.return_value = defer.succeed(True)
self.mock_store.get_new_events_for_appservice.side_effect = [
- (0, [event]),
- (0, []),
+ defer.succeed((0, [event])),
+ defer.succeed((0, [])),
]
- yield self.handler.notify_interested_services(0)
+ yield defer.ensureDeferred(self.handler.notify_interested_services(0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@defer.inlineCallbacks
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
- services[0].is_interested_in_user = Mock(return_value=True)
- self.mock_store.get_app_services = Mock(return_value=services)
- self.mock_store.get_user_by_id = Mock(return_value={"name": user_id})
+ services[0].is_interested_in_user.return_value = True
+ self.mock_store.get_app_services.return_value = services
+ self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id})
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.push = Mock()
- self.mock_as_api.query_user = Mock()
+ self.mock_as_api.query_user.return_value = defer.succeed(True)
self.mock_store.get_new_events_for_appservice.side_effect = [
- (0, [event]),
- (0, []),
+ defer.succeed((0, [event])),
+ defer.succeed((0, [])),
]
- yield self.handler.notify_interested_services(0)
+ yield defer.ensureDeferred(self.handler.notify_interested_services(0))
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been.",
@@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_room_alias_exists(self):
room_alias_str = "#foo:bar"
room_alias = Mock()
- room_alias.to_string = Mock(return_value=room_alias_str)
+ room_alias.to_string.return_value = room_alias_str
room_id = "!alpha:bet"
servers = ["aperture"]
@@ -118,12 +118,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice_alias(is_interested_in_alias=False),
]
- self.mock_store.get_app_services = Mock(return_value=services)
- self.mock_store.get_association_from_room_alias = Mock(
- return_value=Mock(room_id=room_id, servers=servers)
+ self.mock_as_api.query_alias.return_value = make_awaitable(True)
+ self.mock_store.get_app_services.return_value = services
+ self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
+ Mock(room_id=room_id, servers=servers)
)
- result = yield self.handler.query_room_alias_exists(room_alias)
+ result = yield defer.ensureDeferred(
+ self.handler.query_room_alias_exists(room_alias)
+ )
self.mock_as_api.query_alias.assert_called_once_with(
interested_service, room_alias_str
@@ -133,14 +136,14 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def _mkservice(self, is_interested):
service = Mock()
- service.is_interested = Mock(return_value=is_interested)
+ service.is_interested.return_value = make_awaitable(is_interested)
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
def _mkservice_alias(self, is_interested_in_alias):
service = Mock()
- service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
+ service.is_interested_in_alias.return_value = is_interested_in_alias
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 62b47f6574..6aa322bf3a 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
- res = self.handler.get_device(user1, "abc")
- self.pump()
- self.assertIsInstance(
- self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ self.get_failure(
+ self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
)
# we'd like to check the access token was invalidated, but that's a
@@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
- res = self.handler.update_device("user_id", "unknown_device_id", update)
- self.pump()
- self.assertIsInstance(
- self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ self.get_failure(
+ self.handler.update_device("user_id", "unknown_device_id", update),
+ synapse.api.errors.NotFoundError,
)
def _record_users(self):
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 00bb776271..bc0c5aefdc 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -16,8 +16,6 @@
from mock import Mock
-from twisted.internet import defer
-
import synapse
import synapse.api.errors
from synapse.api.constants import EventTypes
@@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room
from synapse.types import RoomAlias, create_requester
from tests import unittest
+from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
def test_get_remote_association(self):
- self.mock_federation.make_query.return_value = defer.succeed(
+ self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 6c1dc72bd1..210ddcbb88 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -14,11 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import mock
-import signedjson.key as key
-import signedjson.sign as sign
+from signedjson import key as key, sign as sign
from twisted.internet import defer
@@ -48,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"""If the user has no devices, we expect an empty list.
"""
local_user = "@boris:" + self.hs.hostname
- res = yield self.handler.query_local_devices({local_user: None})
+ res = yield defer.ensureDeferred(
+ self.handler.query_local_devices({local_user: None})
+ )
self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
@@ -62,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
# we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
@@ -86,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+ )
)
self.fail("No error when changing string key")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+ )
)
self.fail("No error when replacing dict key with string")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+ )
)
self.fail("No error when replacing string key with dict")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user,
- device_id,
- {
- "one_time_keys": {
- "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
- }
- },
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {
+ "one_time_keys": {
+ "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+ }
+ },
+ )
)
self.fail("No error when replacing dict key")
except errors.SynapseError:
@@ -135,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_id = "xyz"
keys = {"alg1:k1": "key1"}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
- res2 = yield self.handler.claim_one_time_keys(
- {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ res2 = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
)
self.assertEqual(
res2,
@@ -165,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
keys2 = {
"master_key": {
@@ -177,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys2)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys2)
+ )
- devices = yield self.handler.query_devices(
- {"device_keys": {local_user: []}}, 0, local_user
+ devices = yield defer.ensureDeferred(
+ self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@@ -217,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
"2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
)
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
# upload two device keys, which will be signed later by the self-signing key
device_key_1 = {
@@ -247,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"signatures": {local_user: {"ed25519:def": "base64+signature"}},
}
- yield self.handler.upload_keys_for_user(
- local_user, "abc", {"device_keys": device_key_1}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, "abc", {"device_keys": device_key_1}
+ )
)
- yield self.handler.upload_keys_for_user(
- local_user, "def", {"device_keys": device_key_2}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, "def", {"device_keys": device_key_2}
+ )
)
# sign the first device key and upload it
del device_key_1["signatures"]
sign.sign_json(device_key_1, local_user, signing_key)
- yield self.handler.upload_signatures_for_device_keys(
- local_user, {local_user: {"abc": device_key_1}}
+ yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1}}
+ )
)
# sign the second device key and upload both device keys. The server
@@ -266,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# signature for it
del device_key_2["signatures"]
sign.sign_json(device_key_2, local_user, signing_key)
- yield self.handler.upload_signatures_for_device_keys(
- local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+ yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+ )
)
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
- devices = yield self.handler.query_devices(
- {"device_keys": {local_user: []}}, 0, local_user
+ devices = yield defer.ensureDeferred(
+ self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"]
@@ -294,20 +328,26 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
res = None
try:
- yield self.hs.get_device_handler().check_device_registered(
- user_id=local_user,
- device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
- initial_device_display_name="new display name",
+ yield defer.ensureDeferred(
+ self.hs.get_device_handler().check_device_registered(
+ user_id=local_user,
+ device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+ initial_device_display_name="new display name",
+ )
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 400)
- res = yield self.handler.query_local_devices({local_user: None})
+ res = yield defer.ensureDeferred(
+ self.handler.query_local_devices({local_user: None})
+ )
self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
@@ -333,8 +373,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
)
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"device_keys": device_key}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"device_keys": device_key}
+ )
)
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
@@ -374,7 +416,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_signing_key": usersigning_key,
"self_signing_key": selfsigning_key,
}
- yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+ )
# set up another user with a master key. This user will be signed by
# the first user
@@ -386,76 +430,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"usage": ["master"],
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
}
- yield self.handler.upload_signing_keys_for_user(
- other_user, {"master_key": other_master_key}
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(
+ other_user, {"master_key": other_master_key}
+ )
)
# test various signature failures (see below)
- ret = yield self.handler.upload_signatures_for_device_keys(
- local_user,
- {
- local_user: {
- # fails because the signature is invalid
- # should fail with INVALID_SIGNATURE
- device_id: {
- "user_id": local_user,
- "device_id": device_id,
- "algorithms": [
- "m.olm.curve25519-aes-sha2",
- RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
- ],
- "keys": {
- "curve25519:xyz": "curve25519+key",
- # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
- "ed25519:xyz": device_pubkey,
- },
- "signatures": {
- local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ ret = yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ device_id: {
+ "user_id": local_user,
+ "device_id": device_id,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "curve25519:xyz": "curve25519+key",
+ # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
+ "ed25519:xyz": device_pubkey,
+ },
+ "signatures": {
+ local_user: {
+ "ed25519:" + selfsigning_pubkey: "something"
+ }
+ },
},
- },
- # fails because device is unknown
- # should fail with NOT_FOUND
- "unknown": {
- "user_id": local_user,
- "device_id": "unknown",
- "signatures": {
- local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ # fails because device is unknown
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": local_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {
+ "ed25519:" + selfsigning_pubkey: "something"
+ }
+ },
},
- },
- # fails because the signature is invalid
- # should fail with INVALID_SIGNATURE
- master_pubkey: {
- "user_id": local_user,
- "usage": ["master"],
- "keys": {"ed25519:" + master_pubkey: master_pubkey},
- "signatures": {
- local_user: {"ed25519:" + device_pubkey: "something"}
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ master_pubkey: {
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_pubkey: master_pubkey},
+ "signatures": {
+ local_user: {"ed25519:" + device_pubkey: "something"}
+ },
},
},
- },
- other_user: {
- # fails because the device is not the user's master-signing key
- # should fail with NOT_FOUND
- "unknown": {
- "user_id": other_user,
- "device_id": "unknown",
- "signatures": {
- local_user: {"ed25519:" + usersigning_pubkey: "something"}
+ other_user: {
+ # fails because the device is not the user's master-signing key
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": other_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {
+ "ed25519:" + usersigning_pubkey: "something"
+ }
+ },
},
- },
- other_master_pubkey: {
- # fails because the key doesn't match what the server has
- # should fail with UNKNOWN
- "user_id": other_user,
- "usage": ["master"],
- "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
- "something": "random",
- "signatures": {
- local_user: {"ed25519:" + usersigning_pubkey: "something"}
+ other_master_pubkey: {
+ # fails because the key doesn't match what the server has
+ # should fail with UNKNOWN
+ "user_id": other_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:" + other_master_pubkey: other_master_pubkey
+ },
+ "something": "random",
+ "signatures": {
+ local_user: {
+ "ed25519:" + usersigning_pubkey: "something"
+ }
+ },
},
},
},
- },
+ )
)
user_failures = ret["failures"][local_user]
@@ -480,19 +538,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
sign.sign_json(device_key, local_user, selfsigning_signing_key)
sign.sign_json(master_key, local_user, device_signing_key)
sign.sign_json(other_master_key, local_user, usersigning_signing_key)
- ret = yield self.handler.upload_signatures_for_device_keys(
- local_user,
- {
- local_user: {device_id: device_key, master_pubkey: master_key},
- other_user: {other_master_pubkey: other_master_key},
- },
+ ret = yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {device_id: device_key, master_pubkey: master_key},
+ other_user: {other_master_pubkey: other_master_key},
+ },
+ )
)
self.assertEqual(ret["failures"], {})
# fetch the signed keys/devices and make sure that the signatures are there
- ret = yield self.handler.query_devices(
- {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ ret = yield defer.ensureDeferred(
+ self.handler.query_devices(
+ {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ )
)
self.assertEqual(
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 70f172eb02..3362050ce0 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_version_info(self.local_user)
+ yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_version_info(self.local_user, "bogus_version")
+ yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "bogus_version")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -87,15 +89,21 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_create_version(self):
"""Check that we can create and then retrieve versions.
"""
- res = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "1")
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
version_etag = res["etag"]
+ self.assertIsInstance(version_etag, str)
del res["etag"]
self.assertDictEqual(
res,
@@ -108,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# check we can retrieve it as a specific version
- res = yield self.handler.get_version_info(self.local_user, "1")
+ res = yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "1")
+ )
self.assertEqual(res["etag"], version_etag)
del res["etag"]
self.assertDictEqual(
@@ -122,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# upload a new one...
- res = yield self.handler.create_version(
- self.local_user,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "second_version_auth_data",
- },
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "2")
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -148,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_version(self):
"""Check that we can update versions.
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- res = yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": version,
- },
+ res = yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": version,
+ },
+ )
)
self.assertDictEqual(res, {})
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -184,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.update_version(
- self.local_user,
- "1",
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "1",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ "1",
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "1",
+ },
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -201,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_omitted_version(self):
"""Check that the update succeeds if the version is missing from the body
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ },
+ )
)
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"] # etag is opaque, so don't test its contents
self.assertDictEqual(
res,
@@ -233,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_bad_version(self):
"""Check that we get a 400 if the version in the body doesn't match
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
res = None
try:
- yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "incorrect",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "incorrect",
+ },
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -260,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.delete_version(self.local_user, "1")
+ yield defer.ensureDeferred(
+ self.handler.delete_version(self.local_user, "1")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -271,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.delete_version(self.local_user)
+ yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -280,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_version(self):
"""Check that we can create and then delete versions.
"""
- res = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "1")
# check we can delete it
- yield self.handler.delete_version(self.local_user, "1")
+ yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
# check that it's gone
res = None
try:
- yield self.handler.get_version_info(self.local_user, "1")
+ yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "1")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -303,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_room_keys(self.local_user, "bogus_version")
+ yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, "bogus_version")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -312,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_get_missing_room_keys(self):
"""Check we get an empty response from an empty backup
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertDictEqual(res, {"rooms": {}})
# TODO: test the locking semantics when uploading room_keys,
@@ -330,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.upload_room_keys(
- self.local_user, "no_version", room_keys
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
)
except errors.SynapseError as e:
res = e.code
@@ -342,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""Check that we get a 404 on uploading keys when an nonexistent version
is specified
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
res = None
try:
- yield self.handler.upload_room_keys(
- self.local_user, "bogus_version", room_keys
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(
+ self.local_user, "bogus_version", room_keys
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -361,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_wrong_version(self):
"""Check that we get a 403 on uploading keys for an old version
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- version = yield self.handler.create_version(
- self.local_user,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "second_version_auth_data",
- },
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "2")
res = None
try:
- yield self.handler.upload_room_keys(self.local_user, "1", room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, "1", room_keys)
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 403)
@@ -387,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_insert(self):
"""Check that we can insert and retrieve keys for a session
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given room
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org"
+ )
)
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given session_id
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, room_keys)
@@ -414,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_merge(self):
"""Check that we can upload a new room_key for an existing session and
have it correctly merged"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
# get the etag to compare to future versions
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
backup_etag = res["etag"]
self.assertEqual(res["count"], 1)
@@ -433,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# test that increasing the message_index doesn't replace the existing session
new_room_key["first_message_index"] = 2
new_room_key["session_data"] = "new"
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK",
)
# the etag should be the same since the session did not change
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should NOT be equal now, since the key changed
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertNotEqual(res["etag"], backup_etag)
backup_etag = res["etag"]
@@ -463,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# with a lower forwarding count
new_room_key["forwarded_count"] = 2
new_room_key["session_data"] = "other"
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should be the same since the session did not change
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# TODO: check edge cases as well as the common variations here
@@ -480,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_room_keys(self):
"""Check that we can insert and delete keys for a session
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
# check for bulk-delete
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(self.local_user, version)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(self.local_user, version)
+ )
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per room
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(
- self.local_user, version, room_id="!abc:matrix.org"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org"
+ )
)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per session
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..f92f3b8c15 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(spec=["args", "getCookie", "addCookie"])
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
code = "code"
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
+ user_agent = "Browser"
+ ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
@@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+ request.getClientIP.return_value = ip_address
+
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()
@@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 05ea40a7de..306dcfe944 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -19,6 +19,7 @@ from mock import Mock, call
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
@@ -32,7 +33,6 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest.client.v1 import room
-from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from tests import unittest
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 29dd7d9c6e..b609b30d4a 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -24,6 +24,7 @@ from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -63,7 +64,7 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- yield self.store.create_profile(self.frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
self.handler = hs.get_profile_handler()
self.hs = hs
@@ -72,7 +73,9 @@ class ProfileTestCase(unittest.TestCase):
def test_get_my_name(self):
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
- displayname = yield self.handler.get_displayname(self.frank)
+ displayname = yield defer.ensureDeferred(
+ self.handler.get_displayname(self.frank)
+ )
self.assertEquals("Frank", displayname)
@@ -136,11 +139,13 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_other_name(self):
- self.mock_federation.make_query.return_value = defer.succeed(
+ self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
- displayname = yield self.handler.get_displayname(self.alice)
+ displayname = yield defer.ensureDeferred(
+ self.handler.get_displayname(self.alice)
+ )
self.assertEquals(displayname, "Alice")
self.mock_federation.make_query.assert_called_with(
@@ -152,11 +157,13 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_profile("caroline")
+ yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield self.store.set_profile_displayname("caroline", "Caroline")
- response = yield self.query_handlers["profile"](
- {"user_id": "@caroline:test", "field": "displayname"}
+ response = yield defer.ensureDeferred(
+ self.query_handlers["profile"](
+ {"user_id": "@caroline:test", "field": "displayname"}
+ )
)
self.assertEquals({"displayname": "Caroline"}, response)
@@ -166,8 +173,7 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
)
-
- avatar_url = yield self.handler.get_avatar_url(self.frank)
+ avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
self.assertEquals("http://my.server/me.png", avatar_url)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index ca32f993a3..5c92d0e8c9 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,15 +17,21 @@ from mock import Mock
from twisted.internet import defer
+from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
+from tests.test_utils import make_awaitable
+from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
+
from .. import unittest
-class RegistrationHandlers(object):
+class RegistrationHandlers:
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
@@ -145,9 +151,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms(self):
room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@@ -185,7 +191,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.is_real_user = Mock(return_value=defer.succeed(False))
+ self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@@ -193,12 +199,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias = RoomAlias.from_string(room_alias_str)
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
+ @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.count_real_users = Mock(return_value=defer.succeed(1))
- self.store.is_real_user = Mock(return_value=defer.succeed(True))
+ self.store.count_real_users = Mock(return_value=make_awaitable(1))
+ self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@@ -212,12 +218,218 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.count_real_users = Mock(return_value=defer.succeed(2))
- self.store.is_real_user = Mock(return_value=defer.succeed(True))
+ self.store.count_real_users = Mock(return_value=make_awaitable(2))
+ self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ @override_config(
+ {
+ "auto_join_rooms": ["#room:test"],
+ "autocreate_auto_join_rooms_federated": False,
+ }
+ )
+ def test_auto_create_auto_join_rooms_federated(self):
+ """
+ Auto-created rooms that are private require an invite to go to the user
+ (instead of directly joining it).
+ """
+ room_alias_str = "#room:test"
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+ # Ensure the room was created.
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ # Ensure the room is properly not federated.
+ room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ self.assertFalse(room["federatable"])
+ self.assertFalse(room["public"])
+ self.assertEqual(room["join_rules"], "public")
+ self.assertIsNone(room["guest_access"])
+
+ # The user should be in the room.
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ @override_config(
+ {"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"}
+ )
+ def test_auto_join_mxid_localpart(self):
+ """
+ Ensure the user still needs up in the room created by a different user.
+ """
+ # Ensure the support user exists.
+ inviter = "@support:test"
+
+ room_alias_str = "#room:test"
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+ # Ensure the room was created.
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ # Ensure the room is properly a public room.
+ room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ self.assertEqual(room["join_rules"], "public")
+
+ # Both users should be in the room.
+ rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+ self.assertIn(room_id["room_id"], rooms)
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ # Register a second user, which should also end up in the room.
+ user_id = self.get_success(self.handler.register_user(localpart="bob"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ @override_config(
+ {
+ "auto_join_rooms": ["#room:test"],
+ "autocreate_auto_join_room_preset": "private_chat",
+ "auto_join_mxid_localpart": "support",
+ }
+ )
+ def test_auto_create_auto_join_room_preset(self):
+ """
+ Auto-created rooms that are private require an invite to go to the user
+ (instead of directly joining it).
+ """
+ # Ensure the support user exists.
+ inviter = "@support:test"
+
+ room_alias_str = "#room:test"
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+ # Ensure the room was created.
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ # Ensure the room is properly a private room.
+ room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ self.assertFalse(room["public"])
+ self.assertEqual(room["join_rules"], "invite")
+ self.assertEqual(room["guest_access"], "can_join")
+
+ # Both users should be in the room.
+ rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+ self.assertIn(room_id["room_id"], rooms)
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ # Register a second user, which should also end up in the room.
+ user_id = self.get_success(self.handler.register_user(localpart="bob"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ @override_config(
+ {
+ "auto_join_rooms": ["#room:test"],
+ "autocreate_auto_join_room_preset": "private_chat",
+ "auto_join_mxid_localpart": "support",
+ }
+ )
+ def test_auto_create_auto_join_room_preset_guest(self):
+ """
+ Auto-created rooms that are private require an invite to go to the user
+ (instead of directly joining it).
+
+ This should also work for guests.
+ """
+ inviter = "@support:test"
+
+ room_alias_str = "#room:test"
+ user_id = self.get_success(
+ self.handler.register_user(localpart="jeff", make_guest=True)
+ )
+
+ # Ensure the room was created.
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ # Ensure the room is properly a private room.
+ room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ self.assertFalse(room["public"])
+ self.assertEqual(room["join_rules"], "invite")
+ self.assertEqual(room["guest_access"], "can_join")
+
+ # Both users should be in the room.
+ rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+ self.assertIn(room_id["room_id"], rooms)
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ @override_config(
+ {
+ "auto_join_rooms": ["#room:test"],
+ "autocreate_auto_join_room_preset": "private_chat",
+ "auto_join_mxid_localpart": "support",
+ }
+ )
+ def test_auto_create_auto_join_room_preset_invalid_permissions(self):
+ """
+ Auto-created rooms that are private require an invite, check that
+ registration doesn't completely break if the inviter doesn't have proper
+ permissions.
+ """
+ inviter = "@support:test"
+
+ # Register an initial user to create the room and such (essentially this
+ # is a subset of test_auto_create_auto_join_room_preset).
+ room_alias_str = "#room:test"
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+ # Ensure the room was created.
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ # Ensure the room exists.
+ self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+
+ # Both users should be in the room.
+ rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+ self.assertIn(room_id["room_id"], rooms)
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ # Lower the permissions of the inviter.
+ event_creation_handler = self.hs.get_event_creation_handler()
+ requester = create_requester(inviter)
+ event, context = self.get_success(
+ event_creation_handler.create_event(
+ requester,
+ {
+ "type": "m.room.power_levels",
+ "state_key": "",
+ "room_id": room_id["room_id"],
+ "content": {"invite": 100, "users": {inviter: 0}},
+ "sender": inviter,
+ },
+ )
+ )
+ self.get_success(
+ event_creation_handler.send_nonmember_event(requester, event, context)
+ )
+
+ # Register a second user, which won't be be in the room (or even have an invite)
+ # since the inviter no longer has the proper permissions.
+ user_id = self.get_success(self.handler.register_user(localpart="bob"))
+
+ # This user should not be in any rooms.
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(user_id)
+ )
+ self.assertEqual(rooms, set())
+ self.assertEqual(invited_rooms, [])
+
def test_auto_create_auto_join_where_no_consent(self):
"""Test to ensure that the first user is not auto-joined to a room if
they have not given general consent.
@@ -266,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_spam_checker_deny(self):
+ """A spam checker can deny registration, which results in an error."""
+
+ class DenyAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.DENY
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [DenyAll()]
+
+ self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+ def test_spam_checker_shadow_ban(self):
+ """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+ class BanAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanAll()]
+
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+ # Get an access token.
+ token = self.macaroon_generator.generate_access_token(user_id)
+ self.get_success(
+ self.store.add_access_token_to_user(
+ user_id=user_id, token=token, device_id=None, valid_until_ms=None
+ )
+ )
+
+ # Ensure the user was marked as shadow-banned.
+ request = Mock(args={})
+ request.args[b"access_token"] = [token.encode("ascii")]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ auth = Auth(self.hs)
+ requester = self.get_success(auth.get_user_by_req(request))
+
+ self.assertTrue(requester.shadow_banned)
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d9d312f0fb..0e666492f6 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -15,7 +15,7 @@
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
-from synapse.storage.data_stores.main import stats
+from synapse.storage.databases.main import stats
from tests import unittest
@@ -42,36 +42,36 @@ class StatsRoomTests(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms",
+ "update_name": "populate_stats_process_rooms_2",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms",
+ "depends_on": "populate_stats_process_rooms_2",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
def get_all_room_state(self):
- return self.store.db.simple_select_list(
+ return self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)
@@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
return self.get_success(
- self.store.db.simple_select_one(
+ self.store.db_pool.simple_select_one(
table + "_historical",
{id_col: stat_id, end_ts: end_ts},
cols,
@@ -109,10 +109,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
def test_initial_room(self):
@@ -146,10 +146,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
r = self.get_success(self.get_all_room_state())
@@ -186,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# the position that the deltas should begin at, once they take over.
self.hs.config.stats_enabled = True
self.handler.stats_enabled = True
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_update_one(
+ self.store.db_pool.simple_update_one(
table="stats_incremental_position",
keyvalues={},
updatevalues={"stream_id": 0},
@@ -196,17 +196,17 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Now, before the table is actually ingested, add some more events.
@@ -217,28 +217,31 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Now do the initial ingestion.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
- {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
+ {
+ "update_name": "populate_stats_process_rooms_2",
+ "progress_json": "{}",
+ },
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms",
+ "depends_on": "populate_stats_process_rooms_2",
},
)
)
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
self.reactor.advance(86401)
@@ -346,6 +349,37 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ def test_updating_profile_information_does_not_increase_joined_members_count(self):
+ """
+ Check that the joined_members count does not increase when a user changes their
+ profile information (which is done by sending another join membership event into
+ the room.
+ """
+ self._perform_background_initial_update()
+
+ # Create a user and room
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ # Get the current room stats
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ # Send a profile update into the room
+ new_profile = {"displayname": "bob"}
+ self.helper.change_membership(
+ r1, u1, u1, "join", extra_data=new_profile, tok=u1token
+ )
+
+ # Get the new room stats
+ r1stats_post = self._get_current_stats("room", r1)
+
+ # Ensure that the user count did not changed
+ self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
+ self.assertEqual(
+ r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
+ )
+
def test_send_state_event_nonoverwriting(self):
"""
When we send a non-overwriting state event, it increments total_events AND current_state_events
@@ -669,15 +703,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# preparation stage of the initial background update
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_delete(
+ self.store.db_pool.simple_delete(
"room_stats_current", {"1": 1}, "test_delete_stats"
)
)
self.get_success(
- self.store.db.simple_delete(
+ self.store.db_pool.simple_delete(
"user_stats_current", {"1": 1}, "test_delete_stats"
)
)
@@ -689,29 +723,29 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# now do the background updates
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms",
+ "update_name": "populate_stats_process_rooms_2",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms",
+ "depends_on": "populate_stats_process_rooms_2",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -722,10 +756,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
r1stats_complete = self._get_current_stats("room", r1)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 2fa8d4739b..e01de158e5 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -24,6 +24,7 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import register_federation_servlets
@@ -115,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
+ self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
(0, [])
)
@@ -126,9 +127,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- def check_user_in_room(room_id, user_id):
+ async def check_user_in_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
+ return None
hs.get_auth().check_user_in_room = check_user_in_room
@@ -137,10 +139,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
- def get_current_users_in_room(room_id):
- return {str(u) for u in self.room_members}
+ def get_users_in_room(room_id):
+ return defer.succeed({str(u) for u in self.room_members})
- hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
+ self.datastore.get_users_in_room = get_users_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
@@ -150,11 +152,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0
- self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
+ self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
([], 0)
)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
- self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+ self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
None
)
@@ -163,7 +165,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- self.successResultOf(
+ self.get_success(
self.handler.started_typing(
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
)
@@ -190,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
def test_started_typing_remote_send(self):
self.room_members = [U_APPLE, U_ONION]
- self.successResultOf(
+ self.get_success(
self.handler.started_typing(
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
)
@@ -265,7 +267,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- self.successResultOf(
+ self.get_success(
self.handler.stopped_typing(
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
)
@@ -305,7 +307,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- self.successResultOf(
+ self.get_success(
self.handler.started_typing(
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
)
@@ -344,7 +346,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# SYN-230 - see if we can still set after timeout
- self.successResultOf(
+ self.get_success(
self.handler.started_typing(
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 23fcc372dd..87be94111f 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_spam_checker(self):
"""
- A user which fails to the spam checks will not appear in search results.
+ A user which fails the spam checks will not appear in search results.
"""
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
@@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
- class AllowAll(object):
+ class AllowAll:
def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll(object):
+ class BlockAll:
def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
@@ -339,7 +339,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_in_public_rooms(self):
r = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
)
)
@@ -350,7 +350,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_who_share_private_rooms(self):
return self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
@@ -362,10 +362,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_createtables",
@@ -374,7 +374,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_rooms",
@@ -384,7 +384,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_users",
@@ -394,7 +394,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_cleanup",
@@ -437,10 +437,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
shares_private = self.get_users_who_share_private_rooms()
@@ -476,10 +476,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
shares_private = self.get_users_who_share_private_rooms()
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 562397cdda..69945a8f98 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -67,6 +67,14 @@ def get_connection_factory():
return test_server_connection_factory
+# Once Async Mocks or lambdas are supported this can go away.
+def generate_resolve_service(result):
+ async def resolve_service(_):
+ return result
+
+ return resolve_service
+
+
class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
@@ -86,6 +94,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.well_known_resolver = WellKnownResolver(
self.reactor,
Agent(self.reactor, contextFactory=self.tls_factory),
+ b"test-agent",
well_known_cache=self.well_known_cache,
had_well_known_cache=self.had_well_known_cache,
)
@@ -93,6 +102,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=self.tls_factory,
+ user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
)
@@ -186,6 +196,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# check the .well-known request and send a response
self.assertEqual(len(well_known_server.requests), 1)
request = well_known_server.requests[0]
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
+ )
self._send_well_known_response(request, content, headers=response_headers)
return well_known_server
@@ -231,6 +244,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"]
)
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
+ )
content = request.content.read()
self.assertEqual(content, b"")
@@ -365,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when the certificate on the server doesn't match the hostname
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv1"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
@@ -448,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has no port, no SRV, and no well-known
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -502,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the .well-known delegates elsewhere
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@@ -564,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@@ -653,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -709,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# the config left to the default, which will not trust it (since the
# presented cert is signed by a test CA)
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True)
@@ -719,10 +735,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=tls_factory,
+ user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
_srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver(
self.reactor,
Agent(self.reactor, contextFactory=tls_factory),
+ b"test-agent",
well_known_cache=self.well_known_cache,
had_well_known_cache=self.had_well_known_cache,
),
@@ -754,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when there is a single SRV record
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"srvtarget", port=8443)
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"srvtarget", port=8443)]
+ )
self.reactor.lookups["srvtarget"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -809,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443)
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"srvtarget", port=8443)
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"srvtarget", port=8443)]
+ )
self._handle_well_known_connection(
client_factory,
@@ -851,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
# the resolver is always called with the IDNA hostname as a native string.
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@@ -912,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_srv_target(self):
"""test the behaviour when the target of a SRV record has idna chars"""
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"xn--trget-3qa.com", port=8443) # târget.com
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com
+ )
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
@@ -1077,11 +1095,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_srv_fallbacks(self):
"""Test that other SRV results are tried if the first one fails.
"""
-
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"target.com", port=8443),
- Server(host=b"target.com", port=8444),
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [
+ Server(host=b"target.com", port=8443),
+ Server(host=b"target.com", port=8444),
+ ]
+ )
self.reactor.lookups["target.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index babc201643..fee2985d35 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import LoggingContext, current_context
from tests import unittest
from tests.utils import MockClock
@@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase):
with LoggingContext("one") as ctx:
resolve_d = resolver.resolve_service(service_name)
-
- self.assertNoResult(resolve_d)
-
- # should have reset to the sentinel context
- self.assertIs(current_context(), SENTINEL_CONTEXT)
-
- result = yield resolve_d
+ result = yield defer.ensureDeferred(resolve_d)
# should have restored our context
self.assertIs(current_context(), ctx)
@@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
dns_client_mock.lookupService.assert_called_once_with(service_name)
@@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client=dns_client_mock, cache=cache, get_time=clock.time
)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertFalse(dns_client_mock.lookupService.called)
@@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase):
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError):
- yield resolver.resolve_service(service_name)
+ yield defer.ensureDeferred(resolver.resolve_service(service_name))
@defer.inlineCallbacks
def test_name_error(self):
@@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)
@@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolver.resolve_service(service_name)
- self.assertNoResult(resolve_d)
+ # Old versions of Twisted don't have an ensureDeferred in failureResultOf.
+ resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
# returning a single "." should make the lookup fail with a ConenctError
lookup_deferred.callback(
@@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolver.resolve_service(service_name)
- self.assertNoResult(resolve_d)
+ # Old versions of Twisted don't have an ensureDeferred in successResultOf.
+ resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
lookup_deferred.callback(
(
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
new file mode 100644
index 0000000000..62d36c2906
--- /dev/null
+++ b/tests/http/test_additional_resource.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.http.additional_resource import AdditionalResource
+from synapse.http.server import respond_with_json
+
+from tests.unittest import HomeserverTestCase
+
+
+class _AsyncTestCustomEndpoint:
+ def __init__(self, config, module_api):
+ pass
+
+ async def handle_request(self, request):
+ respond_with_json(request, 200, {"some_key": "some_value_async"})
+
+
+class _SyncTestCustomEndpoint:
+ def __init__(self, config, module_api):
+ pass
+
+ async def handle_request(self, request):
+ respond_with_json(request, 200, {"some_key": "some_value_sync"})
+
+
+class AdditionalResourceTests(HomeserverTestCase):
+ """Very basic tests that `AdditionalResource` works correctly with sync
+ and async handlers.
+ """
+
+ def test_async(self):
+ handler = _AsyncTestCustomEndpoint({}, None).handle_request
+ self.resource = AdditionalResource(self.hs, handler)
+
+ request, channel = self.make_request("GET", "/")
+ self.render(request)
+
+ self.assertEqual(request.code, 200)
+ self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
+
+ def test_sync(self):
+ handler = _SyncTestCustomEndpoint({}, None).handle_request
+ self.resource = AdditionalResource(self.hs, handler)
+
+ request, channel = self.make_request("GET", "/")
+ self.render(request)
+
+ self.assertEqual(request.code, 200)
+ self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index fff4f0cbf4..ac598249e4 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -58,7 +58,9 @@ class FederationClientTests(HomeserverTestCase):
@defer.inlineCallbacks
def do_request():
with LoggingContext("one") as context:
- fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
+ fetch_d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar")
+ )
# Nothing happened yet
self.assertNoResult(fetch_d)
@@ -120,7 +122,9 @@ class FederationClientTests(HomeserverTestCase):
"""
If the DNS lookup returns an error, it will bubble up.
"""
- d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+ )
self.pump()
f = self.failureResultOf(d)
@@ -128,7 +132,9 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
def test_client_connection_refused(self):
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -154,7 +160,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is not connected and is timed out, it'll give a
ConnectingCancelledError or TimeoutError.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -184,7 +192,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -226,7 +236,7 @@ class FederationClientTests(HomeserverTestCase):
# Try making a GET request to a blacklisted IPv4 address
# ------------------------------------------------------
# Make the request
- d = cl.get_json("internal:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
# Nothing happened yet
self.assertNoResult(d)
@@ -244,7 +254,9 @@ class FederationClientTests(HomeserverTestCase):
# Try making a POST request to a blacklisted IPv6 address
# -------------------------------------------------------
# Make the request
- d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+ )
# Nothing has happened yet
self.assertNoResult(d)
@@ -263,7 +275,7 @@ class FederationClientTests(HomeserverTestCase):
# Try making a GET request to a non-blacklisted IPv4 address
# ----------------------------------------------------------
# Make the request
- d = cl.post_json("fine:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))
# Nothing has happened yet
self.assertNoResult(d)
@@ -286,7 +298,7 @@ class FederationClientTests(HomeserverTestCase):
request = MatrixFederationRequest(
method="GET", destination="testserv:8008", path="foo/bar"
)
- d = self.cl._send_request(request, timeout=10000)
+ d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000))
self.pump()
@@ -310,7 +322,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
"""
- d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -342,7 +356,9 @@ class FederationClientTests(HomeserverTestCase):
requiring a trailing slash. We need to retry the request with a
trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ )
# Send the request
self.pump()
@@ -395,7 +411,9 @@ class FederationClientTests(HomeserverTestCase):
See test_client_requires_trailing_slashes() for context.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ )
# Send the request
self.pump()
@@ -432,7 +450,11 @@ class FederationClientTests(HomeserverTestCase):
self.failureResultOf(d)
def test_client_sends_body(self):
- self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})
+ defer.ensureDeferred(
+ self.cl.post_json(
+ "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
+ )
+ )
self.pump()
@@ -453,7 +475,7 @@ class FederationClientTests(HomeserverTestCase):
def test_closes_connection(self):
"""Check that the client closes unused HTTP connections"""
- d = self.cl.get_json("testserv:8008", "foo/bar")
+ d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
self.pump()
diff --git a/tests/module_api/__init__.py b/tests/module_api/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/module_api/__init__.py
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
new file mode 100644
index 0000000000..807cd65dd6
--- /dev/null
+++ b/tests/module_api/test_api.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.module_api import ModuleApi
+
+from tests.unittest import HomeserverTestCase
+
+
+class ModuleApiTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
+ self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler())
+
+ def test_can_register_user(self):
+ """Tests that an external module can register a user"""
+ # Register a new user
+ user_id, access_token = self.get_success(
+ self.module_api.register(
+ "bob", displayname="Bobberino", emails=["bob@bobinator.bob"]
+ )
+ )
+
+ # Check that the new user exists with all provided attributes
+ self.assertEqual(user_id, "@bob:test")
+ self.assertTrue(access_token)
+ self.assertTrue(self.store.get_user_by_id(user_id))
+
+ # Check that the email was assigned
+ emails = self.get_success(self.store.user_get_threepids(user_id))
+ self.assertEqual(len(emails), 1)
+
+ email = emails[0]
+ self.assertEqual(email["medium"], "email")
+ self.assertEqual(email["address"], "bob@bobinator.bob")
+
+ # Should these be 0?
+ self.assertEqual(email["validated_at"], 0)
+ self.assertEqual(email["added_at"], 0)
+
+ # Check that the displayname was assigned
+ displayname = self.get_success(self.store.get_profile_displayname("bob"))
+ self.assertEqual(displayname, "Bobberino")
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index baf9c785f4..b567868b02 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -25,7 +25,6 @@ from tests.unittest import HomeserverTestCase
class HTTPPusherTests(HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -35,7 +34,6 @@ class HTTPPusherTests(HomeserverTestCase):
hijack_auth = False
def make_homeserver(self, reactor, clock):
-
self.push_attempts = []
m = Mock()
@@ -90,9 +88,6 @@ class HTTPPusherTests(HomeserverTestCase):
# Create a room
room = self.helper.create_room_as(user_id, tok=access_token)
- # Invite the other person
- self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
-
# The other user joins
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
@@ -157,3 +152,350 @@ class HTTPPusherTests(HomeserverTestCase):
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+
+ def test_sends_high_priority_for_encrypted(self):
+ """
+ The HTTP pusher will send pushes at high priority if they correspond
+ to an encrypted message.
+ This will happen both in 1:1 rooms and larger rooms.
+ """
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Register a third user
+ yet_another_user_id = self.register_user("yetanotheruser", "pass")
+ yet_another_access_token = self.login("yetanotheruser", "pass")
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple["token_id"]
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "example.com"},
+ )
+ )
+
+ # Send an encrypted event
+ # I know there'd normally be set-up of an encrypted room first
+ # but this will do for our purposes
+ self.helper.send_event(
+ room,
+ "m.room.encrypted",
+ content={
+ "algorithm": "m.megolm.v1.aes-sha2",
+ "sender_key": "6lImKbzK51MzWLwHh8tUM3UBBSBrLlgup/OOCGTvumM",
+ "ciphertext": "AwgAErABoRxwpMipdgiwXgu46rHiWQ0DmRj0qUlPrMraBUDk"
+ "leTnJRljpuc7IOhsYbLY3uo2WI0ab/ob41sV+3JEIhODJPqH"
+ "TK7cEZaIL+/up9e+dT9VGF5kRTWinzjkeqO8FU5kfdRjm+3w"
+ "0sy3o1OCpXXCfO+faPhbV/0HuK4ndx1G+myNfK1Nk/CxfMcT"
+ "BT+zDS/Df/QePAHVbrr9uuGB7fW8ogW/ulnydgZPRluusFGv"
+ "J3+cg9LoPpZPAmv5Me3ec7NtdlfN0oDZ0gk3TiNkkhsxDG9Y"
+ "YcNzl78USI0q8+kOV26Bu5dOBpU4WOuojXZHJlP5lMgdzLLl"
+ "EQ0",
+ "session_id": "IigqfNWLL+ez/Is+Duwp2s4HuCZhFG9b9CZKTYHtQ4A",
+ "device_id": "AHQDUSTAAA",
+ },
+ tok=other_access_token,
+ )
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ # Make the push succeed
+ self.push_attempts[0][0].callback({})
+ self.pump()
+
+ # Check our push made it with high priority
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
+
+ # Add yet another person — we want to make this room not a 1:1
+ # (as encrypted messages in a 1:1 currently have tweaks applied
+ # so it doesn't properly exercise the condition of all encrypted
+ # messages need to be high).
+ self.helper.join(
+ room=room, user=yet_another_user_id, tok=yet_another_access_token
+ )
+
+ # Check no push notifications are sent regarding the membership changes
+ # (that would confuse the test)
+ self.pump()
+ self.assertEqual(len(self.push_attempts), 1)
+
+ # Send another encrypted event
+ self.helper.send_event(
+ room,
+ "m.room.encrypted",
+ content={
+ "ciphertext": "AwgAEoABtEuic/2DF6oIpNH+q/PonzlhXOVho8dTv0tzFr5m"
+ "9vTo50yabx3nxsRlP2WxSqa8I07YftP+EKWCWJvTkg6o7zXq"
+ "6CK+GVvLQOVgK50SfvjHqJXN+z1VEqj+5mkZVN/cAgJzoxcH"
+ "zFHkwDPJC8kQs47IHd8EO9KBUK4v6+NQ1uE/BIak4qAf9aS/"
+ "kI+f0gjn9IY9K6LXlah82A/iRyrIrxkCkE/n0VfvLhaWFecC"
+ "sAWTcMLoF6fh1Jpke95mljbmFSpsSd/eEQw",
+ "device_id": "SRCFTWTHXO",
+ "session_id": "eMA+bhGczuTz1C5cJR1YbmrnnC6Goni4lbvS5vJ1nG4",
+ "algorithm": "m.megolm.v1.aes-sha2",
+ "sender_key": "rC/XSIAiYrVGSuaHMop8/pTZbku4sQKBZwRwukgnN1c",
+ },
+ tok=other_access_token,
+ )
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+ self.assertEqual(len(self.push_attempts), 2)
+ self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
+
+ def test_sends_high_priority_for_one_to_one_only(self):
+ """
+ The HTTP pusher will send pushes at high priority if they correspond
+ to a message in a one-to-one room.
+ """
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Register a third user
+ yet_another_user_id = self.register_user("yetanotheruser", "pass")
+ yet_another_access_token = self.login("yetanotheruser", "pass")
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple["token_id"]
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "example.com"},
+ )
+ )
+
+ # Send a message
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ # Make the push succeed
+ self.push_attempts[0][0].callback({})
+ self.pump()
+
+ # Check our push made it with high priority — this is a one-to-one room
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
+
+ # Yet another user joins
+ self.helper.join(
+ room=room, user=yet_another_user_id, tok=yet_another_access_token
+ )
+
+ # Check no push notifications are sent regarding the membership changes
+ # (that would confuse the test)
+ self.pump()
+ self.assertEqual(len(self.push_attempts), 1)
+
+ # Send another event
+ self.helper.send(room, body="Welcome!", tok=other_access_token)
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+ self.assertEqual(len(self.push_attempts), 2)
+ self.assertEqual(self.push_attempts[1][1], "example.com")
+
+ # check that this is low-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+
+ def test_sends_high_priority_for_mention(self):
+ """
+ The HTTP pusher will send pushes at high priority if they correspond
+ to a message containing the user's display name.
+ """
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Register a third user
+ yet_another_user_id = self.register_user("yetanotheruser", "pass")
+ yet_another_access_token = self.login("yetanotheruser", "pass")
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other users join
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+ self.helper.join(
+ room=room, user=yet_another_user_id, tok=yet_another_access_token
+ )
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple["token_id"]
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "example.com"},
+ )
+ )
+
+ # Send a message
+ self.helper.send(room, body="Oh, user, hello!", tok=other_access_token)
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ # Make the push succeed
+ self.push_attempts[0][0].callback({})
+ self.pump()
+
+ # Check our push made it with high priority
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
+
+ # Send another event, this time with no mention
+ self.helper.send(room, body="Are you there?", tok=other_access_token)
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+ self.assertEqual(len(self.push_attempts), 2)
+ self.assertEqual(self.push_attempts[1][1], "example.com")
+
+ # check that this is low-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+
+ def test_sends_high_priority_for_atroom(self):
+ """
+ The HTTP pusher will send pushes at high priority if they correspond
+ to a message that contains @room.
+ """
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Register a third user
+ yet_another_user_id = self.register_user("yetanotheruser", "pass")
+ yet_another_access_token = self.login("yetanotheruser", "pass")
+
+ # Create a room (as other_user so the power levels are compatible with
+ # other_user sending @room).
+ room = self.helper.create_room_as(other_user_id, tok=other_access_token)
+
+ # The other users join
+ self.helper.join(room=room, user=user_id, tok=access_token)
+ self.helper.join(
+ room=room, user=yet_another_user_id, tok=yet_another_access_token
+ )
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple["token_id"]
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "example.com"},
+ )
+ )
+
+ # Send a message
+ self.helper.send(
+ room,
+ body="@room eeek! There's a spider on the table!",
+ tok=other_access_token,
+ )
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ # Make the push succeed
+ self.push_attempts[0][0].callback({})
+ self.pump()
+
+ # Check our push made it with high priority
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
+
+ # Send another event, this time as someone without the power of @room
+ self.helper.send(
+ room, body="@room the spider is gone", tok=yet_another_access_token
+ )
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+ self.assertEqual(len(self.push_attempts), 2)
+ self.assertEqual(self.push_attempts[1][1], "example.com")
+
+ # check that this is low-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 9ae6a87d7b..1f4b5ca2ac 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -15,13 +15,14 @@
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
+from synapse.push import push_rule_evaluator
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase):
- def setUp(self):
+ def _get_evaluator(self, content):
event = FrozenEvent(
{
"event_id": "$event_id",
@@ -29,37 +30,74 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"sender": "@user:test",
"state_key": "",
"room_id": "@room:test",
- "content": {"body": "foo bar baz"},
+ "content": content,
},
RoomVersions.V1,
)
room_member_count = 0
sender_power_level = 0
power_levels = {}
- self.evaluator = PushRuleEvaluatorForEvent(
+ return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels
)
def test_display_name(self):
"""Check for a matching display name in the body of the event."""
+ evaluator = self._get_evaluator({"body": "foo bar baz"})
+
condition = {
"kind": "contains_display_name",
}
# Blank names are skipped.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", ""))
+ self.assertFalse(evaluator.matches(condition, "@user:test", ""))
# Check a display name that doesn't match.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found"))
+ self.assertFalse(evaluator.matches(condition, "@user:test", "not found"))
# Check a display name which matches.
- self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo"))
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
# A display name that matches, but not a full word does not result in a match.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba"))
+ self.assertFalse(evaluator.matches(condition, "@user:test", "ba"))
# A display name should not be interpreted as a regular expression.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]"))
+ self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]"))
# A display name with spaces should work fine.
- self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar"))
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
+
+ def test_no_body(self):
+ """Not having a body shouldn't break the evaluator."""
+ evaluator = self._get_evaluator({})
+
+ condition = {
+ "kind": "contains_display_name",
+ }
+ self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+ def test_invalid_body(self):
+ """A non-string body should not break the evaluator."""
+ condition = {
+ "kind": "contains_display_name",
+ }
+
+ for body in (1, True, {"foo": "bar"}):
+ evaluator = self._get_evaluator({"body": body})
+ self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+ def test_tweaks_for_actions(self):
+ """
+ This tests the behaviour of tweaks_for_actions.
+ """
+
+ actions = [
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight"},
+ "notify",
+ ]
+
+ self.assertEqual(
+ push_rule_evaluator.tweaks_for_actions(actions),
+ {"sound": "default", "highlight": True},
+ )
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 9d4f0bbe44..ae60874ec3 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, List, Optional, Tuple
+from typing import Any, Callable, List, Optional, Tuple
import attr
@@ -26,8 +26,9 @@ from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
+from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
-from synapse.replication.http import streams
+from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -35,7 +36,7 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeTransport
+from tests.server import FakeTransport, render
logger = logging.getLogger(__name__)
@@ -64,7 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Since we use sqlite in memory databases we need to make sure the
# databases objects are the same.
- self.worker_hs.get_datastore().db = hs.get_datastore().db
+ self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler
@@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET")
+class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
+ """Base class for tests running multiple workers.
+
+ Automatically handle HTTP replication requests from workers to master,
+ unlike `BaseStreamTestCase`.
+ """
+
+ servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
+
+ def setUp(self):
+ super().setUp()
+
+ # build a replication server
+ self.server_factory = ReplicationStreamProtocolFactory(self.hs)
+ self.streamer = self.hs.get_replication_streamer()
+
+ store = self.hs.get_datastore()
+ self.database_pool = store.db_pool
+
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ self._worker_hs_to_resource = {}
+
+ # When we see a connection attempt to the master replication listener we
+ # automatically set up the connection. This is so that tests don't
+ # manually have to go and explicitly set it up each time (plus sometimes
+ # it is impossible to write the handling explicitly in the tests).
+ self.reactor.add_tcp_client_callback(
+ "1.2.3.4", 8765, self._handle_http_replication_attempt
+ )
+
+ def create_test_json_resource(self):
+ """Overrides `HomeserverTestCase.create_test_json_resource`.
+ """
+ # We override this so that it automatically registers all the HTTP
+ # replication servlets, without having to explicitly do that in all
+ # subclassses.
+
+ resource = ReplicationRestResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(self.hs, resource)
+
+ return resource
+
+ def make_worker_hs(
+ self, worker_app: str, extra_config: dict = {}, **kwargs
+ ) -> HomeServer:
+ """Make a new worker HS instance, correctly connecting replcation
+ stream to the master HS.
+
+ Args:
+ worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
+ extra_config: Any extra config to use for this instances.
+ **kwargs: Options that get passed to `self.setup_test_homeserver`,
+ useful to e.g. pass some mocks for things like `http_client`
+
+ Returns:
+ The new worker HomeServer instance.
+ """
+
+ config = self._get_worker_hs_config()
+ config["worker_app"] = worker_app
+ config.update(extra_config)
+
+ worker_hs = self.setup_test_homeserver(
+ homeserverToUse=GenericWorkerServer,
+ config=config,
+ reactor=self.reactor,
+ **kwargs
+ )
+
+ store = worker_hs.get_datastore()
+ store.db_pool._db_pool = self.database_pool._db_pool
+
+ repl_handler = ReplicationCommandHandler(worker_hs)
+ client = ClientReplicationStreamProtocol(
+ worker_hs, "client", "test", self.clock, repl_handler,
+ )
+ server = self.server_factory.buildProtocol(None)
+
+ client_transport = FakeTransport(server, self.reactor)
+ client.makeConnection(client_transport)
+
+ server_transport = FakeTransport(client, self.reactor)
+ server.makeConnection(server_transport)
+
+ # Set up a resource for the worker
+ resource = ReplicationRestResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(worker_hs, resource)
+
+ self._worker_hs_to_resource[worker_hs] = resource
+
+ return worker_hs
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
+ def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
+ render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+
+ def replicate(self):
+ """Tell the master side of replication that something has happened, and then
+ wait for the replication to occur.
+ """
+ self.streamer.on_notifier_poke()
+ self.pump()
+
+ def _handle_http_replication_attempt(self):
+ """Handles a connection attempt to the master replication HTTP
+ listener.
+ """
+
+ # We should have at least one outbound connection attempt, where the
+ # last is one to the HTTP repication IP/port.
+ clients = self.reactor.tcpClients
+ self.assertGreaterEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8765)
+
+ # Set up client side protocol
+ client_protocol = client_factory.buildProtocol(None)
+
+ request_factory = OneShotRequestFactory()
+
+ # Set up the server side protocol
+ channel = _PushHTTPChannel(self.reactor)
+ channel.requestFactory = request_factory
+ channel.site = self.site
+
+ # Connect client to server and vice versa.
+ client_to_server_transport = FakeTransport(
+ channel, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, channel
+ )
+ channel.makeConnection(server_to_client_transport)
+
+ # Note: at this point we've wired everything up, but we need to return
+ # before the data starts flowing over the connections as this is called
+ # inside `connecTCP` before the connection has been passed back to the
+ # code that requested the TCP connection.
+
+
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel):
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()
+ def checkPersistence(self, request, version):
+ """Check whether the connection can be re-used
+ """
+ # We hijack this to always say no for ease of wiring stuff up in
+ # `handle_http_replication_attempt`.
+ request.responseHeaders.setRawHeaders(b"connection", [b"close"])
+ return False
+
class _PullToPushProducer:
"""A push producer that wraps a pull producer.
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1a88c7fb80..0b5204654c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event))
- self.master_store.add_push_actions_to_staging(
- event.event_id, {user_id: actions for user_id, actions in push_actions}
+ self.get_success(
+ self.master_store.add_push_actions_to_staging(
+ event.event_id, {user_id: actions for user_id, actions in push_actions}
+ )
)
return event, context
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 51bf0ef4e9..c9998e88e6 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -17,6 +17,7 @@ from typing import List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
+from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
@@ -66,11 +67,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
# also one state event
state_event = self._inject_state_event()
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -123,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
OTHER_USER = "@other_user:localhost"
# have the user join
- inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+ self.get_success(
+ inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+ )
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
@@ -161,24 +159,21 @@ class EventsStreamTestCase(BaseStreamTestCase):
# roll back all the state by de-modding the user
prev_events = fork_point
pls["users"][OTHER_USER] = 0
- pl_event = inject_event(
- self.hs,
- prev_event_ids=prev_events,
- type=EventTypes.PowerLevels,
- state_key="",
- sender=self.user_id,
- room_id=self.room_id,
- content=pls,
+ pl_event = self.get_success(
+ inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
)
# one more bit of state that doesn't get rolled back
state2 = self._inject_state_event()
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -277,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
# have the users join
for u in user_ids:
- inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+ self.get_success(
+ inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+ )
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
@@ -315,23 +312,20 @@ class EventsStreamTestCase(BaseStreamTestCase):
pl_events = []
for u in user_ids:
pls["users"][u] = 0
- e = inject_event(
- self.hs,
- prev_event_ids=prev_events,
- type=EventTypes.PowerLevels,
- state_key="",
- sender=self.user_id,
- room_id=self.room_id,
- content=pls,
+ e = self.get_success(
+ inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
)
prev_events = [e.event_id]
pl_events.append(e)
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -378,6 +372,64 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
+ def test_backwards_stream_id(self):
+ """
+ Test that RDATA that comes after the current position should be discarded.
+ """
+ # disconnect, so that we can stack up some changes
+ self.disconnect()
+
+ # Generate an events. We inject them using inject_event so that they are
+ # not send out over replication until we call self.replicate().
+ event = self._inject_test_event()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # We should have received the expected single row (as well as various
+ # cache invalidation updates which we ignore).
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+
+ # There should be a single received row.
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows[0]
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, event.event_id)
+
+ # Reset the data.
+ self.test_handler.received_rdata_rows = []
+
+ # Save the current token for later.
+ worker_events_stream = self.worker_hs.get_replication_streams()["events"]
+ prev_token = worker_events_stream.current_token("master")
+
+ # Manually send an old RDATA command, which should get dropped. This
+ # re-uses the row from above, but with an earlier stream token.
+ self.hs.get_tcp_replication().send_command(
+ RdataCommand("events", "master", 1, row)
+ )
+
+ # No updates have been received (because it was discard as old).
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+ self.assertEqual(len(received_rows), 0)
+
+ # Ensure the stream has not gone backwards.
+ current_token = worker_events_stream.current_token("master")
+ self.assertGreaterEqual(current_token, prev_token)
+
event_count = 0
def _inject_test_event(
@@ -390,13 +442,15 @@ class EventsStreamTestCase(BaseStreamTestCase):
body = "event %i" % (self.event_count,)
self.event_count += 1
- return inject_event(
- self.hs,
- room_id=self.room_id,
- sender=sender,
- type="test_event",
- content={"body": body},
- **kwargs
+ return self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_event",
+ content={"body": body},
+ **kwargs
+ )
)
def _inject_state_event(
@@ -415,11 +469,13 @@ class EventsStreamTestCase(BaseStreamTestCase):
if body is None:
body = "state event %s" % (state_key,)
- return inject_event(
- self.hs,
- room_id=self.room_id,
- sender=sender,
- type="test_state_event",
- state_key=state_key,
- content={"body": body},
+ return self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_state_event",
+ state_key=state_key,
+ content={"body": body},
+ )
)
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index fd62b26356..5acfb3e53e 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -16,10 +16,15 @@ from mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.replication.tcp.streams import TypingStream
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
+USER_ID_2 = "@da-ba-dee:blue"
+
+ROOM_ID = "!bar:blue"
+ROOM_ID_2 = "!foo:blue"
class TypingStreamTestCase(BaseStreamTestCase):
@@ -29,11 +34,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
def test_typing(self):
typing = self.hs.get_typing_handler()
- room_id = "!bar:blue"
-
self.reconnect()
- typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
self.reactor.advance(0)
@@ -46,7 +49,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
- self.assertEqual(room_id, row.room_id)
+ self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
# Now let's disconnect and insert some data.
@@ -54,7 +57,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.test_handler.on_rdata.reset_mock()
- typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
self.test_handler.on_rdata.assert_not_called()
@@ -73,5 +76,78 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
- self.assertEqual(room_id, row.room_id)
+ self.assertEqual(ROOM_ID, row.room_id)
+ self.assertEqual([], row.user_ids)
+
+ def test_reset(self):
+ """
+ Test what happens when a typing stream resets.
+
+ This is emulated by jumping the stream ahead, then reconnecting (which
+ sends the proper position and RDATA).
+ """
+ typing = self.hs.get_typing_handler()
+
+ self.reconnect()
+
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
+
+ self.reactor.advance(0)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0] # type: TypingStream.TypingStreamRow
+ self.assertEqual(ROOM_ID, row.room_id)
+ self.assertEqual([USER_ID], row.user_ids)
+
+ # Push the stream forward a bunch so it can be reset.
+ for i in range(100):
+ typing._push_update(
+ member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True
+ )
+ self.reactor.advance(0)
+
+ # Disconnect.
+ self.disconnect()
+
+ # Reset the typing handler
+ self.hs.get_replication_streams()["typing"].last_token = 0
+ self.hs.get_tcp_replication()._streams["typing"].last_token = 0
+ typing._latest_room_serial = 0
+ typing._typing_stream_change_cache = StreamChangeCache(
+ "TypingStreamChangeCache", typing._latest_room_serial
+ )
+ typing._reset()
+
+ # Reconnect.
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ # Reset the test code.
+ self.test_handler.on_rdata.reset_mock()
+ self.test_handler.on_rdata.assert_not_called()
+
+ # Push additional data.
+ typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
+ self.reactor.advance(0)
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0]
+ self.assertEqual(ROOM_ID_2, row.room_id)
self.assertEqual([], row.user_ids)
+
+ # The token should have been reset.
+ self.assertEqual(token, 1)
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
new file mode 100644
index 0000000000..86c03fd89c
--- /dev/null
+++ b/tests/replication/test_client_reader_shard.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.constants import LoginType
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
+from tests.server import FakeChannel
+
+logger = logging.getLogger(__name__)
+
+
+class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
+ """Base class for tests of the replication streams"""
+
+ servlets = [register.register_servlets]
+
+ def prepare(self, reactor, clock, hs):
+ self.recaptcha_checker = DummyRecaptchaChecker(hs)
+ auth_handler = hs.get_auth_handler()
+ auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.client_reader"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
+ def test_register_single_worker(self):
+ """Test that registration works when using a single client reader worker.
+ """
+ worker_hs = self.make_worker_hs("synapse.app.client_reader")
+
+ request_1, channel_1 = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs, request_1)
+ self.assertEqual(request_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ request_2, channel_2 = self.make_request(
+ "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs, request_2)
+ self.assertEqual(request_2.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel_2.json_body["user_id"], "@user:test")
+
+ def test_register_multi_worker(self):
+ """Test that registration works when using multiple client reader workers.
+ """
+ worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
+ worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
+
+ request_1, channel_1 = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs_1, request_1)
+ self.assertEqual(request_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ request_2, channel_2 = self.make_request(
+ "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs_2, request_2)
+ self.assertEqual(request_2.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel_2.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 5448d9f0dc..23be1167a3 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -32,6 +32,7 @@ class FederationAckTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
+
return hs
def test_federation_ack_sent(self):
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
new file mode 100644
index 0000000000..83f9aa291c
--- /dev/null
+++ b/tests/replication/test_federation_sender_shard.py
@@ -0,0 +1,234 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.events.builder import EventBuilderFactory
+from synapse.rest.admin import register_servlets_for_client_rest_resource
+from synapse.rest.client.v1 import login, room
+from synapse.types import UserID
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.test_utils import make_awaitable
+
+logger = logging.getLogger(__name__)
+
+
+class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [
+ login.register_servlets,
+ register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ ]
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["send_federation"] = False
+ return conf
+
+ def test_send_event_single_sender(self):
+ """Test that using a single federation sender worker correctly sends a
+ new event.
+ """
+ mock_client = Mock(spec=["put_json"])
+ mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
+
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {"send_federation": True},
+ http_client=mock_client,
+ )
+
+ user = self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ room = self.create_room_with_remote_server(user, token)
+
+ mock_client.put_json.reset_mock()
+
+ self.create_and_send_event(room, UserID.from_string(user))
+ self.replicate()
+
+ # Assert that the event was sent out over federation.
+ mock_client.put_json.assert_called()
+ self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
+ self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
+
+ def test_send_event_sharded(self):
+ """Test that using two federation sender workers correctly sends
+ new events.
+ """
+ mock_client1 = Mock(spec=["put_json"])
+ mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender1",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client1,
+ )
+
+ mock_client2 = Mock(spec=["put_json"])
+ mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender2",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client2,
+ )
+
+ user = self.register_user("user2", "pass")
+ token = self.login("user2", "pass")
+
+ sent_on_1 = False
+ sent_on_2 = False
+ for i in range(20):
+ server_name = "other_server_%d" % (i,)
+ room = self.create_room_with_remote_server(user, token, server_name)
+ mock_client1.reset_mock() # type: ignore[attr-defined]
+ mock_client2.reset_mock() # type: ignore[attr-defined]
+
+ self.create_and_send_event(room, UserID.from_string(user))
+ self.replicate()
+
+ if mock_client1.put_json.called:
+ sent_on_1 = True
+ mock_client2.put_json.assert_not_called()
+ self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("pdus"))
+ elif mock_client2.put_json.called:
+ sent_on_2 = True
+ mock_client1.put_json.assert_not_called()
+ self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("pdus"))
+ else:
+ raise AssertionError(
+ "Expected send transaction from one or the other sender"
+ )
+
+ if sent_on_1 and sent_on_2:
+ break
+
+ self.assertTrue(sent_on_1)
+ self.assertTrue(sent_on_2)
+
+ def test_send_typing_sharded(self):
+ """Test that using two federation sender workers correctly sends
+ new typing EDUs.
+ """
+ mock_client1 = Mock(spec=["put_json"])
+ mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender1",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client1,
+ )
+
+ mock_client2 = Mock(spec=["put_json"])
+ mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender2",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client2,
+ )
+
+ user = self.register_user("user3", "pass")
+ token = self.login("user3", "pass")
+
+ typing_handler = self.hs.get_typing_handler()
+
+ sent_on_1 = False
+ sent_on_2 = False
+ for i in range(20):
+ server_name = "other_server_%d" % (i,)
+ room = self.create_room_with_remote_server(user, token, server_name)
+ mock_client1.reset_mock() # type: ignore[attr-defined]
+ mock_client2.reset_mock() # type: ignore[attr-defined]
+
+ self.get_success(
+ typing_handler.started_typing(
+ target_user=UserID.from_string(user),
+ auth_user=UserID.from_string(user),
+ room_id=room,
+ timeout=20000,
+ )
+ )
+
+ self.replicate()
+
+ if mock_client1.put_json.called:
+ sent_on_1 = True
+ mock_client2.put_json.assert_not_called()
+ self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("edus"))
+ elif mock_client2.put_json.called:
+ sent_on_2 = True
+ mock_client1.put_json.assert_not_called()
+ self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("edus"))
+ else:
+ raise AssertionError(
+ "Expected send transaction from one or the other sender"
+ )
+
+ if sent_on_1 and sent_on_2:
+ break
+
+ self.assertTrue(sent_on_1)
+ self.assertTrue(sent_on_2)
+
+ def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+ room = self.helper.create_room_as(user, tok=token)
+ store = self.hs.get_datastore()
+ federation = self.hs.get_handlers().federation_handler
+
+ prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
+ room_version = self.get_success(store.get_room_version(room))
+
+ factory = EventBuilderFactory(self.hs)
+ factory.hostname = remote_server
+
+ user_id = UserID("user", remote_server).to_string()
+
+ event_dict = {
+ "type": EventTypes.Member,
+ "state_key": user_id,
+ "content": {"membership": Membership.JOIN},
+ "sender": user_id,
+ "room_id": room,
+ }
+
+ builder = factory.for_room_version(room_version, event_dict)
+ join_event = self.get_success(builder.build(prev_event_ids))
+
+ self.get_success(federation.on_send_join_request(remote_server, join_event))
+ self.replicate()
+
+ return room
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
new file mode 100644
index 0000000000..2bdc6edbb1
--- /dev/null
+++ b/tests/replication/test_pusher_shard.py
@@ -0,0 +1,193 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+
+class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks pusher sharding works
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # Register a user who sends a message that we'll get notified about
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["start_pushers"] = False
+ return conf
+
+ def _create_pusher_and_send_msg(self, localpart):
+ # Create a user that will get push notifications
+ user_id = self.register_user(localpart, "pass")
+ access_token = self.login(localpart, "pass")
+
+ # Register a pusher
+ user_dict = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_dict["token_id"]
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "https://push.example.com/push"},
+ )
+ )
+
+ self.pump()
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # The other user sends some messages
+ response = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+ event_id = response["event_id"]
+
+ return event_id
+
+ def test_send_push_single_worker(self):
+ """Test that registration works when using a pusher worker.
+ """
+ http_client_mock = Mock(spec_set=["post_json_get_json"])
+ http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {"start_pushers": True},
+ proxied_http_client=http_client_mock,
+ )
+
+ event_id = self._create_pusher_and_send_msg("user")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock.post_json_get_json.assert_called_once()
+ self.assertEqual(
+ http_client_mock.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
+
+ def test_send_push_multiple_workers(self):
+ """Test that registration works when using sharded pusher workers.
+ """
+ http_client_mock1 = Mock(spec_set=["post_json_get_json"])
+ http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {
+ "start_pushers": True,
+ "worker_name": "pusher1",
+ "pusher_instances": ["pusher1", "pusher2"],
+ },
+ proxied_http_client=http_client_mock1,
+ )
+
+ http_client_mock2 = Mock(spec_set=["post_json_get_json"])
+ http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {
+ "start_pushers": True,
+ "worker_name": "pusher2",
+ "pusher_instances": ["pusher1", "pusher2"],
+ },
+ proxied_http_client=http_client_mock2,
+ )
+
+ # We choose a user name that we know should go to pusher1.
+ event_id = self._create_pusher_and_send_msg("user2")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock1.post_json_get_json.assert_called_once()
+ http_client_mock2.post_json_get_json.assert_not_called()
+ self.assertEqual(
+ http_client_mock1.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock1.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
+
+ http_client_mock1.post_json_get_json.reset_mock()
+ http_client_mock2.post_json_get_json.reset_mock()
+
+ # Now we choose a user name that we know should go to pusher2.
+ event_id = self._create_pusher_and_send_msg("user4")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock1.post_json_get_json.assert_not_called()
+ http_client_mock2.post_json_get_json.assert_called_once()
+ self.assertEqual(
+ http_client_mock2.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock2.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 977615ebef..0f1144fe1e 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.fetches = []
- def get_file(destination, path, output_stream, args=None, max_size=None):
+ async def get_file(destination, path, output_stream, args=None, max_size=None):
"""
Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code.
@@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
d = Deferred()
d.addCallback(write_to)
self.fetches.append((d, destination, path, args))
- return make_deferred_yieldable(d)
+ return await make_deferred_yieldable(d)
client = Mock()
client.get_file = get_file
@@ -220,6 +220,24 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
return hs
+ def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
+ """Ensure a piece of media is quarantined when trying to access it."""
+ request, channel = self.make_request(
+ "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id
+ ),
+ )
+
def test_quarantine_media_requires_admin(self):
self.register_user("nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("nonadmin", "pass")
@@ -292,24 +310,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Attempt to access the media
- request, channel = self.make_request(
- "GET",
- server_name_and_media_id,
- shorthand=False,
- access_token=admin_user_tok,
- )
- request.render(self.download_resource)
- self.pump(1.0)
-
- # Should be quarantined
- self.assertEqual(
- 404,
- int(channel.code),
- msg=(
- "Expected to receive a 404 on accessing quarantined media: %s"
- % server_name_and_media_id
- ),
- )
+ self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
def test_quarantine_all_media_in_room(self, override_url_template=None):
self.register_user("room_admin", "pass", admin=True)
@@ -371,45 +372,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_and_media_id_2 = mxc_2[6:]
# Test that we cannot download any of the media anymore
- request, channel = self.make_request(
- "GET",
- server_and_media_id_1,
- shorthand=False,
- access_token=non_admin_user_tok,
- )
- request.render(self.download_resource)
- self.pump(1.0)
-
- # Should be quarantined
- self.assertEqual(
- 404,
- int(channel.code),
- msg=(
- "Expected to receive a 404 on accessing quarantined media: %s"
- % server_and_media_id_1
- ),
- )
-
- request, channel = self.make_request(
- "GET",
- server_and_media_id_2,
- shorthand=False,
- access_token=non_admin_user_tok,
- )
- request.render(self.download_resource)
- self.pump(1.0)
-
- # Should be quarantined
- self.assertEqual(
- 404,
- int(channel.code),
- msg=(
- "Expected to receive a 404 on accessing quarantined media: %s"
- % server_and_media_id_2
- ),
- )
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
- def test_quaraantine_all_media_in_room_deprecated_api_path(self):
+ def test_quarantine_all_media_in_room_deprecated_api_path(self):
# Perform the above test with the deprecated API path
self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")
@@ -449,25 +415,52 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
)
# Attempt to access each piece of media
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
+
+ def test_cannot_quarantine_safe_media(self):
+ self.register_user("user_admin", "pass", admin=True)
+ admin_user_tok = self.login("user_admin", "pass")
+
+ non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("user_nonadmin", "pass")
+
+ # Upload some media
+ response_1 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+ response_2 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+
+ # Extract media IDs
+ server_and_media_id_1 = response_1["content_uri"][6:]
+ server_and_media_id_2 = response_2["content_uri"][6:]
+
+ # Mark the second item as safe from quarantine.
+ _, media_id_2 = server_and_media_id_2.split("/")
+ self.get_success(self.store.mark_local_media_as_safe(media_id_2))
+
+ # Quarantine all media by this user
+ url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
+ non_admin_user
+ )
request, channel = self.make_request(
- "GET",
- server_and_media_id_1,
- shorthand=False,
- access_token=non_admin_user_tok,
+ "POST", url.encode("ascii"), access_token=admin_user_tok,
)
- request.render(self.download_resource)
+ self.render(request)
self.pump(1.0)
-
- # Should be quarantined
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
- 404,
- int(channel.code),
- msg=(
- "Expected to receive a 404 on accessing quarantined media: %s"
- % server_and_media_id_1,
- ),
+ json.loads(channel.result["body"].decode("utf-8")),
+ {"num_quarantined": 1},
+ "Expected 1 quarantined item",
)
+ # Attempt to access each piece of media, the first should fail, the
+ # second should succeed.
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
+
# Attempt to access each piece of media
request, channel = self.make_request(
"GET",
@@ -478,12 +471,12 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
request.render(self.download_resource)
self.pump(1.0)
- # Should be quarantined
+ # Shouldn't be quarantined
self.assertEqual(
- 404,
+ 200,
int(channel.code),
msg=(
- "Expected to receive a 404 on accessing quarantined media: %s"
+ "Expected to receive a 200 on accessing not-quarantined media: %s"
% server_and_media_id_2
),
)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 54cd24bf64..408c568a27 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1,1007 +1,1500 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 Dirk Klimpel
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import urllib.parse
-from typing import List, Optional
-
-from mock import Mock
-
-import synapse.rest.admin
-from synapse.api.errors import Codes
-from synapse.rest.client.v1 import directory, events, login, room
-
-from tests import unittest
-
-"""Tests admin REST events for /rooms paths."""
-
-
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
-
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([self.other_user], users_in_room)
-
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
-
- # Assert that the user is getting consent error
- self.helper.send(
- room_id, body="foo", tok=self.other_user_token, expect_code=403
- )
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert there is now no longer anyone in the room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
-
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- """
-
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_token,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert we can no longer peek into the room
- self._assert_peek(room_id, expect_code=403)
-
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
-
- url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
-
-class PurgeRoomTestCase(unittest.HomeserverTestCase):
- """Test /purge_room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_purge_room(self):
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # All users have to have left the room.
- self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
-
- url = "/_synapse/admin/v1/purge_room"
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the following tables have been purged of all rows related to the room.
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "local_invites",
- "room_account_data",
- "room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
- "state_groups_state",
- ):
- count = self.get_success(
- self.store.db.simple_select_one_onecol(
- table=table,
- keyvalues={"room_id": room_id},
- retcol="COUNT(*)",
- desc="test_purge_room",
- )
- )
-
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
-
-
-class RoomTestCase(unittest.HomeserverTestCase):
- """Test /room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- directory.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- # Create user
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_list_rooms(self):
- """Test that we can list rooms"""
- # Create 3 test rooms
- total_rooms = 3
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
-
- # Check request completed successfully
- self.assertEqual(200, int(channel.code), msg=channel.json_body)
-
- # Check that response json body contains a "rooms" key
- self.assertTrue(
- "rooms" in channel.json_body,
- msg="Response body does not " "contain a 'rooms' key",
- )
-
- # Check that 3 rooms were returned
- self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
-
- # Check their room_ids match
- returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
- self.assertEqual(room_ids, returned_room_ids)
-
- # Check that all fields are available
- for r in channel.json_body["rooms"]:
- self.assertIn("name", r)
- self.assertIn("canonical_alias", r)
- self.assertIn("joined_members", r)
- self.assertIn("joined_local_members", r)
- self.assertIn("version", r)
- self.assertIn("creator", r)
- self.assertIn("encryption", r)
- self.assertIn("federatable", r)
- self.assertIn("public", r)
- self.assertIn("join_rules", r)
- self.assertIn("guest_access", r)
- self.assertIn("history_visibility", r)
- self.assertIn("state_events", r)
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # Should be 0 as we aren't paginating
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that the prev_batch parameter is not present
- self.assertNotIn("prev_batch", channel.json_body)
-
- # We shouldn't receive a next token here as there's no further rooms to show
- self.assertNotIn("next_batch", channel.json_body)
-
- def test_list_rooms_pagination(self):
- """Test that we can get a full list of rooms through pagination"""
- # Create 5 test rooms
- total_rooms = 5
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Set the name of the rooms so we get a consistent returned ordering
- for idx, room_id in enumerate(room_ids):
- self.helper.send_state(
- room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- returned_room_ids = []
- start = 0
- limit = 2
-
- run_count = 0
- should_repeat = True
- while should_repeat:
- run_count += 1
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
- start,
- limit,
- "name",
- )
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- self.assertTrue("rooms" in channel.json_body)
- for r in channel.json_body["rooms"]:
- returned_room_ids.append(r["room_id"])
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # We're only getting 2 rooms each page, so should be 2 * last run_count
- self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
-
- if run_count > 1:
- # Check the value of prev_batch is correct
- self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
-
- if "next_batch" not in channel.json_body:
- # We have reached the end of the list
- should_repeat = False
- else:
- # Make another query with an updated start value
- start = channel.json_body["next_batch"]
-
- # We should've queried the endpoint 3 times
- self.assertEqual(
- run_count,
- 3,
- msg="Should've queried 3 times for 5 rooms with limit 2 per query",
- )
-
- # Check that we received all of the room ids
- self.assertEqual(room_ids, returned_room_ids)
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- def test_correct_room_attributes(self):
- """Test the correct attributes for a room are returned"""
- # Create a test room
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- test_alias = "#test:test"
- test_room_name = "something"
-
- # Have another user join the room
- user_2 = self.register_user("user4", "pass")
- user_tok_2 = self.login("user4", "pass")
- self.helper.join(room_id, user_2, tok=user_tok_2)
-
- # Create a new alias to this room
- url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Set this new alias as the canonical alias for this room
- self.helper.send_state(
- room_id,
- "m.room.aliases",
- {"aliases": [test_alias]},
- tok=self.admin_user_tok,
- state_key="test",
- )
- self.helper.send_state(
- room_id,
- "m.room.canonical_alias",
- {"alias": test_alias},
- tok=self.admin_user_tok,
- )
-
- # Set a name for the room
- self.helper.send_state(
- room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that only one room was returned
- self.assertEqual(len(rooms), 1)
-
- # And that the value of the total_rooms key was correct
- self.assertEqual(channel.json_body["total_rooms"], 1)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that all provided attributes are set
- r = rooms[0]
- self.assertEqual(room_id, r["room_id"])
- self.assertEqual(test_room_name, r["name"])
- self.assertEqual(test_alias, r["canonical_alias"])
-
- def test_room_list_sort_order(self):
- """Test room list sort ordering. alphabetical name versus number of members,
- reversing the order, etc.
- """
-
- def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
- # Create a new alias to this room
- url = "/_matrix/client/r0/directory/room/%s" % (
- urllib.parse.quote(test_alias),
- )
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=admin_user_tok,
- )
- self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- # Set this new alias as the canonical alias for this room
- self.helper.send_state(
- room_id,
- "m.room.aliases",
- {"aliases": [test_alias]},
- tok=admin_user_tok,
- state_key="test",
- )
- self.helper.send_state(
- room_id,
- "m.room.canonical_alias",
- {"alias": test_alias},
- tok=admin_user_tok,
- )
-
- def _order_test(
- order_type: str, expected_room_list: List[str], reverse: bool = False,
- ):
- """Request the list of rooms in a certain order. Assert that order is what
- we expect
-
- Args:
- order_type: The type of ordering to give the server
- expected_room_list: The list of room_ids in the order we expect to get
- back from the server
- """
- # Request the list of rooms in the given order
- url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
- if reverse:
- url += "&dir=b"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check for the correct total_rooms value
- self.assertEqual(channel.json_body["total_rooms"], 3)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that rooms were returned in alphabetical order
- returned_order = [r["room_id"] for r in rooms]
- self.assertListEqual(expected_room_list, returned_order) # order is checked
-
- # Create 3 test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
- )
-
- # Set room canonical room aliases
- _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
- _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
- _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
-
- # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
- user_1 = self.register_user("bob1", "pass")
- user_1_tok = self.login("bob1", "pass")
- self.helper.join(room_id_2, user_1, tok=user_1_tok)
-
- user_2 = self.register_user("bob2", "pass")
- user_2_tok = self.login("bob2", "pass")
- self.helper.join(room_id_3, user_2, tok=user_2_tok)
-
- user_3 = self.register_user("bob3", "pass")
- user_3_tok = self.login("bob3", "pass")
- self.helper.join(room_id_3, user_3, tok=user_3_tok)
-
- # Test different sort orders, with forward and reverse directions
- _order_test("name", [room_id_1, room_id_2, room_id_3])
- _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
- _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
- _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
- _order_test(
- "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
- )
-
- _order_test("version", [room_id_1, room_id_2, room_id_3])
- _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("creator", [room_id_1, room_id_2, room_id_3])
- _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("encryption", [room_id_1, room_id_2, room_id_3])
- _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("federatable", [room_id_1, room_id_2, room_id_3])
- _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("public", [room_id_1, room_id_2, room_id_3])
- # Different sort order of SQlite and PostreSQL
- # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
- _order_test(
- "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
- )
-
- _order_test("state_events", [room_id_3, room_id_2, room_id_1])
- _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- def test_search_term(self):
- """Test that searching for a room works correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- room_name_1 = "something"
- room_name_2 = "else"
-
- # Set the name for each room
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
- )
-
- def _search_test(
- expected_room_id: Optional[str],
- search_term: str,
- expected_http_code: int = 200,
- ):
- """Search for a room and check that the returned room's id is a match
-
- Args:
- expected_room_id: The room_id expected to be returned by the API. Set
- to None to expect zero results for the search
- search_term: The term to search for room names with
- expected_http_code: The expected http code for the request
- """
- url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
-
- if expected_http_code != 200:
- return
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that the expected number of rooms were returned
- expected_room_count = 1 if expected_room_id else 0
- self.assertEqual(len(rooms), expected_room_count)
- self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- if expected_room_id:
- # Check that the first returned room id is correct
- r = rooms[0]
- self.assertEqual(expected_room_id, r["room_id"])
-
- # Perform search tests
- _search_test(room_id_1, "something")
- _search_test(room_id_1, "thing")
-
- _search_test(room_id_2, "else")
- _search_test(room_id_2, "se")
-
- _search_test(None, "foo")
- _search_test(None, "bar")
- _search_test(None, "", expected_http_code=400)
-
- def test_single_room(self):
- """Test that a single room can be requested correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- room_name_1 = "something"
- room_name_2 = "else"
-
- # Set the name for each room
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
- )
-
- url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- self.assertIn("room_id", channel.json_body)
- self.assertIn("name", channel.json_body)
- self.assertIn("canonical_alias", channel.json_body)
- self.assertIn("joined_members", channel.json_body)
- self.assertIn("joined_local_members", channel.json_body)
- self.assertIn("version", channel.json_body)
- self.assertIn("creator", channel.json_body)
- self.assertIn("encryption", channel.json_body)
- self.assertIn("federatable", channel.json_body)
- self.assertIn("public", channel.json_body)
- self.assertIn("join_rules", channel.json_body)
- self.assertIn("guest_access", channel.json_body)
- self.assertIn("history_visibility", channel.json_body)
- self.assertIn("state_events", channel.json_body)
-
- self.assertEqual(room_id_1, channel.json_body["room_id"])
-
-
-class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
-
- servlets = [
- synapse.rest.admin.register_servlets,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor, clock, homeserver):
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.creator = self.register_user("creator", "test")
- self.creator_tok = self.login("creator", "test")
-
- self.second_user_id = self.register_user("second", "test")
- self.second_tok = self.login("second", "test")
-
- self.public_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=True
- )
- self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
-
- def test_requester_is_no_admin(self):
- """
- If the user is not a server admin, an error 403 is returned.
- """
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.second_tok,
- )
- self.render(request)
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_invalid_parameter(self):
- """
- If a parameter is missing, return an error
- """
- body = json.dumps({"unknown_parameter": "@unknown:test"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
-
- def test_local_user_does_not_exist(self):
- """
- Tests that a lookup for a user that does not exist returns a 404
- """
- body = json.dumps({"user_id": "@unknown:test"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_remote_user(self):
- """
- Check that only local user can join rooms.
- """
- body = json.dumps({"user_id": "@not:exist.bla"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "This endpoint can only be used with local users",
- channel.json_body["error"],
- )
-
- def test_room_does_not_exist(self):
- """
- Check that unknown rooms/server return error 404.
- """
- body = json.dumps({"user_id": self.second_user_id})
- url = "/_synapse/admin/v1/join/!unknown:test"
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("No known servers", channel.json_body["error"])
-
- def test_room_is_not_valid(self):
- """
- Check that invalid room names, return an error 400.
- """
- body = json.dumps({"user_id": self.second_user_id})
- url = "/_synapse/admin/v1/join/invalidroom"
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "invalidroom was not legal room ID or room alias",
- channel.json_body["error"],
- )
-
- def test_join_public_room(self):
- """
- Test joining a local user to a public room with "JoinRules.PUBLIC"
- """
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.public_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
-
- def test_join_private_room_if_not_member(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE"
- when server admin is not member of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=False
- )
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_join_private_room_if_member(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE",
- when server admin is member of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=False
- )
- self.helper.invite(
- room=private_room_id,
- src=self.creator,
- targ=self.admin_user,
- tok=self.creator_tok,
- )
- self.helper.join(
- room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
- )
-
- # Validate if server admin is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
-
- # Join user to room.
-
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
-
- def test_join_private_room_if_owner(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE",
- when server admin is owner of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok, is_public=False
- )
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+from typing import List, Optional
+
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import directory, events, login, room
+
+from tests import unittest
+
+"""Tests admin REST events for /rooms paths."""
+
+
+class ShutdownRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ room_id, body="foo", tok=self.other_user_token, expect_code=403
+ )
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert there is now no longer anyone in the room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ """
+
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_token,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(room_id, expect_code=403)
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+
+class DeleteRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ request, channel = self.make_request(
+ "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_room_does_not_exist(self):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+ url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
+
+ request, channel = self.make_request(
+ "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_room_is_not_valid(self):
+ """
+ Check that invalid room names, return an error 400.
+ """
+ url = "/_synapse/admin/v1/rooms/invalidroom/delete"
+
+ request, channel = self.make_request(
+ "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "invalidroom is not a legal room ID", channel.json_body["error"],
+ )
+
+ def test_new_room_user_does_not_exist(self):
+ """
+ Tests that the user ID must be from local server but it does not have to exist.
+ """
+ body = json.dumps({"new_room_user_id": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("kicked_users", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ def test_new_room_user_is_not_local(self):
+ """
+ Check that only local users can create new room to move members.
+ """
+ body = json.dumps({"new_room_user_id": "@not:exist.bla"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "User must be our own: @not:exist.bla", channel.json_body["error"],
+ )
+
+ def test_block_is_not_bool(self):
+ """
+ If parameter `block` is not boolean, return an error
+ """
+ body = json.dumps({"block": "NotBool"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_purge_is_not_bool(self):
+ """
+ If parameter `purge` is not boolean, return an error
+ """
+ body = json.dumps({"purge": "NotBool"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_purge_room_and_block(self):
+ """Test to purge a room and block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ body = json.dumps({"block": True, "purge": True})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url.encode("ascii"),
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(None, channel.json_body["new_room_id"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=True)
+ self._has_no_members(self.room_id)
+
+ def test_purge_room_and_not_block(self):
+ """Test to purge a room and do not block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ body = json.dumps({"block": False, "purge": True})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url.encode("ascii"),
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(None, channel.json_body["new_room_id"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=False)
+ self._has_no_members(self.room_id)
+
+ def test_block_room_and_not_purge(self):
+ """Test to block a room without purging it.
+ Members will not be moved to a new room and will not receive a message.
+ The room will not be purged.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ body = json.dumps({"block": False, "purge": False})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url.encode("ascii"),
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(None, channel.json_body["new_room_id"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=False)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+ )
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["new_room_id"], user_id=self.other_user
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["new_room_id"], user_id=self.other_user
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(self.room_id, expect_code=403)
+
+ def _is_blocked(self, room_id, expect=True):
+ """Assert that the room is blocked or not
+ """
+ d = self.store.is_room_blocked(room_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertIsNone(self.get_success(d))
+
+ def _has_no_members(self, room_id):
+ """Assert there is now no longer anyone in the room
+ """
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def _is_member(self, room_id, user_id):
+ """Test that user is member of the room
+ """
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertIn(user_id, users_in_room)
+
+ def _is_purged(self, room_id):
+ """Test that the following tables have been purged of all rows related to the room.
+ """
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+ """Test /purge_room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_purge_room(self):
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # All users have to have left the room.
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/purge_room"
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the following tables have been purged of all rows related to the room.
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+
+class RoomTestCase(unittest.HomeserverTestCase):
+ """Test /room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_list_rooms(self):
+ """Test that we can list rooms"""
+ # Create 3 test rooms
+ total_rooms = 3
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ # Check request completed successfully
+ self.assertEqual(200, int(channel.code), msg=channel.json_body)
+
+ # Check that response json body contains a "rooms" key
+ self.assertTrue(
+ "rooms" in channel.json_body,
+ msg="Response body does not " "contain a 'rooms' key",
+ )
+
+ # Check that 3 rooms were returned
+ self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
+
+ # Check their room_ids match
+ returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
+ self.assertEqual(room_ids, returned_room_ids)
+
+ # Check that all fields are available
+ for r in channel.json_body["rooms"]:
+ self.assertIn("name", r)
+ self.assertIn("canonical_alias", r)
+ self.assertIn("joined_members", r)
+ self.assertIn("joined_local_members", r)
+ self.assertIn("version", r)
+ self.assertIn("creator", r)
+ self.assertIn("encryption", r)
+ self.assertIn("federatable", r)
+ self.assertIn("public", r)
+ self.assertIn("join_rules", r)
+ self.assertIn("guest_access", r)
+ self.assertIn("history_visibility", r)
+ self.assertIn("state_events", r)
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # Should be 0 as we aren't paginating
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that the prev_batch parameter is not present
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # We shouldn't receive a next token here as there's no further rooms to show
+ self.assertNotIn("next_batch", channel.json_body)
+
+ def test_list_rooms_pagination(self):
+ """Test that we can get a full list of rooms through pagination"""
+ # Create 5 test rooms
+ total_rooms = 5
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Set the name of the rooms so we get a consistent returned ordering
+ for idx, room_id in enumerate(room_ids):
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ returned_room_ids = []
+ start = 0
+ limit = 2
+
+ run_count = 0
+ should_repeat = True
+ while should_repeat:
+ run_count += 1
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
+ start,
+ limit,
+ "name",
+ )
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ self.assertTrue("rooms" in channel.json_body)
+ for r in channel.json_body["rooms"]:
+ returned_room_ids.append(r["room_id"])
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # We're only getting 2 rooms each page, so should be 2 * last run_count
+ self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
+
+ if run_count > 1:
+ # Check the value of prev_batch is correct
+ self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
+
+ if "next_batch" not in channel.json_body:
+ # We have reached the end of the list
+ should_repeat = False
+ else:
+ # Make another query with an updated start value
+ start = channel.json_body["next_batch"]
+
+ # We should've queried the endpoint 3 times
+ self.assertEqual(
+ run_count,
+ 3,
+ msg="Should've queried 3 times for 5 rooms with limit 2 per query",
+ )
+
+ # Check that we received all of the room ids
+ self.assertEqual(room_ids, returned_room_ids)
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_correct_room_attributes(self):
+ """Test the correct attributes for a room are returned"""
+ # Create a test room
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ test_alias = "#test:test"
+ test_room_name = "something"
+
+ # Have another user join the room
+ user_2 = self.register_user("user4", "pass")
+ user_tok_2 = self.login("user4", "pass")
+ self.helper.join(room_id, user_2, tok=user_tok_2)
+
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=self.admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=self.admin_user_tok,
+ )
+
+ # Set a name for the room
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that only one room was returned
+ self.assertEqual(len(rooms), 1)
+
+ # And that the value of the total_rooms key was correct
+ self.assertEqual(channel.json_body["total_rooms"], 1)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that all provided attributes are set
+ r = rooms[0]
+ self.assertEqual(room_id, r["room_id"])
+ self.assertEqual(test_room_name, r["name"])
+ self.assertEqual(test_alias, r["canonical_alias"])
+
+ def test_room_list_sort_order(self):
+ """Test room list sort ordering. alphabetical name versus number of members,
+ reversing the order, etc.
+ """
+
+ def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (
+ urllib.parse.quote(test_alias),
+ )
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=admin_user_tok,
+ )
+
+ def _order_test(
+ order_type: str, expected_room_list: List[str], reverse: bool = False,
+ ):
+ """Request the list of rooms in a certain order. Assert that order is what
+ we expect
+
+ Args:
+ order_type: The type of ordering to give the server
+ expected_room_list: The list of room_ids in the order we expect to get
+ back from the server
+ """
+ # Request the list of rooms in the given order
+ url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
+ if reverse:
+ url += "&dir=b"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check for the correct total_rooms value
+ self.assertEqual(channel.json_body["total_rooms"], 3)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that rooms were returned in alphabetical order
+ returned_order = [r["room_id"] for r in rooms]
+ self.assertListEqual(expected_room_list, returned_order) # order is checked
+
+ # Create 3 test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+ )
+
+ # Set room canonical room aliases
+ _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
+
+ # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
+ user_1 = self.register_user("bob1", "pass")
+ user_1_tok = self.login("bob1", "pass")
+ self.helper.join(room_id_2, user_1, tok=user_1_tok)
+
+ user_2 = self.register_user("bob2", "pass")
+ user_2_tok = self.login("bob2", "pass")
+ self.helper.join(room_id_3, user_2, tok=user_2_tok)
+
+ user_3 = self.register_user("bob3", "pass")
+ user_3_tok = self.login("bob3", "pass")
+ self.helper.join(room_id_3, user_3, tok=user_3_tok)
+
+ # Test different sort orders, with forward and reverse directions
+ _order_test("name", [room_id_1, room_id_2, room_id_3])
+ _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
+ _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
+ _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
+ _order_test(
+ "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("version", [room_id_1, room_id_2, room_id_3])
+ _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("creator", [room_id_1, room_id_2, room_id_3])
+ _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3])
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3])
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("public", [room_id_1, room_id_2, room_id_3])
+ # Different sort order of SQlite and PostreSQL
+ # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
+ _order_test(
+ "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("state_events", [room_id_3, room_id_2, room_id_1])
+ _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ def test_search_term(self):
+ """Test that searching for a room works correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ def _search_test(
+ expected_room_id: Optional[str],
+ search_term: str,
+ expected_http_code: int = 200,
+ ):
+ """Search for a room and check that the returned room's id is a match
+
+ Args:
+ expected_room_id: The room_id expected to be returned by the API. Set
+ to None to expect zero results for the search
+ search_term: The term to search for room names with
+ expected_http_code: The expected http code for the request
+ """
+ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+ if expected_http_code != 200:
+ return
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that the expected number of rooms were returned
+ expected_room_count = 1 if expected_room_id else 0
+ self.assertEqual(len(rooms), expected_room_count)
+ self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ if expected_room_id:
+ # Check that the first returned room id is correct
+ r = rooms[0]
+ self.assertEqual(expected_room_id, r["room_id"])
+
+ # Perform search tests
+ _search_test(room_id_1, "something")
+ _search_test(room_id_1, "thing")
+
+ _search_test(room_id_2, "else")
+ _search_test(room_id_2, "se")
+
+ _search_test(None, "foo")
+ _search_test(None, "bar")
+ _search_test(None, "", expected_http_code=400)
+
+ def test_single_room(self):
+ """Test that a single room can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertIn("room_id", channel.json_body)
+ self.assertIn("name", channel.json_body)
+ self.assertIn("canonical_alias", channel.json_body)
+ self.assertIn("joined_members", channel.json_body)
+ self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("version", channel.json_body)
+ self.assertIn("creator", channel.json_body)
+ self.assertIn("encryption", channel.json_body)
+ self.assertIn("federatable", channel.json_body)
+ self.assertIn("public", channel.json_body)
+ self.assertIn("join_rules", channel.json_body)
+ self.assertIn("guest_access", channel.json_body)
+ self.assertIn("history_visibility", channel.json_body)
+ self.assertIn("state_events", channel.json_body)
+
+ self.assertEqual(room_id_1, channel.json_body["room_id"])
+
+ def test_room_members(self):
+ """Test that room members can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Have another user join the room
+ user_1 = self.register_user("foo", "pass")
+ user_tok_1 = self.login("foo", "pass")
+ self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+ # Have another user join the room
+ user_2 = self.register_user("bar", "pass")
+ user_tok_2 = self.login("bar", "pass")
+ self.helper.join(room_id_1, user_2, tok=user_tok_2)
+ self.helper.join(room_id_2, user_2, tok=user_tok_2)
+
+ # Have another user join the room
+ user_3 = self.register_user("foobar", "pass")
+ user_tok_3 = self.login("foobar", "pass")
+ self.helper.join(room_id_2, user_3, tok=user_tok_3)
+
+ url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertCountEqual(
+ ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
+ )
+ self.assertEqual(channel.json_body["total"], 3)
+
+ url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertCountEqual(
+ ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
+ )
+ self.assertEqual(channel.json_body["total"], 3)
+
+
+class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.public_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+ self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.second_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self):
+ """
+ If a parameter is missing, return an error
+ """
+ body = json.dumps({"unknown_parameter": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ def test_local_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ body = json.dumps({"user_id": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_remote_user(self):
+ """
+ Check that only local user can join rooms.
+ """
+ body = json.dumps({"user_id": "@not:exist.bla"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "This endpoint can only be used with local users",
+ channel.json_body["error"],
+ )
+
+ def test_room_does_not_exist(self):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/!unknown:test"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("No known servers", channel.json_body["error"])
+
+ def test_room_is_not_valid(self):
+ """
+ Check that invalid room names, return an error 400.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/invalidroom"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "invalidroom was not legal room ID or room alias",
+ channel.json_body["error"],
+ )
+
+ def test_join_public_room(self):
+ """
+ Test joining a local user to a public room with "JoinRules.PUBLIC"
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_not_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE"
+ when server admin is not member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_join_private_room_if_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ self.helper.invite(
+ room=private_room_id,
+ src=self.creator,
+ targ=self.admin_user,
+ tok=self.creator_tok,
+ )
+ self.helper.join(
+ room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
+ )
+
+ # Validate if server admin is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ # Join user to room.
+
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_owner(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is owner of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index cca5f548e6..17d0aae2e9 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -20,6 +20,8 @@ import urllib.parse
from mock import Mock
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import HttpResponseException, ResourceLimitError
@@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore()
# Set monthly active users to the limit
- store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
+ store.get_monthly_active_count = Mock(
+ return_value=defer.succeed(self.hs.config.max_mau_value)
+ )
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
self.get_failure(
@@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=self.hs.config.max_mau_value
+ return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=self.hs.config.max_mau_value
+ return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -857,6 +861,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
+ def test_reactivate_user(self):
+ """
+ Test reactivating another user.
+ """
+
+ # Deactivate the user.
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Attempt to reactivate the user (without a password).
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Reactivate the user.
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=json.dumps({"deactivated": False, "password": "foo"}).encode(
+ encoding="utf_8"
+ ),
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+
def test_set_user_as_admin(self):
"""
Test setting the admin flag on a user.
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 95475bb651..d4e7fa1293 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -45,50 +45,63 @@ class RetentionTestCase(unittest.HomeserverTestCase):
}
self.hs = self.setup_test_homeserver(config=config)
+
return self.hs
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
- def test_retention_state_event(self):
- """Tests that the server configuration can limit the values a user can set to the
- room's retention policy.
+ self.store = self.hs.get_datastore()
+ self.serializer = self.hs.get_event_client_serializer()
+ self.clock = self.hs.get_clock()
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
"""
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_day_ms * 4},
+ body={"max_lifetime": lifetime},
tok=self.token,
- expect_code=400,
)
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_with_state_event_outside_allowed(self):
+ """Tests that the server configuration can override the policy for a room when
+ running the purge jobs.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set a max_lifetime higher than the maximum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_hour_ms},
+ body={"max_lifetime": one_day_ms * 4},
tok=self.token,
- expect_code=400,
)
- def test_retention_event_purged_with_state_event(self):
- """Tests that expired events are correctly purged when the room's retention policy
- is defined by a state event.
- """
- room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Check that the event is purged after waiting for the maximum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
- # Set the room's retention period to 2 days.
- lifetime = one_day_ms * 2
+ # Set a max_lifetime lower than the minimum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": lifetime},
+ body={"max_lifetime": one_hour_ms},
tok=self.token,
)
- self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+ # Check that the event is purged after waiting for the minimum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 0.5)
def test_retention_event_purged_without_state_event(self):
"""Tests that expired events are correctly purged when the room's retention policy
@@ -126,7 +139,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
events.append(self.get_success(store.get_event(valid_event_id)))
- # Advance the time by anothe 2 days. After this, the first event should be
+ # Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one.
self.reactor.advance(one_day_ms * 2 / 1000)
@@ -140,11 +153,33 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# That event should be the second, not outdated event.
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
- def _test_retention_event_purged(self, room_id, increment):
+ def _test_retention_event_purged(self, room_id: str, increment: float):
+ """Run the following test scenario to test the message retention policy support:
+
+ 1. Send event 1
+ 2. Increment time by `increment`
+ 3. Send event 2
+ 4. Increment time by `increment`
+ 5. Check that event 1 has been purged
+ 6. Check that event 2 has not been purged
+ 7. Check that state events that were sent before event 1 aren't purged.
+ The main reason for sending a second event is because currently Synapse won't
+ purge the latest message in a room because it would otherwise result in a lack of
+ forward extremities for this room. It's also a good thing to ensure the purge jobs
+ aren't too greedy and purge messages they shouldn't.
+
+ Args:
+ room_id: The ID of the room to test retention in.
+ increment: The number of milliseconds to advance the clock each time. Must be
+ defined so that events in the room aren't purged if they are `increment`
+ old but are purged if they are `increment * 2` old.
+ """
# Get the create event to, later, check that we can still access it.
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
- message_handler.get_room_data(self.user_id, room_id, EventTypes.Create)
+ message_handler.get_room_data(
+ self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False
+ )
)
# Send a first event to the room. This is the event we'll want to be purged at the
@@ -154,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
expired_event_id = resp.get("event_id")
# Check that we can retrieve the event.
- expired_event = self.get_event(room_id, expired_event_id)
+ expired_event = self.get_event(expired_event_id)
self.assertEqual(
expired_event.get("content", {}).get("body"), "1", expired_event
)
@@ -172,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# one should still be kept.
self.reactor.advance(increment / 1000)
- # Check that the event has been purged from the database.
- self.get_event(room_id, expired_event_id, expected_code=404)
+ # Check that the first event has been purged from the database, i.e. that we
+ # can't retrieve it anymore, because it has expired.
+ self.get_event(expired_event_id, expect_none=True)
- # Check that the event that hasn't been purged can still be retrieved.
- valid_event = self.get_event(room_id, valid_event_id)
+ # Check that the event that hasn't expired can still be retrieved.
+ valid_event = self.get_event(valid_event_id)
self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
# Check that we can still access state events that were sent before the event that
# has been purged.
self.get_event(room_id, create_event.event_id)
- def get_event(self, room_id, event_id, expected_code=200):
- url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+ def get_event(self, event_id, expect_none=False):
+ event = self.get_success(self.store.get_event(event_id, allow_none=True))
- request, channel = self.make_request("GET", url, access_token=self.token)
- self.render(request)
+ if expect_none:
+ self.assertIsNone(event)
+ return {}
- self.assertEqual(channel.code, expected_code, channel.result)
+ self.assertIsNotNone(event)
- return channel.json_body
+ time_now = self.clock.time_msec()
+ serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+
+ return serialized
class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index c6d8f24fe9..3ddcca288b 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -396,7 +390,7 @@ class CASTestCase(unittest.HomeserverTestCase):
</cas:serviceResponse>
"""
% cas_user_id
- )
+ ).encode("utf-8")
mocked_http_client = Mock(spec=["get_raw"])
mocked_http_client.get_raw.side_effect = get_raw
@@ -512,19 +506,22 @@ class JWTTestCase(unittest.HomeserverTestCase):
]
jwt_secret = "secret"
+ jwt_algorithm = "HS256"
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt_enabled = True
self.hs.config.jwt_secret = self.jwt_secret
- self.hs.config.jwt_algorithm = "HS256"
+ self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs
def jwt_encode(self, token, secret=jwt_secret):
- return jwt.encode(token, secret, "HS256").decode("ascii")
+ return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
def jwt_login(self, *args):
- params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ params = json.dumps(
+ {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ )
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel
@@ -542,35 +539,126 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, "notsecret")
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: Signature verification failed",
+ )
def test_login_jwt_expired(self):
channel = self.jwt_login({"sub": "frog", "exp": 864000})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "JWT expired")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Signature has expired"
+ )
def test_login_jwt_not_before(self):
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: The token is not yet valid (nbf)",
+ )
def test_login_no_sub(self):
channel = self.jwt_login({"username": "root"})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ @override_config(
+ {
+ "jwt_config": {
+ "jwt_enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ "issuer": "test-issuer",
+ }
+ }
+ )
+ def test_login_iss(self):
+ """Test validating the issuer claim."""
+ # A valid issuer.
+ channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ # An invalid issuer.
+ channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid issuer"
+ )
+
+ # Not providing an issuer.
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ 'JWT validation failed: Token is missing the "iss" claim',
+ )
+
+ def test_login_iss_no_config(self):
+ """Test providing an issuer claim without requiring it in the configuration."""
+ channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ @override_config(
+ {
+ "jwt_config": {
+ "jwt_enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ "audiences": ["test-audience"],
+ }
+ }
+ )
+ def test_login_aud(self):
+ """Test validating the audience claim."""
+ # A valid audience.
+ channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ # An invalid audience.
+ channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid audience"
+ )
+
+ # Not providing an audience.
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ 'JWT validation failed: Token is missing the "aud" claim',
+ )
+
+ def test_login_aud_no_config(self):
+ """Test providing an audience without requiring it in the configuration."""
+ channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid audience"
+ )
+
def test_login_no_token(self):
- params = json.dumps({"type": "m.login.jwt"})
+ params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -638,7 +726,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
return jwt.encode(token, secret, "RS256").decode("ascii")
def jwt_login(self, *args):
- params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ params = json.dumps(
+ {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ )
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel
@@ -650,6 +740,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: Signature verification failed",
+ )
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 0fdff79aa7..3c66255dac 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -60,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def test_put_presence_disabled(self):
"""
- PUT to the status endpoint with use_presence disbled will NOT call
+ PUT to the status endpoint with use_presence disabled will NOT call
set_state on the presence handler.
"""
self.hs.config.use_presence = False
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 8df58b4a63..ace0a3c08d 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
profile_handler=self.mock_handler,
)
- def _get_user_by_req(request=None, allow_guest=False):
- return defer.succeed(synapse.types.create_requester(myid))
+ async def _get_user_by_req(request=None, allow_guest=False):
+ return synapse.types.create_requester(myid)
hs.get_auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 4886bbb401..c6c6edeac2 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -19,18 +19,16 @@
"""Tests REST events for /rooms paths."""
import json
+from urllib import parse as urlparse
-from mock import Mock
-from six.moves.urllib import parse as urlparse
-
-from twisted.internet import defer
+from mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
-from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias
+from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet
+from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
from tests import unittest
@@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
self.hs.get_federation_handler = Mock(return_value=Mock())
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
self.hs.get_datastore().insert_client_ip = _insert_client_ip
@@ -677,6 +675,91 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
+class RoomJoinRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit(self):
+ """Tests that local joins are actually rate-limited."""
+ for i in range(5):
+ self.helper.create_room_as(self.user_id)
+
+ self.helper.create_room_as(self.user_id, expect_code=429)
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_profile_change(self):
+ """Tests that sending a profile update into all of the user's joined rooms isn't
+ rate-limited by the rate-limiter on joins."""
+
+ # Create and join more rooms than the rate-limiting config allows in a second.
+ room_ids = [
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ ]
+ self.reactor.advance(1)
+ room_ids = room_ids + [
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ ]
+
+ # Create a profile for the user, since it hasn't been done on registration.
+ store = self.hs.get_datastore()
+ store.create_profile(UserID.from_string(self.user_id).localpart)
+
+ # Update the display name for the user.
+ path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
+ request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ # Check that all the rooms have been sent a profile update into.
+ for room_id in room_ids:
+ path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
+ room_id,
+ self.user_id,
+ )
+
+ request, channel = self.make_request("GET", path)
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+ self.assertIn("displayname", channel.json_body)
+ self.assertEquals(channel.json_body["displayname"], "John Doe")
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_idempotent(self):
+ """Tests that the room join endpoints remain idempotent despite rate-limiting
+ on room joins."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # Let's test both paths to be sure.
+ paths_to_test = [
+ "/_matrix/client/r0/rooms/%s/join",
+ "/_matrix/client/r0/join/%s",
+ ]
+
+ for path in paths_to_test:
+ # Make sure we send more requests than the rate-limiting config would allow
+ # if all of these requests ended up joining the user to a room.
+ for i in range(6):
+ request, channel = self.make_request("POST", path % room_id, {})
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
@@ -1976,3 +2059,158 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class ShadowBannedTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.banned_user_id = self.register_user("banned", "test")
+ self.banned_access_token = self.login("banned", "test")
+
+ self.store = self.hs.get_datastore()
+
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="users",
+ keyvalues={"name": self.banned_user_id},
+ updatevalues={"shadow_banned": True},
+ desc="shadow_ban",
+ )
+ )
+
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def test_invite(self):
+ """Invites from shadow-banned users don't actually get sent."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ self.helper.invite(
+ room=room_id,
+ src=self.banned_user_id,
+ tok=self.banned_access_token,
+ targ=self.other_user_id,
+ )
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ def test_invite_3pid(self):
+ """Ensure that a 3PID invite does not attempt to contact the identity server."""
+ identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler.lookup_3pid = Mock(
+ side_effect=AssertionError("This should not get called")
+ )
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/invite" % (room_id,),
+ {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # This should have raised an error earlier, but double check this wasn't called.
+ identity_handler.lookup_3pid.assert_not_called()
+
+ def test_create_room(self):
+ """Invitations during a room creation should be discarded, but the room still gets created."""
+ # The room creation is successful.
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ {"visibility": "public", "invite": [self.other_user_id]},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ room_id = channel.json_body["room_id"]
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ # Since a real room was created, the other user should be able to join it.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ # Both users should be in the room.
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
+
+ def test_message(self):
+ """Messages from shadow-banned users don't actually get sent."""
+
+ room_id = self.helper.create_room_as(
+ self.other_user_id, tok=self.other_access_token
+ )
+
+ # The user should be in the room.
+ self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
+
+ # Sending a message should complete successfully.
+ result = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "with right label"},
+ tok=self.banned_access_token,
+ )
+ self.assertIn("event_id", result)
+ event_id = result["event_id"]
+
+ latest_events = self.get_success(
+ self.store.get_latest_event_ids_in_room(room_id)
+ )
+ self.assertNotIn(event_id, latest_events)
+
+ def test_upgrade(self):
+ """A room upgrade should fail, but look like it succeeded."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
+ {"new_version": "6"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ # A new room_id should be returned.
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(new_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 18260bb90e..94d2bf2eb1 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_handlers().federation_handler = Mock()
- def get_user_by_access_token(token=None, allow_guest=False):
+ async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
@@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_auth().get_user_by_access_token = get_user_by_access_token
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
hs.get_datastore().insert_client_ip = _insert_client_ip
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 22d734e763..e66c9a4c4c 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -39,7 +39,9 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
- def create_room_as(self, room_creator=None, is_public=True, tok=None):
+ def create_room_as(
+ self, room_creator=None, is_public=True, tok=None, expect_code=200,
+ ):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
@@ -54,9 +56,11 @@ class RestHelper(object):
)
render(request, self.resource, self.hs.get_reactor())
- assert channel.result["code"] == b"200", channel.result
+ assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id
- return channel.json_body["room_id"]
+
+ if expect_code == 200:
+ return channel.json_body["room_id"]
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
@@ -88,7 +92,28 @@ class RestHelper(object):
expect_code=expect_code,
)
- def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
+ def change_membership(
+ self,
+ room: str,
+ src: str,
+ targ: str,
+ membership: str,
+ extra_data: dict = {},
+ tok: Optional[str] = None,
+ expect_code: int = 200,
+ ) -> None:
+ """
+ Send a membership state event into a room.
+
+ Args:
+ room: The ID of the room to send to
+ src: The mxid of the event sender
+ targ: The mxid of the event's target. The state key
+ membership: The type of membership event
+ extra_data: Extra information to include in the content of the event
+ tok: The user access token to use
+ expect_code: The expected HTTP response code
+ """
temp_id = self.auth_user_id
self.auth_user_id = src
@@ -97,6 +122,7 @@ class RestHelper(object):
path = path + "?access_token=%s" % tok
data = {"membership": membership}
+ data.update(extra_data)
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 3ab611f618..152a5182fa 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -108,6 +108,46 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
+ def test_basic_password_reset_canonicalise_email(self):
+ """Test basic password reset flow
+ Request password reset with different spelling
+ """
+ old_password = "monkey"
+ new_password = "kangeroo"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email_profile = "test@example.com"
+ email_passwort_reset = "TEST@EXAMPLE.COM"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email_profile,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ client_secret = "foobar"
+ session_id = self._request_token(email_passwort_reset, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ self._reset_password(new_password, session_id, client_secret)
+
+ # Assert we can log in with the new password
+ self.login("kermit", new_password)
+
+ # Assert we can't log in with the old password
+ self.attempt_wrong_password_login("kermit", old_password)
+
def test_cant_reset_password_without_clicking_link(self):
"""Test that we do actually need to click the link in the email
"""
@@ -386,44 +426,67 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.email = "test@example.com"
self.url_3pid = b"account/3pid"
- def test_add_email(self):
- """Test adding an email to profile
- """
- client_secret = "foobar"
- session_id = self._request_token(self.email, client_secret)
+ def test_add_valid_email(self):
+ self.get_success(self._add_email(self.email, self.email))
- self.assertEquals(len(self.email_attempts), 1)
- link = self._get_link_from_email()
+ def test_add_valid_email_second_time(self):
+ self.get_success(self._add_email(self.email, self.email))
+ self.get_success(
+ self._request_token_invalid_email(
+ self.email,
+ expected_errcode=Codes.THREEPID_IN_USE,
+ expected_error="Email is already in use",
+ )
+ )
- self._validate_token(link)
+ def test_add_valid_email_second_time_canonicalise(self):
+ self.get_success(self._add_email(self.email, self.email))
+ self.get_success(
+ self._request_token_invalid_email(
+ "TEST@EXAMPLE.COM",
+ expected_errcode=Codes.THREEPID_IN_USE,
+ expected_error="Email is already in use",
+ )
+ )
- request, channel = self.make_request(
- "POST",
- b"/_matrix/client/unstable/account/3pid/add",
- {
- "client_secret": client_secret,
- "sid": session_id,
- "auth": {
- "type": "m.login.password",
- "user": self.user_id,
- "password": "test",
- },
- },
- access_token=self.user_id_tok,
+ def test_add_email_no_at(self):
+ self.get_success(
+ self._request_token_invalid_email(
+ "address-without-at.bar",
+ expected_errcode=Codes.UNKNOWN,
+ expected_error="Unable to parse email address",
+ )
)
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ def test_add_email_two_at(self):
+ self.get_success(
+ self._request_token_invalid_email(
+ "foo@foo@test.bar",
+ expected_errcode=Codes.UNKNOWN,
+ expected_error="Unable to parse email address",
+ )
+ )
- # Get user
- request, channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ def test_add_email_bad_format(self):
+ self.get_success(
+ self._request_token_invalid_email(
+ "user@bad.example.net@good.example.com",
+ expected_errcode=Codes.UNKNOWN,
+ expected_error="Unable to parse email address",
+ )
)
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
+ def test_add_email_domain_to_lower(self):
+ self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar"))
+
+ def test_add_email_domain_with_umlaut(self):
+ self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com"))
+
+ def test_add_email_address_casefold(self):
+ self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com"))
+
+ def test_address_trim(self):
+ self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed
@@ -616,6 +679,19 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
return channel.json_body["sid"]
+ def _request_token_invalid_email(
+ self, email, expected_errcode, expected_error, client_secret="foobar",
+ ):
+ request, channel = self.make_request(
+ "POST",
+ b"account/3pid/email/requestToken",
+ {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(expected_errcode, channel.json_body["errcode"])
+ self.assertEqual(expected_error, channel.json_body["error"])
+
def _validate_token(self, link):
# Remove the host
path = link.replace("https://example.com", "")
@@ -643,3 +719,42 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
assert match, "Could not find link in email"
return match.group(0)
+
+ def _add_email(self, request_email, expected_email):
+ """Test adding an email to profile
+ """
+ client_secret = "foobar"
+ session_id = self._request_token(request_email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 7e3dc22f64..0f33c7806d 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -116,8 +116,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
+ @override_config({"enable_registration": False})
def test_POST_disabled_registration(self):
- self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
@@ -160,7 +160,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -186,7 +186,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index c7e5859970..99c9f4e928 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -15,8 +15,7 @@
import itertools
import json
-
-import six
+import urllib
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
@@ -100,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(400, channel.code, channel.json_body)
def test_basic_paginate_relations(self):
- """Tests that calling pagination API corectly the latest relations.
+ """Tests that calling pagination API correctly the latest relations.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
self.assertEquals(200, channel.code, channel.json_body)
@@ -134,7 +133,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# Make sure next_batch has something in it that looks like it could be a
# valid token.
self.assertIsInstance(
- channel.json_body.get("next_batch"), six.string_types, channel.json_body
+ channel.json_body.get("next_batch"), str, channel.json_body
)
def test_repeated_paginate_relations(self):
@@ -278,7 +277,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
prev_token = None
found_event_ids = []
- encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8"))
+ encoded_key = urllib.parse.quote_plus("👍".encode("utf-8"))
for _ in range(20):
from_token = ""
if prev_token:
@@ -670,7 +669,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
query = ""
if key:
- query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
+ query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8"))
original_id = parent_id if parent_id else self.parent_id
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 99eb477149..6850c666be 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect an outgoing GET request for the given key
"""
- def get_json(destination, path, ignore_backoff=False, **kwargs):
+ async def get_json(destination, path, ignore_backoff=False, **kwargs):
self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
# wire up outbound POST /key/v2/query requests from hs2 so that they
# will be forwarded to hs1
- def post_json(destination, path, data):
+ async def post_json(destination, path, data):
self.assertEqual(destination, self.hs.hostname)
self.assertEqual(
path, "/_matrix/key/v2/query",
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 1ca648ef2b..f4f3e56777 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -12,22 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
import os
import shutil
import tempfile
from binascii import unhexlify
from io import BytesIO
from typing import Optional
+from urllib import parse
from mock import Mock
-from six.moves.urllib import parse
import attr
-import PIL.Image as Image
from parameterized import parameterized_class
+from PIL import Image as Image
+from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
@@ -79,7 +78,9 @@ class MediaStorageTests(unittest.HomeserverTestCase):
# This uses a real blocking threadpool so we have to wait for it to be
# actually done :/
- x = self.media_storage.ensure_media_is_in_local_cache(file_info)
+ x = defer.ensureDeferred(
+ self.media_storage.ensure_media_is_in_local_cache(file_info)
+ )
# Hotloop until the threadpool does its job...
self.wait_on_thread(x)
@@ -232,7 +233,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
- self.fetches[0][2], "/_matrix/media/v1/download/" + self.media_id
+ self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
)
self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 2826211f32..74765a582b 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -12,8 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import os
+import re
+
+from mock import patch
import attr
@@ -131,7 +134,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.reactor.nameResolver = Resolver()
def test_cache_returns_correct_type(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
"GET", "url_preview?url=http://matrix.org", shorthand=False
@@ -187,7 +190,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
def test_non_ascii_preview_httpequiv(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -221,7 +224,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_non_ascii_preview_content_type(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -254,7 +257,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_overlong_title(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -292,7 +295,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
IP addresses can be previewed directly.
"""
- self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
"GET", "url_preview?url=http://example.com", shorthand=False
@@ -439,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Hardcode the URL resolving to the IP we want.
self.lookups["example.com"] = [
(IPv4Address, "1.1.1.2"),
- (IPv4Address, "8.8.8.8"),
+ (IPv4Address, "10.1.2.3"),
]
request, channel = self.make_request(
@@ -518,7 +521,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
Accept-Language header is sent to the remote server
"""
- self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
# Build and make a request to the server
request, channel = self.make_request(
@@ -562,3 +565,126 @@ class URLPreviewTests(unittest.HomeserverTestCase):
),
server.data,
)
+
+ def test_oembed_photo(self):
+ """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
+ # Route the HTTP version to an HTTP endpoint so that the tests work.
+ with patch.dict(
+ "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+ {
+ re.compile(
+ r"http://twitter\.com/.+/status/.+"
+ ): "http://publish.twitter.com/oembed",
+ },
+ clear=True,
+ ):
+
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "photo",
+ "url": "http://cdn.twitter.com/matrixdotorg",
+ }
+ oembed_content = json.dumps(result).encode("utf-8")
+
+ end_content = (
+ b"<html><head>"
+ b"<title>Some Title</title>"
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ request, channel = self.make_request(
+ "GET",
+ "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(oembed_content),)
+ + oembed_content
+ )
+
+ self.pump()
+
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+ )
+
+ def test_oembed_rich(self):
+ """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
+ # Route the HTTP version to an HTTP endpoint so that the tests work.
+ with patch.dict(
+ "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+ {
+ re.compile(
+ r"http://twitter\.com/.+/status/.+"
+ ): "http://publish.twitter.com/oembed",
+ },
+ clear=True,
+ ):
+
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "rich",
+ "html": "<div>Content Preview</div>",
+ }
+ end_content = json.dumps(result).encode("utf-8")
+
+ request, channel = self.make_request(
+ "GET",
+ "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {"og:title": None, "og:description": "Content Preview"},
+ )
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
new file mode 100644
index 0000000000..2d021f6565
--- /dev/null
+++ b/tests/rest/test_health.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.rest.health import HealthResource
+
+from tests import unittest
+
+
+class HealthCheckTests(unittest.HomeserverTestCase):
+ def setUp(self):
+ super().setUp()
+
+ # replace the JsonResource with a HealthResource.
+ self.resource = HealthResource()
+
+ def test_health(self):
+ request, channel = self.make_request("GET", "/health", shorthand=False)
+ self.render(request)
+
+ self.assertEqual(request.code, 200)
+ self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/server.py b/tests/server.py
index 1644710aa0..b6e0b14e78 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,8 +2,6 @@ import json
import logging
from io import BytesIO
-from six import text_type
-
import attr
from zope.interface import implementer
@@ -174,7 +172,7 @@ def make_request(
if not path.startswith(b"/"):
path = b"/" + path
- if isinstance(content, text_type):
+ if isinstance(content, str):
content = content.encode("utf8")
site = FakeSite()
@@ -239,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self):
self.threadpool = ThreadPool(self)
+ self._tcp_callbacks = {}
self._udp = []
lookups = self.lookups = {}
@@ -270,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def getThreadPool(self):
return self.threadpool
+ def add_tcp_client_callback(self, host, port, callback):
+ """Add a callback that will be invoked when we receive a connection
+ attempt to the given IP/port using `connectTCP`.
+
+ Note that the callback gets run before we return the connection to the
+ client, which means callbacks cannot block while waiting for writes.
+ """
+ self._tcp_callbacks[(host, port)] = callback
+
+ def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
+ """Fake L{IReactorTCP.connectTCP}.
+ """
+
+ conn = super().connectTCP(
+ host, port, factory, timeout=timeout, bindAddress=None
+ )
+
+ callback = self._tcp_callbacks.get((host, port))
+ if callback:
+ callback()
+
+ return conn
+
class ThreadPool:
"""
@@ -488,7 +510,7 @@ class FakeTransport(object):
try:
self.other.dataReceived(to_write)
except Exception as e:
- logger.warning("Exception writing to protocol: %s", e)
+ logger.exception("Exception writing to protocol: %s", e)
return
self.buffer = self.buffer[len(to_write) :]
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 99908edba3..23db821fb7 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -27,6 +27,7 @@ from synapse.server_notices.resource_limits_server_notices import (
)
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import default_config
@@ -79,7 +80,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=defer.succeed("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
- self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
+ self._rlsn._store.get_tags_for_room = Mock(
+ side_effect=lambda user_id, room_id: make_awaitable({})
+ )
@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self):
@@ -101,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
@@ -119,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -214,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -258,7 +261,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self):
- self.store.get_monthly_active_count = Mock(return_value=1000)
+ self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000))
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
@@ -275,7 +278,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id)
)
- token = self.get_success(self.event_source.get_current_token())
+ token = self.event_source.get_current_token()
events, _ = self.get_success(
self.store.get_recent_events_for_room(
room_id, limit=100, end_token=token.room_key
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index a44960203e..f2955a9c69 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -14,11 +14,12 @@
# limitations under the License.
import itertools
-
-from six.moves import zip
+from typing import List
import attr
+from twisted.internet import defer
+
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
@@ -43,6 +44,11 @@ MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
ORIGIN_SERVER_TS = 0
+class FakeClock:
+ def sleep(self, msec):
+ return defer.succeed(None)
+
+
class FakeEvent(object):
"""A fake event we use as a convenience.
@@ -419,6 +425,7 @@ class StateTestCase(unittest.TestCase):
state_before = dict(state_at_event[prev_events[0]])
else:
state_d = resolve_events_with_store(
+ FakeClock(),
ROOM_ID,
RoomVersions.V2.identifier,
[state_at_event[n] for n in prev_events],
@@ -426,7 +433,7 @@ class StateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(event_map),
)
- state_before = self.successResultOf(state_d)
+ state_before = self.successResultOf(defer.ensureDeferred(state_d))
state_after = dict(state_before)
if fake_event.state_key is not None:
@@ -567,6 +574,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
# Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store(
+ FakeClock(),
ROOM_ID,
RoomVersions.V2.identifier,
[self.state_at_bob, self.state_at_charlie],
@@ -574,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(self.event_map),
)
- state = self.successResultOf(state_d)
+ state = self.successResultOf(defer.ensureDeferred(state_d))
self.assert_dict(self.expected_combined_state, state)
@@ -601,9 +609,11 @@ class TestStateResolutionStore(object):
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
- return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ return defer.succeed(
+ {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ )
- def _get_auth_chain(self, event_ids):
+ def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -615,10 +625,10 @@ class TestStateResolutionStore(object):
presence of rejected events
Args:
- event_ids (list): The event IDs of the events to fetch the auth
+ event_ids: The event IDs of the events to fetch the auth
chain for. Must be state events.
Returns:
- Deferred[list[str]]: List of event IDs of the auth chain.
+ List of event IDs of the auth chain.
"""
# Simple DFS for auth chain
@@ -641,4 +651,4 @@ class TestStateResolutionStore(object):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
- return set(chains[0]).union(*chains[1:]) - common
+ return defer.succeed(set(chains[0]).union(*chains[1:]) - common)
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 5a50e4fdd4..319e2c2325 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
self.table_name = "table_" + hs.get_secrets().token_hex(6)
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"create",
lambda x, *a: x.execute(*a),
"CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"index",
lambda x, *a: x.execute(*a),
"CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["hello"], ["there"]]
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"test",
- self.storage.db.simple_upsert_many_txn,
+ self.storage.db_pool.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage.db.simple_select_list(
+ self.storage.db_pool.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["bleb"]]
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"test",
- self.storage.db.simple_upsert_many_txn,
+ self.storage.db_pool.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage.db.simple_select_list(
+ self.storage.db_pool.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ef296e7dab..17fbde284a 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,13 +24,14 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
-from synapse.storage.database import Database, make_conn
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -178,14 +179,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_appservice_state_none(self):
service = Mock(id="999")
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(None, state)
@defer.inlineCallbacks
def test_get_appservice_state_up(self):
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
service = Mock(id=self.as_list[0]["id"])
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.UP, state)
@defer.inlineCallbacks
@@ -194,20 +195,22 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
service = Mock(id=self.as_list[1]["id"])
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.DOWN, state)
@defer.inlineCallbacks
def test_get_appservices_by_state_none(self):
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(0, len(services))
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +222,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -339,7 +348,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def test_get_oldest_unsent_txn_none(self):
service = Mock(id=self.as_list[0]["id"])
- txn = yield self.store.get_oldest_unsent_txn(service)
+ txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(None, txn)
@defer.inlineCallbacks
@@ -349,14 +358,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- self.store.get_events_as_list = Mock(return_value=events)
+ self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
yield self._insert_txn(service.id, 11, other_events)
yield self._insert_txn(service.id, 12, other_events)
- txn = yield self.store.get_oldest_unsent_txn(service)
+ txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(service, txn.service)
self.assertEquals(10, txn.id)
self.assertEquals(events, txn.events)
@@ -366,8 +375,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(1, len(services))
self.assertEquals(self.as_list[0]["id"], services[0].id)
@@ -379,8 +388,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(2, len(services))
self.assertEquals(
@@ -391,7 +400,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(TestTransactionStore, self).__init__(database, db_conn, hs)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 940b166129..2efbc97c2e 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -9,7 +9,9 @@ from tests import unittest
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
- self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater
+ self.updates = (
+ self.hs.get_datastore().db_pool.updates
+ ) # type: BackgroundUpdater
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
@@ -29,7 +31,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore()
self.get_success(
- store.db.simple_insert(
+ store.db_pool.simple_insert(
"background_updates",
values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
)
@@ -40,7 +42,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def update(progress, count):
yield self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
- yield store.db.runInteraction(
+ yield store.db_pool.runInteraction(
"update_progress",
self.updates._background_update_progress_txn,
"test_update",
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 278961c331..13bcac743a 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,11 +21,11 @@ from mock import Mock
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.engines import create_engine
from tests import unittest
-from tests.utils import TestHomeServer
+from tests.utils import TestHomeServer, default_config
class SQLBaseStoreTestCase(unittest.TestCase):
@@ -49,10 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.db_pool.runWithConnection = runWithConnection
- config = Mock()
- config._disable_native_upserts = True
- config.caches = Mock()
- config.caches.event_cache_size = 1
+ config = default_config(name="test", parse=True)
hs = TestHomeServer("test", config=config)
sqlite_config = {"name": "sqlite3"}
@@ -60,7 +57,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
- db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+ db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
self.datastore = SQLBaseStore(db, None, hs)
@@ -69,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_insert(
- table="tablename", values={"columname": "Value"}
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename", values={"columname": "Value"}
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -81,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_insert(
- table="tablename",
- # Use OrderedDict() so we can assert on the SQL generated
- values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename",
+ # Use OrderedDict() so we can assert on the SQL generated
+ values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -96,7 +97,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
- value = yield self.datastore.db.simple_select_one_onecol(
+ value = yield self.datastore.db_pool.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
)
@@ -110,7 +111,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
- ret = yield self.datastore.db.simple_select_one(
+ ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
@@ -126,7 +127,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
- ret = yield self.datastore.db.simple_select_one(
+ ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
@@ -141,7 +142,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
- ret = yield self.datastore.db.simple_select_list(
+ ret = yield self.datastore.db_pool.simple_select_list(
table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
)
@@ -154,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_update_one(
+ yield self.datastore.db_pool.simple_update_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
updatevalues={"columnname": "New Value"},
@@ -169,7 +170,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_4cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_update_one(
+ yield self.datastore.db_pool.simple_update_one(
table="tablename",
keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
@@ -184,7 +185,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_delete_one(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_delete_one(
+ yield self.datastore.db_pool.simple_delete_one(
table="tablename", keyvalues={"keycol": "Go away"}
)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 43425c969a..43639ca286 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
@@ -47,12 +47,12 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
"""
# Make sure we don't clash with in progress updates.
self.assertTrue(
- self.store.db.updates._all_done, "Background updates are still ongoing"
+ self.store.db_pool.updates._all_done, "Background updates are still ongoing"
)
schema_path = os.path.join(
prepare_database.dir_path,
- "data_stores",
+ "databases",
"main",
"schema",
"delta",
@@ -64,19 +64,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
prepare_database.executescript(txn, schema_path)
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"test_delete_forward_extremities", run_delta_file
)
)
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
def test_soft_failed_extremities_handled_correctly(self):
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
"3"
] = 300000
+
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
# All entries within time frame
self.assertEqual(
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
3,
)
# Oldest room to expire
- self.pump(1)
+ self.pump(1.01)
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
self.assertEqual(
len(
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 3b483bc7f0..224ea6fd79 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -86,7 +86,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -117,7 +117,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -204,10 +204,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_devices_last_seen_bg_update(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
user_id = "@user:id"
@@ -225,7 +225,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# But clear the associated entry in devices table
self.get_success(
- self.store.db.simple_update(
+ self.store.db_pool.simple_update(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={"last_seen": None, "ip": None, "user_agent": None},
@@ -252,7 +252,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# Register the background update to run again.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
table="background_updates",
values={
"update_name": "devices_last_seen",
@@ -263,14 +263,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# We should now get the correct result again
@@ -293,10 +293,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_old_user_ips_pruned(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
user_id = "@user:id"
@@ -315,7 +315,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should see that in the DB
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -341,7 +341,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should get no results.
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index c2539b353a..87ed8f8cd1 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -34,7 +34,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_store_new_device(self):
- yield self.store.store_device("user_id", "device_id", "display_name")
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device_id", "display_name")
+ )
res = yield self.store.get_device("user_id", "device_id")
self.assertDictContainsSubset(
@@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_get_devices_by_user(self):
- yield self.store.store_device("user_id", "device1", "display_name 1")
- yield self.store.store_device("user_id", "device2", "display_name 2")
- yield self.store.store_device("user_id2", "device3", "display_name 3")
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device1", "display_name 1")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device2", "display_name 2")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id2", "device3", "display_name 3")
+ )
- res = yield self.store.get_devices_by_user("user_id")
+ res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset(
{
@@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids, ["somehost"]
+ yield defer.ensureDeferred(
+ self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "somehost", -1, limit=100
+ now_stream_id, device_updates = yield defer.ensureDeferred(
+ self.store.get_device_updates_by_remote("somehost", -1, limit=100)
)
# Check original device_ids are contained within these updates
@@ -99,19 +107,23 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_update_device(self):
- yield self.store.store_device("user_id", "device_id", "display_name 1")
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device_id", "display_name 1")
+ )
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
- yield self.store.update_device("user_id", "device_id")
+ yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do the update
- yield self.store.update_device(
- "user_id", "device_id", new_display_name="display_name 2"
+ yield defer.ensureDeferred(
+ self.store.update_device(
+ "user_id", "device_id", new_display_name="display_name 2"
+ )
)
# check it worked
@@ -121,7 +133,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm:
- yield self.store.update_device(
- "user_id", "unknown_device_id", new_display_name="display_name 2"
+ yield defer.ensureDeferred(
+ self.store.update_device(
+ "user_id", "unknown_device_id", new_display_name="display_name 2"
+ )
)
self.assertEqual(404, cm.exception.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 4e128e1047..daac947cb2 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -34,8 +34,10 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_room_to_alias(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
self.assertEquals(
@@ -45,24 +47,36 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_alias_to_room(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]},
- (yield self.store.get_association_from_room_alias(self.alias)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_association_from_room_alias(self.alias)
+ )
+ ),
)
@defer.inlineCallbacks
def test_delete_alias(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
- room_id = yield self.store.delete_room_alias(self.alias)
+ room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
- (yield self.store.get_association_from_room_alias(self.alias))
+ (
+ yield defer.ensureDeferred(
+ self.store.get_association_from_room_alias(self.alias)
+ )
+ )
)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 398d546280..d57cdffd8b 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -30,11 +30,13 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070
json = {"key": "value"}
- yield self.store.store_device("user", "device", None)
+ yield defer.ensureDeferred(self.store.store_device("user", "device", None))
yield self.store.set_e2e_device_keys("user", "device", now, json)
- res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user", "device"),))
+ )
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
@@ -45,7 +47,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070
json = {"key": "value"}
- yield self.store.store_device("user", "device", None)
+ yield defer.ensureDeferred(self.store.store_device("user", "device", None))
changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
self.assertTrue(changed)
@@ -61,9 +63,13 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
json = {"key": "value"}
yield self.store.set_e2e_device_keys("user", "device", now, json)
- yield self.store.store_device("user", "device", "display_name")
+ yield defer.ensureDeferred(
+ self.store.store_device("user", "device", "display_name")
+ )
- res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user", "device"),))
+ )
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
@@ -75,18 +81,18 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def test_multiple_devices(self):
now = 1470174257070
- yield self.store.store_device("user1", "device1", None)
- yield self.store.store_device("user1", "device2", None)
- yield self.store.store_device("user2", "device1", None)
- yield self.store.store_device("user2", "device2", None)
+ yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
+ yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
+ yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
+ yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
- res = yield self.store.get_e2e_device_keys(
- (("user1", "device1"), ("user2", "device2"))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 3aeec0dc0f..d4c3b867e3 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -56,7 +56,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
for i in range(0, 20):
- self.get_success(self.store.db.runInteraction("insert", insert_event, i))
+ self.get_success(
+ self.store.db_pool.runInteraction("insert", insert_event, i)
+ )
# this should get the last ten
r = self.get_success(self.store.get_prev_events_for_room(room_id))
@@ -81,13 +83,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for i in range(0, 20):
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room1)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room1)
)
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room2)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room2)
)
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room3)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room3)
)
# Test simple case
@@ -164,7 +166,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
depth = depth_map[event_id]
- self.store.db.simple_insert_txn(
+ self.store.db_pool.simple_insert_txn(
txn,
table="events",
values={
@@ -179,7 +181,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
- self.store.db.simple_insert_many_txn(
+ self.store.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
values=[
@@ -192,7 +194,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for event_id in auth_graph:
next_stream_ordering += 1
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"insert", insert_event, event_id, next_stream_ordering
)
)
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a7b85004e5..949846fe33 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
# Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)]
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b45bc9c115..238bad5b45 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
- yield self.store.get_unread_push_actions_for_user_in_range_for_http(
- USER_ID, 0, 1000, 20
+ yield defer.ensureDeferred(
+ self.store.get_unread_push_actions_for_user_in_range_for_http(
+ USER_ID, 0, 1000, 20
+ )
)
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self):
- yield self.store.get_unread_push_actions_for_user_in_range_for_email(
- USER_ID, 0, 1000, 20
+ yield defer.ensureDeferred(
+ self.store.get_unread_push_actions_for_user_in_range_for_email(
+ USER_ID, 0, 1000, 20
+ )
)
@defer.inlineCallbacks
@@ -56,7 +60,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield self.store.db.runInteraction(
+ counts = yield self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
self.assertEquals(
@@ -72,10 +76,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
- yield self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}
+ yield defer.ensureDeferred(
+ self.store.add_push_actions_to_staging(
+ event.event_id, {user_id: action}
+ )
)
- yield self.store.db.runInteraction(
+ yield self.store.db_pool.runInteraction(
"",
self.persist_events_store._set_push_actions_for_event_and_users_txn,
[(event, None)],
@@ -83,12 +89,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
def _rotate(stream):
- return self.store.db.runInteraction(
+ return self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
def _mark_read(stream, depth):
- return self.store.db.runInteraction(
+ return self.store.db_pool.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
room_id,
@@ -117,7 +123,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
- yield self.store.db.simple_delete(
+ yield self.store.db_pool.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
@@ -136,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return self.store.db.simple_insert(
- "events",
- {
- "stream_ordering": so,
- "received_ts": ts,
- "event_id": "event%i" % so,
- "type": "",
- "room_id": "",
- "content": "",
- "processed": True,
- "outlier": False,
- "topological_ordering": 0,
- "depth": 0,
- },
+ return defer.ensureDeferred(
+ self.store.db_pool.simple_insert(
+ "events",
+ {
+ "stream_ordering": so,
+ "received_ts": ts,
+ "event_id": "event%i" % so,
+ "type": "",
+ "room_id": "",
+ "content": "",
+ "processed": True,
+ "outlier": False,
+ "topological_ordering": 0,
+ "depth": 0,
+ },
+ )
)
# start with the base case where there are no events in the table
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 55e9ecf264..9b9a183e7f 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.unittest import HomeserverTestCase
@@ -27,9 +27,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.db = self.store.db # type: Database
+ self.db_pool = self.store.db_pool # type: DatabasePool
- self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
@@ -47,7 +47,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def _create(conn):
return MultiWriterIdGenerator(
conn,
- self.db,
+ self.db_pool,
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
@@ -55,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
sequence_name="foobar_seq",
)
- return self.get_success(self.db.runWithConnection(_create))
+ return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
def _insert(txn):
@@ -65,7 +65,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
(instance_name,),
)
- self.get_success(self.db.runInteraction("test_single_instance", _insert))
+ self.get_success(self.db_pool.runInteraction("test_single_instance", _insert))
def test_empty(self):
"""Test an ID generator against an empty database gives sensible
@@ -88,7 +88,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -98,12 +98,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
@@ -116,8 +116,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second")
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token("first"), 3)
- self.assertEqual(first_id_gen.get_current_token("second"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -166,7 +166,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -176,9 +176,45 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- self.get_success(self.db.runInteraction("test", _get_next_txn))
+ self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+
+ def test_get_persisted_upto_position(self):
+ """Test that `get_persisted_upto_position` correctly tracks updates to
+ positions.
+ """
+
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 5)
+
+ id_gen = self._create_id_generator("first")
+
+ # Min is 3 and there is a gap between 5, so we expect it to be 3.
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ # We advance "first" straight to 6. Min is now 5 but there is no gap so
+ # we expect it to be 6
+ id_gen.advance("first", 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+ # No gap, so we expect 7.
+ id_gen.advance("second", 7)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # We haven't seen 8 yet, so we expect 7 still.
+ id_gen.advance("second", 9)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # Now that we've seen 7, 8 and 9 we can got straight to 9.
+ id_gen.advance("first", 8)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 9)
+
+ # Jump forward with gaps. The minimum is 11, even though we haven't seen
+ # 10 we know that everything before 11 must be persisted.
+ id_gen.advance("first", 11)
+ id_gen.advance("second", 15)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 11)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..fbf8af940a 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass")
- yield self.store.create_profile(self.user.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
users, total = yield self.store.get_users_paginate(
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9c04e92577..9870c74883 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
@@ -78,7 +79,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# XXX why are we doing this here? this function is only run at startup
# so it is odd to re-run it here.
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
)
@@ -204,7 +205,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_add_threepid(user, "email", email, now, now)
)
- d = self.store.db.runInteraction(
+ d = self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.get_success(d)
@@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
d = self.store.populate_monthly_active_users(user_id)
self.get_success(d)
@@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
@@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock(
@@ -280,7 +287,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
]
self.hs.config.mau_limits_reserved_threepids = threepids
- d = self.store.db.runInteraction(
+ d = self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.get_success(d)
@@ -293,8 +300,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.register_user(user_id=user2, password_hash=None))
now = int(self.hs.get_clock().time_msec())
- self.store.user_add_threepid(user1, "email", user1_email, now, now)
- self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ self.get_success(
+ self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ )
+ self.get_success(
+ self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ )
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), len(threepids))
@@ -333,7 +344,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..9d5b8aa47d 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(self.u_frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
@@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(self.u_frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index b9fafaa1a6..918387733b 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
+from synapse.api.errors import NotFoundError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -44,28 +47,19 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
- event = store.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
+ event = self.get_success(
+ store.get_topological_token_for_event(last["event_id"])
+ )
# Purge everything before this topological token
- purge = storage.purge_events.purge_history(self.room_id, event, True)
- self.pump()
- self.assertEqual(self.successResultOf(purge), None)
-
- # Try and get the events
- get_first = store.get_event(first["event_id"])
- get_second = store.get_event(second["event_id"])
- get_third = store.get_event(third["event_id"])
- get_last = store.get_event(last["event_id"])
- self.pump()
+ self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
- self.failureResultOf(get_first)
- self.failureResultOf(get_second)
- self.failureResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_failure(store.get_event(first["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(second["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(third["event_id"]), NotFoundError)
+ self.get_success(store.get_event(last["event_id"]))
def test_purge_wont_delete_extrems(self):
"""
@@ -80,28 +74,21 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
- event = storage.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
+ event = self.get_success(
+ storage.get_topological_token_for_event(last["event_id"])
+ )
event = "t{}-{}".format(
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
)
# Purge everything before this topological token
- purge = storage.purge_history(self.room_id, event, True)
+ purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
self.pump()
f = self.failureResultOf(purge)
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
- get_first = storage.get_event(first["event_id"])
- get_second = storage.get_event(second["event_id"])
- get_third = storage.get_event(third["event_id"])
- get_last = storage.get_event(last["event_id"])
- self.pump()
-
- # Nothing is deleted.
- self.successResultOf(get_first)
- self.successResultOf(get_second)
- self.successResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_success(storage.get_event(first["event_id"]))
+ self.get_success(storage.get_event(second["event_id"]))
+ self.get_success(storage.get_event(third["event_id"]))
+ self.get_success(storage.get_event(last["event_id"]))
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index db3667dc43..1ea35d60c1 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def build(self, prev_event_ids):
- built_event = yield self._base_builder.build(prev_event_ids)
+ built_event = yield defer.ensureDeferred(
+ self._base_builder.build(prev_event_ids)
+ )
built_event._event_id = self._event_id
built_event._dict["event_id"] = self._event_id
@@ -249,6 +251,10 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def room_id(self):
return self._base_builder.room_id
+ @property
+ def type(self):
+ return self._base_builder.type
+
event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
@@ -341,7 +347,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
event_json = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
@@ -359,7 +365,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.reactor.advance(60 * 60 * 2)
event_json = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 71a40a0a49..58f827d8d3 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,6 +17,7 @@
from twisted.internet import defer
from synapse.api.constants import UserTypes
+from synapse.api.errors import ThreepidValidationError
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -58,8 +59,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_add_tokens(self):
yield self.store.register_user(self.user_id, self.pwhash)
- yield self.store.add_access_token_to_user(
- self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.store.add_access_token_to_user(
+ self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ )
)
result = yield self.store.get_user_by_access_token(self.tokens[1])
@@ -74,11 +77,15 @@ class RegistrationStoreTestCase(unittest.TestCase):
def test_user_delete_access_tokens(self):
# add some tokens
yield self.store.register_user(self.user_id, self.pwhash)
- yield self.store.add_access_token_to_user(
- self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.store.add_access_token_to_user(
+ self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
+ )
)
- yield self.store.add_access_token_to_user(
- self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.store.add_access_token_to_user(
+ self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ )
)
# now delete some
@@ -116,3 +123,33 @@ class RegistrationStoreTestCase(unittest.TestCase):
)
res = yield self.store.is_support_user(SUPPORT_USER)
self.assertTrue(res)
+
+ @defer.inlineCallbacks
+ def test_3pid_inhibit_invalid_validation_session_error(self):
+ """Tests that enabling the configuration option to inhibit 3PID errors on
+ /requestToken also inhibits validation errors caused by an unknown session ID.
+ """
+
+ # Check that, with the config setting set to false (the default value), a
+ # validation error is caused by the unknown session ID.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Unknown session_id", e)
+
+ # Set the config setting to true.
+ self.store._ignore_unknown_session_error = True
+
+ # Check that now the validation error is caused by the token not matching.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 3b78d48896..d07b985a8e 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id=self.u_creator.to_string(),
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id=self.u_creator.to_string(),
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
@@ -56,6 +58,10 @@ class RoomStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_get_room_unknown_room(self):
+ self.assertIsNone((yield self.store.get_room("!uknown:test")),)
+
+ @defer.inlineCallbacks
def test_get_room_with_stats(self):
self.assertDictContainsSubset(
{
@@ -66,6 +72,10 @@ class RoomStoreTestCase(unittest.TestCase):
(yield self.store.get_room_with_stats(self.room.to_string())),
)
+ @defer.inlineCallbacks
+ def test_get_room_with_stats_unknown_room(self):
+ self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
+
class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
@@ -80,17 +90,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id="@creator:text",
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield self.storage.persistence.persist_event(
- self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(
+ self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+ )
)
@defer.inlineCallbacks
@@ -101,7 +115,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
- state = yield self.store.get_current_state(room_id=self.room.to_string())
+ state = yield defer.ensureDeferred(
+ self.store.get_current_state(room_id=self.room.to_string())
+ )
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
@@ -117,7 +133,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
- state = yield self.store.get_current_state(room_id=self.room.to_string())
+ state = yield defer.ensureDeferred(
+ self.store.get_current_state(room_id=self.room.to_string())
+ )
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 5dd46005e6..d98fe8754d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -118,18 +118,22 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
def test_get_joined_users_from_context(self):
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
- bob_event = event_injection.inject_member_event(
- self.hs, room, self.u_bob, Membership.JOIN
+ bob_event = self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room, self.u_bob, Membership.JOIN
+ )
)
# first, create a regular event
- event, context = event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[bob_event.event_id],
- type="m.test.1",
- content={},
+ event, context = self.get_success(
+ event_injection.create_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_alice,
+ prev_event_ids=[bob_event.event_id],
+ type="m.test.1",
+ content={},
+ )
)
users = self.get_success(
@@ -140,22 +144,26 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Regression test for #7376: create a state event whose key matches bob's
# user_id, but which is *not* a membership event, and persist that; then check
# that `get_joined_users_from_context` returns the correct users for the next event.
- non_member_event = event_injection.inject_event(
- self.hs,
- room_id=room,
- sender=self.u_bob,
- prev_event_ids=[bob_event.event_id],
- type="m.test.2",
- state_key=self.u_bob,
- content={},
+ non_member_event = self.get_success(
+ event_injection.inject_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_bob,
+ prev_event_ids=[bob_event.event_id],
+ type="m.test.2",
+ state_key=self.u_bob,
+ content={},
+ )
)
- event, context = event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[non_member_event.event_id],
- type="m.test.3",
- content={},
+ event, context = self.get_success(
+ event_injection.create_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_alice,
+ prev_event_ids=[non_member_event.event_id],
+ type="m.test.3",
+ content={},
+ )
)
users = self.get_success(
self.store.get_joined_users_from_context(event, context)
@@ -171,20 +179,20 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def test_can_rerun_update(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
table="background_updates",
values={
"update_name": "current_state_events_membership",
@@ -195,12 +203,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b88308ff4..8bd12fa847 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id="@creator:text",
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
@@ -64,11 +66,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.storage.persistence.persist_event(event, context)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
@@ -87,8 +91,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.storage.state.get_state_groups_ids(
- self.room, [e2.event_id]
+ state_group_map = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@@ -106,8 +110,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.storage.state.get_state_groups(
- self.room, [e2.event_id]
+ state_group_map = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
@@ -148,7 +152,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield self.storage.state.get_state_for_event(e5.event_id)
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(e5.event_id)
+ )
self.assertIsNotNone(e4)
@@ -164,22 +170,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+ )
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+ )
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+ )
)
self.assertStateMapEqual(
@@ -188,12 +200,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield self.storage.state.get_state_for_event(
- e5.event_id,
- state_filter=StateFilter(
- types={EventTypes.Member: {self.u_alice.to_string()}},
- include_others=True,
- ),
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {self.u_alice.to_string()}},
+ include_others=True,
+ ),
+ )
)
self.assertStateMapEqual(
@@ -206,11 +220,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield self.storage.state.get_state_for_event(
- e5.event_id,
- state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
- ),
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()}, include_others=True
+ ),
+ )
)
self.assertStateMapEqual(
@@ -222,8 +238,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield self.storage.state.get_state_groups_ids(
- room_id, [e5.event_id]
+ group_ids = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 6a545d2eb0..ecfafe68a9 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -40,7 +40,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
- r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
@@ -51,7 +51,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
def test_search_user_dir_all_users(self):
self.hs.config.user_directory_search_all_users = True
try:
- r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"]))
self.assertDictEqual(
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 89dcc58b99..4a4548433f 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,3 +1,18 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from mock import Mock
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
@@ -10,6 +25,7 @@ from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
+from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase):
@@ -26,7 +42,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
user_id = UserID("us", "test")
- our_user = Requester(user_id, None, False, None, None)
+ our_user = Requester(user_id, None, False, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
room_deferred = ensureDeferred(
room_creator.create_room(
@@ -95,7 +111,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
prev_events that said event references.
"""
- def post_json(destination, path, data, headers=None, timeout=0):
+ async def post_json(destination, path, data, headers=None, timeout=0):
# If it asks us for new missing events, give them NOTHING
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
return {"events": []}
@@ -173,7 +189,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
store = self.homeserver.get_datastore()
- store.get_rooms_for_user = Mock(return_value=["!someroom:test"])
+ store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
@@ -218,23 +234,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client.
federation_client = self.homeserver.get_federation_client()
federation_client.query_user_devices = Mock(
- return_value={
- "user_id": remote_user_id,
- "stream_id": 1,
- "devices": [],
- "master_key": {
- "user_id": remote_user_id,
- "usage": ["master"],
- "keys": {"ed25519:" + remote_master_key: remote_master_key},
- },
- "self_signing_key": {
+ return_value=succeed(
+ {
"user_id": remote_user_id,
- "usage": ["self_signing"],
- "keys": {
- "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
},
- },
- }
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:"
+ + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
)
# Resync the device list.
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 49667ed7f4..654a6fa42d 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -166,7 +166,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.do_sync_for_user(token5)
self.do_sync_for_user(token6)
- # But old user cant
+ # But old user can't
with self.assertRaises(SynapseError) as cm:
self.do_sync_for_user(token1)
diff --git a/tests/test_server.py b/tests/test_server.py
index e9a43b1e45..655c918a15 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -12,31 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import re
-from six import StringIO
-
from twisted.internet.defer import Deferred
-from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, RedirectException, SynapseError
-from synapse.http.server import (
- DirectServeResource,
- JsonResource,
- OptionsResource,
- wrap_html_request_handler,
-)
-from synapse.http.site import SynapseSite, logger
+from synapse.config.server import parse_listener_def
+from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
+from synapse.http.site import SynapseSite
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
from tests import unittest
from tests.server import (
- FakeTransport,
ThreadedMemoryReactorClock,
make_request,
render,
@@ -168,6 +157,28 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+ def test_head_request(self):
+ """
+ JsonResource.handler_for_request gives correctly decoded URL args to
+ the callback, while Twisted will give the raw bytes of URL query
+ arguments.
+ """
+
+ def _callback(request, **kwargs):
+ return 200, {"result": True}
+
+ res = JsonResource(self.homeserver)
+ res.register_paths(
+ "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet",
+ )
+
+ # The path was registered as GET, but this is a HEAD request.
+ request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertNotIn("body", channel.result)
+
class OptionsResourceTests(unittest.TestCase):
def setUp(self):
@@ -189,7 +200,13 @@ class OptionsResourceTests(unittest.TestCase):
request.prepath = [] # This doesn't get set properly by make_request.
# Create a site and query for the resource.
- site = SynapseSite("test", "site_tag", {}, self.resource, "1.0")
+ site = SynapseSite(
+ "test",
+ "site_tag",
+ parse_listener_def({"type": "http", "port": 0}),
+ self.resource,
+ "1.0",
+ )
request.site = site
resource = site.getResourceFor(request)
@@ -198,10 +215,10 @@ class OptionsResourceTests(unittest.TestCase):
return channel
def test_unknown_options_request(self):
- """An OPTIONS requests to an unknown URL still returns 200 OK."""
+ """An OPTIONS requests to an unknown URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/foo/")
- self.assertEqual(channel.result["code"], b"200")
- self.assertEqual(channel.result["body"], b"{}")
+ self.assertEqual(channel.result["code"], b"204")
+ self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
self.assertTrue(
@@ -218,10 +235,10 @@ class OptionsResourceTests(unittest.TestCase):
)
def test_known_options_request(self):
- """An OPTIONS requests to an known URL still returns 200 OK."""
+ """An OPTIONS requests to an known URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/res/")
- self.assertEqual(channel.result["code"], b"200")
- self.assertEqual(channel.result["body"], b"{}")
+ self.assertEqual(channel.result["code"], b"204")
+ self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
self.assertTrue(
@@ -250,18 +267,17 @@ class OptionsResourceTests(unittest.TestCase):
class WrapHtmlRequestHandlerTests(unittest.TestCase):
- class TestResource(DirectServeResource):
+ class TestResource(DirectServeHtmlResource):
callback = None
- @wrap_html_request_handler
async def _async_render_GET(self, request):
- return await self.callback(request)
+ await self.callback(request)
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
def test_good_response(self):
- def callback(request):
+ async def callback(request):
request.write(b"response")
request.finish()
@@ -281,7 +297,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
with the right location.
"""
- def callback(request, **kwargs):
+ async def callback(request, **kwargs):
raise RedirectException(b"/look/an/eagle", 301)
res = WrapHtmlRequestHandlerTests.TestResource()
@@ -301,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
returned too
"""
- def callback(request, **kwargs):
+ async def callback(request, **kwargs):
e = RedirectException(b"/no/over/there", 304)
e.cookies.append(b"session=yespls")
raise e
@@ -319,51 +335,18 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
self.assertEqual(cookies_headers, [b"session=yespls"])
+ def test_head_request(self):
+ """A head request should work by being turned into a GET request."""
-class SiteTestCase(unittest.HomeserverTestCase):
- def test_lose_connection(self):
- """
- We log the URI correctly redacted when we lose the connection.
- """
+ async def callback(request):
+ request.write(b"response")
+ request.finish()
- class HangingResource(Resource):
- """
- A Resource that strategically hangs, as if it were processing an
- answer.
- """
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
- def render(self, request):
- return NOT_DONE_YET
-
- # Set up a logging handler that we can inspect afterwards
- output = StringIO()
- handler = logging.StreamHandler(output)
- logger.addHandler(handler)
- old_level = logger.level
- logger.setLevel(10)
- self.addCleanup(logger.setLevel, old_level)
- self.addCleanup(logger.removeHandler, handler)
-
- # Make a resource and a Site, the resource will hang and allow us to
- # time out the request while it's 'processing'
- base_resource = Resource()
- base_resource.putChild(b"", HangingResource())
- site = SynapseSite("test", "site_tag", {}, base_resource, "1.0")
-
- server = site.buildProtocol(None)
- client = AccumulatingProtocol()
- client.makeConnection(FakeTransport(server, self.reactor))
- server.makeConnection(FakeTransport(client, self.reactor))
-
- # Send a request with an access token that will get redacted
- server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n")
- self.pump()
-
- # Lose the connection
- e = Failure(Exception("Failed123"))
- server.connectionLost(e)
- handler.flush()
-
- # Our access token is redacted and the failure reason is logged.
- self.assertIn("/?access_token=<redacted>", output.getvalue())
- self.assertIn("Failed123", output.getvalue())
+ request, channel = make_request(self.reactor, b"HEAD", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertNotIn("body", channel.result)
diff --git a/tests/test_state.py b/tests/test_state.py
index 66f22f6813..b5c3667d2a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -97,17 +97,19 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids)
- return state_group
+ return defer.succeed(state_group)
def get_events(self, event_ids, **kwargs):
- return {
- e_id: self._event_id_to_event[e_id]
- for e_id in event_ids
- if e_id in self._event_id_to_event
- }
+ return defer.succeed(
+ {
+ e_id: self._event_id_to_event[e_id]
+ for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
+ )
def get_state_group_delta(self, name):
- return None, None
+ return defer.succeed((None, None))
def register_events(self, events):
for e in events:
@@ -120,7 +122,7 @@ class StateGroupStore(object):
self._event_to_state_group[event_id] = state_group
def get_room_version_id(self, room_id):
- return RoomVersions.V1.identifier
+ return defer.succeed(RoomVersions.V1.identifier)
class DictObj(dict):
@@ -202,14 +204,16 @@ class StateTestCase(unittest.TestCase):
context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertEqual(2, len(prev_state_ids))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -253,7 +259,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -310,7 +318,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_e = context_store["E"]
- prev_state_ids = yield ctx_e.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -383,7 +393,7 @@ class StateTestCase(unittest.TestCase):
ctx_b = context_store["B"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
@@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
)
@@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
)
@@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(
{e.event_id for e in old_state}, set(current_state_ids.values())
@@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -503,9 +517,9 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
@@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
+ @defer.inlineCallbacks
def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
- sg1 = self.store.store_state_group(
+ sg1 = yield self.store.store_state_group(
prev_event_id_1,
event.room_id,
None,
@@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
- sg2 = self.store.store_state_group(
+ sg2 = yield self.store.store_state_group(
prev_event_id_2,
event.room_id,
None,
@@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
- return self.state.compute_event_context(event)
+ result = yield defer.ensureDeferred(self.state.compute_event_context(event))
+ return result
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 5c2817cf28..b89798336c 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -14,7 +14,6 @@
import json
-import six
from mock import Mock
from twisted.test.proto_helpers import MemoryReactorClock
@@ -60,7 +59,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertTrue(channel.json_body is not None)
- self.assertIsInstance(channel.json_body["session"], six.text_type)
+ self.assertIsInstance(channel.json_body["session"], str)
self.assertIsInstance(channel.json_body["flows"], list)
for flow in channel.json_body["flows"]:
@@ -125,6 +124,6 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertTrue(channel.json_body is not None)
- self.assertIsInstance(channel.json_body["user_id"], six.text_type)
- self.assertIsInstance(channel.json_body["access_token"], six.text_type)
- self.assertIsInstance(channel.json_body["device_id"], six.text_type)
+ self.assertIsInstance(channel.json_body["user_id"], str)
+ self.assertIsInstance(channel.json_body["access_token"], str)
+ self.assertIsInstance(channel.json_body["device_id"], str)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 7b345b03bb..508aeba078 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,7 +17,7 @@
"""
Utilities for running the unit tests
"""
-from typing import Awaitable, TypeVar
+from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
# if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed")
+
+
+async def make_awaitable(result: Any):
+ """Create an awaitable that just returns a result."""
+ return result
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 431e9f8e5e..8522c6fc09 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Optional, Tuple
import synapse.server
@@ -23,15 +22,12 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Collection
-from tests.test_utils import get_awaitable_result
-
-
"""
Utility functions for poking events into the storage of the server under test.
"""
-def inject_member_event(
+async def inject_member_event(
hs: synapse.server.HomeServer,
room_id: str,
sender: str,
@@ -48,7 +44,7 @@ def inject_member_event(
if extra_content:
content.update(extra_content)
- return inject_event(
+ return await inject_event(
hs,
room_id=room_id,
type=EventTypes.Member,
@@ -59,7 +55,7 @@ def inject_member_event(
)
-def inject_event(
+async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
@@ -74,37 +70,27 @@ def inject_event(
prev_event_ids: prev_events for the event. If not specified, will be looked up
kwargs: fields for the event to be created
"""
- test_reactor = hs.get_reactor()
+ event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
-
- d = hs.get_storage().persistence.persist_event(event, context)
- test_reactor.advance(0)
- get_awaitable_result(d)
+ await hs.get_storage().persistence.persist_event(event, context)
return event
-def create_event(
+async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
- test_reactor = hs.get_reactor()
-
if room_version is None:
- d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
- test_reactor.advance(0)
- room_version = get_awaitable_result(d)
+ room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
- d = hs.get_event_creation_handler().create_new_client_event(
+ event, context = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
- test_reactor.advance(0)
- event, context = get_awaitable_result(d)
return event, context
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index f7381b2885..531a9b9118 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage()
- yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
+ yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@defer.inlineCallbacks
def test_filtering(self):
@@ -53,7 +53,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
#
# before we do that, we persist some other events to act as state.
- self.inject_visibility("@admin:hs", "joined")
+ yield self.inject_visibility("@admin:hs", "joined")
for i in range(0, 10):
yield self.inject_room_member("@resident%i:hs" % i)
@@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
- filtered = yield filter_events_for_server(
- self.storage, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(self.storage, "test_server", events_to_filter)
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
# ... and the filtering happens.
- filtered = yield filter_events_for_server(
- self.storage, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(self.storage, "test_server", events_to_filter)
)
for i in range(0, len(events_to_filter)):
@@ -137,10 +137,12 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
+ )
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
)
- yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -158,11 +160,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.storage.persistence.persist_event(event, context)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
@defer.inlineCallbacks
@@ -179,11 +183,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.storage.persistence.persist_event(event, context)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
@defer.inlineCallbacks
@@ -265,8 +271,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
storage.main = test_store
storage.state = test_store
- filtered = yield filter_events_for_server(
- test_store, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(test_store, "test_server", events_to_filter)
)
logger.info("Filtering took %f seconds", time.time() - start)
diff --git a/tests/unittest.py b/tests/unittest.py
index 6b6f224e9c..7b80999a74 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -229,7 +229,7 @@ class HomeserverTestCase(TestCase):
self.site = SynapseSite(
logger_name="synapse.access.http.fake",
site_tag="test",
- config={},
+ config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
)
@@ -241,20 +241,20 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
- def get_user_by_access_token(token=None, allow_guest=False):
- return succeed(
- {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- )
-
- def get_user_by_req(request, allow_guest=False, rights="access"):
- return succeed(
- create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
- )
+ async def get_user_by_access_token(token=None, allow_guest=False):
+ return {
+ "user": UserID.from_string(self.helper.auth_user_id),
+ "token_id": 1,
+ "is_guest": False,
+ }
+
+ async def get_user_by_req(request, allow_guest=False, rights="access"):
+ return create_requester(
+ UserID.from_string(self.helper.auth_user_id),
+ 1,
+ False,
+ False,
+ None,
)
self.hs.get_auth().get_user_by_req = get_user_by_req
@@ -422,8 +422,8 @@ class HomeserverTestCase(TestCase):
async def run_bg_updates():
with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
- while not await stor.db.updates.has_completed_background_updates():
- await stor.db.updates.do_next_background_update(1)
+ while not await stor.db_pool.updates.has_completed_background_updates():
+ await stor.db_pool.updates.do_next_background_update(1)
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
@@ -544,7 +544,7 @@ class HomeserverTestCase(TestCase):
"""
event_creator = self.hs.get_event_creation_handler()
secrets = self.hs.get_secrets()
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
event, context = self.get_success(
event_creator.create_event(
@@ -571,7 +571,7 @@ class HomeserverTestCase(TestCase):
Add the given event as an extremity to the room.
"""
self.get_success(
- self.hs.get_datastore().db.simple_insert(
+ self.hs.get_datastore().db_pool.simple_insert(
table="event_forward_extremities",
values={"room_id": room_id, "event_id": event_id},
desc="test_add_extremity",
@@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase):
user: MXID of the user to inject the membership for.
membership: The membership type.
"""
- event_injection.inject_member_event(self.hs, room, user, membership)
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, user, membership)
+ )
class FederatingHomeserverTestCase(HomeserverTestCase):
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 4d2b9e0d64..0363735d4f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
assert current_context().request == "c1"
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
assert current_context().request == "c1"
return self.mock(args1, arg2)
@@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
return self.mock(args1, arg2)
obj = Cls()
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index e90e08d1c0..8d6627ec33 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -15,9 +15,9 @@
import threading
+from io import StringIO
from mock import NonCallableMock
-from six import StringIO
from twisted.internet import defer, reactor
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index ca3858b184..0e52811948 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six.moves import range
-
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 95301c013c..58ee918f65 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_make_deferred_yieldable(self):
- # a function which retuns an incomplete deferred, but doesn't follow
+ # a function which returns an incomplete deferred, but doesn't follow
# the synapse rules.
def blocking_function():
d = defer.Deferred()
@@ -183,7 +183,7 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_await(self):
- # an async function which retuns an incomplete coroutine, but doesn't
+ # an async function which returns an incomplete coroutine, but doesn't
# follow the synapse rules.
async def blocking_function():
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9e348694ad..bc42ffce88 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
def test_new_destination(self):
"""A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastore()
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
# advance the clock a bit before making the request
self.pump(1)
@@ -36,18 +34,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
with limiter:
pass
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
def test_limiter(self):
"""General test case which walks through the process of a failing request"""
store = self.hs.get_datastore()
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
try:
@@ -58,29 +52,22 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError:
pass
- # wait for the update to land
- self.pump()
-
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], failure_ts)
self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
# now if we try again we should get a failure
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- self.failureResultOf(d, NotRetryingDestination)
+ self.get_failure(
+ get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
+ )
#
# advance the clock and try again
#
self.pump(MIN_RETRY_INTERVAL)
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
try:
@@ -91,12 +78,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError:
pass
- # wait for the update to land
- self.pump()
-
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], retry_ts)
self.assertGreaterEqual(
@@ -110,9 +92,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
# one more go, with success
#
self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
with limiter:
@@ -121,7 +101,5 @@ class RetryLimiterTestCase(HomeserverTestCase):
# wait for the update to land
self.pump()
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index 4f4da29a98..8491f7cc83 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -28,9 +28,6 @@ class StringUtilsTestCase(unittest.TestCase):
"_--something==_",
"...--==-18913",
"8Dj2odd-e9asd.cd==_--ddas-secret-",
- # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766
- # To be removed in a future release
- "SECRET:1234567890",
]
bad = [
diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py
new file mode 100644
index 0000000000..5513724d87
--- /dev/null
+++ b/tests/util/test_threepids.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.threepids import canonicalise_email
+
+from tests.unittest import HomeserverTestCase
+
+
+class CanonicaliseEmailTests(HomeserverTestCase):
+ def test_no_at(self):
+ with self.assertRaises(ValueError):
+ canonicalise_email("address-without-at.bar")
+
+ def test_two_at(self):
+ with self.assertRaises(ValueError):
+ canonicalise_email("foo@foo@test.bar")
+
+ def test_bad_format(self):
+ with self.assertRaises(ValueError):
+ canonicalise_email("user@bad.example.net@good.example.com")
+
+ def test_valid_format(self):
+ self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
+
+ def test_domain_to_lower(self):
+ self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
+
+ def test_domain_with_umlaut(self):
+ self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
+
+ def test_address_casefold(self):
+ self.assertEqual(
+ canonicalise_email("Strauß@Example.com"), "strauss@example.com"
+ )
+
+ def test_address_trim(self):
+ self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")
diff --git a/tests/utils.py b/tests/utils.py
index 59c020a051..a61cbdef44 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -21,9 +21,9 @@ import time
import uuid
import warnings
from inspect import getcallargs
+from urllib import parse as urlparse
from mock import Mock, patch
-from six.moves.urllib import parse as urlparse
from twisted.internet import defer, reactor
@@ -154,6 +154,10 @@ def default_config(name, parse=False):
"account": {"per_second": 10000, "burst_count": 10000},
"failed_attempts": {"per_second": 10000, "burst_count": 10000},
},
+ "rc_joins": {
+ "local": {"per_second": 10000, "burst_count": 10000},
+ "remote": {"per_second": 10000, "burst_count": 10000},
+ },
"saml2_enabled": False,
"public_baseurl": None,
"default_identity_server": None,
@@ -168,6 +172,7 @@ def default_config(name, parse=False):
# background, which upsets the test runner.
"update_user_directory": False,
"caches": {"global_factor": 1},
+ "listeners": [{"port": 0, "type": "http"}],
}
if parse:
@@ -637,14 +642,8 @@ class DeferredMockCallable(object):
)
-@defer.inlineCallbacks
-def create_room(hs, room_id, creator_id):
+async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room
-
- Args:
- hs
- room_id (str)
- creator_id (str)
"""
persistence_store = hs.get_storage().persistence
@@ -652,7 +651,7 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
- yield store.store_room(
+ await store.store_room(
room_id=room_id,
room_creator_user_id=creator_id,
is_public=False,
@@ -670,6 +669,6 @@ def create_room(hs, room_id, creator_id):
},
)
- event, context = yield event_creation_handler.create_new_client_event(builder)
+ event, context = await event_creation_handler.create_new_client_event(builder)
- yield persistence_store.persist_event(event, context)
+ await persistence_store.persist_event(event, context)
|