diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 0bfb86bf1f..8ab56ec94c 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -36,7 +36,7 @@ from tests import unittest
from tests.utils import mock_getRawHeaders, setup_test_homeserver
-class TestHandlers(object):
+class TestHandlers:
def __init__(self, hs):
self.auth_handler = synapse.handlers.auth.AuthHandler(hs)
@@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
+ self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
self.store.is_support_user = Mock(return_value=defer.succeed(False))
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
- self.store.get_user_by_access_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(
+ return_value=defer.succeed(user_info)
+ )
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase):
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"}
- self.store.get_user_by_access_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(
+ return_value=defer.succeed(user_info)
+ )
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase):
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
@@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
@@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ # This just needs to return a truth-y value.
+ self.store.get_user_by_id = Mock(
+ return_value=defer.succeed({"is_guest": False})
+ )
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
- return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
+ return_value=defer.succeed(
+ {"name": "@baldrick:matrix.org", "device_id": "device"}
+ )
)
user_id = "@baldrick:matrix.org"
@@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
- self.store.get_user_by_id = Mock(return_value={"is_guest": True})
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase):
def get_user(tok):
if token != tok:
- return None
- return {
- "name": USER_ID,
- "is_guest": False,
- "token_id": 1234,
- "device_id": "DEVICE",
- }
+ return defer.succeed(None)
+ return defer.succeed(
+ {
+ "name": USER_ID,
+ "is_guest": False,
+ "token_id": 1234,
+ "device_id": "DEVICE",
+ }
+ )
self.store.get_user_by_access_token = get_user
- self.store.get_user_by_id = Mock(return_value={"is_guest": False})
+ self.store.get_user_by_id = Mock(
+ return_value=defer.succeed({"is_guest": False})
+ )
# check the token works
request = Mock(args={})
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 4e67503cf0..d2d535d23c 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -369,14 +369,18 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
- filter_id = yield self.datastore.add_user_filter(
- user_localpart=user_localpart, user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.datastore.add_user_filter(
+ user_localpart=user_localpart, user_filter=user_filter_json
+ )
)
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_presence(events=events)
@@ -386,8 +390,10 @@ class FilteringTestCase(unittest.TestCase):
def test_filter_presence_no_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
- filter_id = yield self.datastore.add_user_filter(
- user_localpart=user_localpart + "2", user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.datastore.add_user_filter(
+ user_localpart=user_localpart + "2", user_filter=user_filter_json
+ )
)
event = MockEvent(
event_id="$asdasd:localhost",
@@ -396,8 +402,10 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart + "2", filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart + "2", filter_id=filter_id
+ )
)
results = user_filter.filter_presence(events=events)
@@ -406,14 +414,18 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_filter_room_state_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield self.datastore.add_user_filter(
- user_localpart=user_localpart, user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.datastore.add_user_filter(
+ user_localpart=user_localpart, user_filter=user_filter_json
+ )
)
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_room_state(events=events)
@@ -422,16 +434,20 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_filter_room_state_no_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield self.datastore.add_user_filter(
- user_localpart=user_localpart, user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.datastore.add_user_filter(
+ user_localpart=user_localpart, user_filter=user_filter_json
+ )
)
event = MockEvent(
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
)
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_room_state(events)
@@ -457,16 +473,20 @@ class FilteringTestCase(unittest.TestCase):
def test_add_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield self.filtering.add_user_filter(
- user_localpart=user_localpart, user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.filtering.add_user_filter(
+ user_localpart=user_localpart, user_filter=user_filter_json
+ )
)
self.assertEquals(filter_id, 0)
self.assertEquals(
user_filter_json,
(
- yield self.datastore.get_user_filter(
- user_localpart=user_localpart, filter_id=0
+ yield defer.ensureDeferred(
+ self.datastore.get_user_filter(
+ user_localpart=user_localpart, filter_id=0
+ )
)
),
)
@@ -475,12 +495,16 @@ class FilteringTestCase(unittest.TestCase):
def test_get_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield self.datastore.add_user_filter(
- user_localpart=user_localpart, user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.datastore.add_user_filter(
+ user_localpart=user_localpart, user_filter=user_filter_json
+ )
)
- filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
self.assertEquals(filter.get_filter_json(), user_filter_json)
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index d580e729c5..1e1f30d790 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,4 +1,6 @@
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
+from synapse.appservice import ApplicationService
+from synapse.types import create_requester
from tests import unittest
@@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase):
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
+ def test_allowed_user_via_can_requester_do_action(self):
+ user_requester = create_requester("@user:example.com")
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_requester_do_action(
+ user_requester, _time_now_s=0
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ user_requester, _time_now_s=5
+ )
+ self.assertFalse(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ user_requester, _time_now_s=10
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(20.0, time_allowed)
+
+ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
+ appservice = ApplicationService(
+ None, "example.com", id="foo", rate_limited=True,
+ )
+ as_requester = create_requester("@user:example.com", app_service=appservice)
+
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=0
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=5
+ )
+ self.assertFalse(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=10
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(20.0, time_allowed)
+
+ def test_allowed_appservice_via_can_requester_do_action(self):
+ appservice = ApplicationService(
+ None, "example.com", id="foo", rate_limited=False,
+ )
+ as_requester = create_requester("@user:example.com", app_service=appservice)
+
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=0
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(-1, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=5
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(-1, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=10
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(-1, time_allowed)
+
def test_allowed_via_ratelimit(self):
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 4003869ed6..236b608d58 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org"
- self.assertFalse((yield self.service.is_interested(self.event)))
+ self.assertFalse(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_room_member_is_checked(self):
@@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_room_id_match(self):
@@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_room_id_no_match(self):
@@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
- self.assertFalse((yield self.service.is_interested(self.event)))
+ self.assertFalse(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.store.get_aliases_for_room.return_value = [
- "#irc_foobar:matrix.org",
- "#athing:matrix.org",
- ]
- self.store.get_users_in_room.return_value = []
- self.assertTrue((yield self.service.is_interested(self.event, self.store)))
+ self.store.get_aliases_for_room.return_value = defer.succeed(
+ ["#irc_foobar:matrix.org", "#athing:matrix.org"]
+ )
+ self.store.get_users_in_room.return_value = defer.succeed([])
+ self.assertTrue(
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
+ )
def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
@@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.store.get_aliases_for_room.return_value = [
- "#xmpp_foobar:matrix.org",
- "#athing:matrix.org",
- ]
- self.store.get_users_in_room.return_value = []
- self.assertFalse((yield self.service.is_interested(self.event, self.store)))
+ self.store.get_aliases_for_room.return_value = defer.succeed(
+ ["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
+ )
+ self.store.get_users_in_room.return_value = defer.succeed([])
+ self.assertFalse(
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
+ )
@defer.inlineCallbacks
def test_regex_multiple_matches(self):
@@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
)
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
- self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
- self.store.get_users_in_room.return_value = []
- self.assertTrue((yield self.service.is_interested(self.event, self.store)))
+ self.store.get_aliases_for_room.return_value = defer.succeed(
+ ["#irc_barfoo:matrix.org"]
+ )
+ self.store.get_users_in_room.return_value = defer.succeed([])
+ self.assertTrue(
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
+ )
@defer.inlineCallbacks
def test_interested_in_self(self):
@@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.type = "m.room.member"
self.event.content = {"membership": "invite"}
self.event.state_key = self.service.sender
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
- self.store.get_users_in_room.return_value = [
- "@alice:here",
- "@irc_fo:here", # AS user
- "@bob:here",
- ]
- self.store.get_aliases_for_room.return_value = []
+ # Note that @irc_fo:here is the AS user.
+ self.store.get_users_in_room.return_value = defer.succeed(
+ ["@alice:here", "@irc_fo:here", "@bob:here"]
+ )
+ self.store.get_aliases_for_room.return_value = defer.succeed([])
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(
- (yield self.service.is_interested(event=self.event, store=self.store))
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(event=self.event, store=self.store)
+ )
+ )
)
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 52f89d3f83..68a4caabbf 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -25,6 +25,7 @@ from synapse.appservice.scheduler import (
from synapse.logging.context import make_deferred_yieldable
from tests import unittest
+from tests.test_utils import make_awaitable
from ..utils import MockClock
@@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.get_appservice_state = Mock(
return_value=defer.succeed(ApplicationServiceState.UP)
)
- txn.send = Mock(return_value=defer.succeed(True))
+ txn.send = Mock(return_value=make_awaitable(True))
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call
- self.txnctrl.send(service, events)
+ self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved
@@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call
- self.txnctrl.send(service, events)
+ self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved
@@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
return_value=defer.succeed(ApplicationServiceState.UP)
)
self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
- txn.send = Mock(return_value=defer.succeed(False)) # fails to send
+ txn.send = Mock(return_value=make_awaitable(False)) # fails to send
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call
- self.txnctrl.send(service, events)
+ self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events
@@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
# shouldn't have called anything prior to waiting for exp backoff
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
- txn.send = Mock(return_value=True)
+ txn.send = Mock(return_value=make_awaitable(True))
+ txn.complete.return_value = make_awaitable(None)
# wait for exp backoff
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
@@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
- txn.send = Mock(return_value=False)
+ txn.send = Mock(return_value=make_awaitable(False))
+ txn.complete.return_value = make_awaitable(None)
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
@@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEquals(3, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
self.assertEquals(0, self.callback.call_count)
- txn.send = Mock(return_value=True) # successfully send the txn
+ txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16)
self.assertEquals(1, txn.send.call_count) # new mock reset call count
diff --git a/tests/config/test_base.py b/tests/config/test_base.py
new file mode 100644
index 0000000000..42ee5f56d9
--- /dev/null
+++ b/tests/config/test_base.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path
+import tempfile
+
+from synapse.config import ConfigError
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+
+
+class BaseConfigTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.hs = hs
+
+ def test_loading_missing_templates(self):
+ # Use a temporary directory that exists on the system, but that isn't likely to
+ # contain template files
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Attempt to load an HTML template from our custom template directory
+ template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0]
+
+ # If no errors, we should've gotten the default template instead
+
+ # Render the template
+ a_random_string = random_string(5)
+ html_content = template.render({"error_description": a_random_string})
+
+ # Check that our string exists in the template
+ self.assertIn(
+ a_random_string,
+ html_content,
+ "Template file did not contain our test string",
+ )
+
+ def test_loading_custom_templates(self):
+ # Use a temporary directory that exists on the system
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Create a temporary bogus template file
+ with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template:
+ # Get temporary file's filename
+ template_filename = os.path.basename(tmp_template.name)
+
+ # Write a custom HTML template
+ contents = b"{{ test_variable }}"
+ tmp_template.write(contents)
+ tmp_template.flush()
+
+ # Attempt to load the template from our custom template directory
+ template = (
+ self.hs.config.read_templates([template_filename], tmp_dir)
+ )[0]
+
+ # Render the template
+ a_random_string = random_string(5)
+ html_content = template.render({"test_variable": a_random_string})
+
+ # Check that our string exists in the template
+ self.assertIn(
+ a_random_string,
+ html_content,
+ "Template file did not contain our test string",
+ )
+
+ def test_loading_template_from_nonexistent_custom_directory(self):
+ with self.assertRaises(ConfigError):
+ self.hs.config.read_templates(
+ ["some_filename.html"], "a_nonexistent_directory"
+ )
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index f9ce609923..8ff1460c0d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -23,6 +23,7 @@ from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
+from twisted.internet.defer import Deferred, ensureDeferred
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
@@ -33,16 +34,17 @@ from synapse.crypto.keyring import (
)
from synapse.logging.context import (
LoggingContext,
- PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.storage.keys import FetchKeyResult
from tests import unittest
+from tests.test_utils import make_awaitable
+from tests.unittest import logcontext_clean
-class MockPerspectiveServer(object):
+class MockPerspectiveServer:
def __init__(self):
self.server_name = "mock_server"
self.key = signedjson.key.generate_signing_key(0)
@@ -66,56 +68,42 @@ class MockPerspectiveServer(object):
signedjson.sign.sign_json(res, self.server_name, self.key)
+@logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- self.mock_perspective_server = MockPerspectiveServer()
- self.http_client = Mock()
-
- config = self.default_config()
- config["trusted_key_servers"] = [
- {
- "server_name": self.mock_perspective_server.server_name,
- "verify_keys": self.mock_perspective_server.get_verify_keys(),
- }
- ]
-
- return self.setup_test_homeserver(
- handlers=None, http_client=self.http_client, config=config
- )
-
- def check_context(self, _, expected):
+ def check_context(self, val, expected):
self.assertEquals(getattr(current_context(), "request", None), expected)
+ return val
def test_verify_json_objects_for_server_awaits_previous_requests(self):
- key1 = signedjson.key.generate_signing_key(1)
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock()
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
- kr = keyring.Keyring(self.hs)
+ # a signed object that we are going to try to validate
+ key1 = signedjson.key.generate_signing_key(1)
json1 = {}
signedjson.sign.sign_json(json1, "server10", key1)
- persp_resp = {
- "server_keys": [
- self.mock_perspective_server.get_signed_key(
- "server10", signedjson.key.get_verify_key(key1)
- )
- ]
- }
- persp_deferred = defer.Deferred()
+ # start off a first set of lookups. We make the mock fetcher block until this
+ # deferred completes.
+ first_lookup_deferred = Deferred()
- @defer.inlineCallbacks
- def get_perspectives(**kwargs):
- self.assertEquals(current_context().request, "11")
- with PreserveLoggingContext():
- yield persp_deferred
- return persp_resp
+ async def first_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_11")
+ self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
- self.http_client.post_json.side_effect = get_perspectives
+ await make_deferred_yieldable(first_lookup_deferred)
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
+
+ mock_fetcher.get_keys.side_effect = first_lookup_fetch
- # start off a first set of lookups
- @defer.inlineCallbacks
- def first_lookup():
- with LoggingContext("11") as context_11:
- context_11.request = "11"
+ async def first_lookup():
+ with LoggingContext("context_11") as context_11:
+ context_11.request = "context_11"
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
@@ -124,7 +112,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# the unsigned json should be rejected pretty quickly
self.assertTrue(res_deferreds[1].called)
try:
- yield res_deferreds[1]
+ await res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure")
except SynapseError:
pass
@@ -132,45 +120,51 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds[0])
+ await make_deferred_yieldable(res_deferreds[0])
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
+ d0 = ensureDeferred(first_lookup())
- d0 = first_lookup()
-
- # wait a tick for it to send the request to the perspectives server
- # (it first tries the datastore)
- self.pump()
- self.http_client.post_json.assert_called_once()
+ mock_fetcher.get_keys.assert_called_once()
# a second request for a server with outstanding requests
# should block rather than start a second call
- @defer.inlineCallbacks
- def second_lookup():
- with LoggingContext("12") as context_12:
- context_12.request = "12"
- self.http_client.post_json.reset_mock()
- self.http_client.post_json.return_value = defer.Deferred()
+
+ async def second_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_12")
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
+
+ mock_fetcher.get_keys.reset_mock()
+ mock_fetcher.get_keys.side_effect = second_lookup_fetch
+ second_lookup_state = [0]
+
+ async def second_lookup():
+ with LoggingContext("context_12") as context_12:
+ context_12.request = "context_12"
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")]
)
res_deferreds_2[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 1
+ await make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 2
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
-
- d2 = second_lookup()
+ d2 = ensureDeferred(second_lookup())
self.pump()
- self.http_client.post_json.assert_not_called()
+ # the second request should be pending, but the fetcher should not yet have been
+ # called
+ self.assertEqual(second_lookup_state[0], 1)
+ mock_fetcher.get_keys.assert_not_called()
# complete the first request
- persp_deferred.callback(persp_resp)
+ first_lookup_deferred.callback(None)
+
+ # and now both verifications should succeed.
self.get_success(d0)
self.get_success(d2)
@@ -190,7 +184,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
- self.failureResultOf(d, SynapseError)
+ self.get_failure(d, SynapseError)
# should succeed on a signed object
d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
@@ -202,7 +196,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
with a null `ts_valid_until_ms`
"""
mock_fetcher = keyring.KeyFetcher()
- mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
+ mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
kr = keyring.Keyring(
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
@@ -221,7 +215,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
- self.failureResultOf(d, SynapseError)
+ self.get_failure(d, SynapseError)
# should fail on a signed object with a non-zero minimum_valid_until_ms,
# as it tries to refetch the keys and fails.
@@ -245,17 +239,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key(1)
- def get_keys(keys_to_fetch):
+ async def get_keys(keys_to_fetch):
# there should only be one request object (with the max validity)
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
- }
+ return {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
- )
+ }
mock_fetcher = keyring.KeyFetcher()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
@@ -282,25 +274,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""If the first fetcher cannot provide a recent enough key, we fall back"""
key1 = signedjson.key.generate_signing_key(1)
- def get_keys1(keys_to_fetch):
+ async def get_keys1(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
- }
- }
- )
+ return {
+ "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
+ }
- def get_keys2(keys_to_fetch):
+ async def get_keys2(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
- }
+ return {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
- )
+ }
mock_fetcher1 = keyring.KeyFetcher()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
@@ -325,6 +311,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher2.get_keys.assert_called_once()
+@logcontext_clean
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
@@ -355,7 +342,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
}
signedjson.sign.sign_json(response, SERVER_NAME, testkey)
- def get_json(destination, path, **kwargs):
+ async def get_json(destination, path, **kwargs):
self.assertEqual(destination, SERVER_NAME)
self.assertEqual(path, "/_matrix/key/v2/server/key1")
return response
@@ -444,7 +431,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect a perspectives-server key query
"""
- def post_json(destination, path, data, **kwargs):
+ async def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
self.assertEqual(path, "/_matrix/key/v2/query")
@@ -580,14 +567,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# remove the perspectives server's signature
response = build_response()
del response["signatures"][self.mock_perspective_server.server_name]
- self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
# remove the origin server's signature
response = build_response()
del response["signatures"][SERVER_NAME]
- self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 0c9987be54..1471cc1a28 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -15,14 +15,13 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import UserID
from tests import unittest
+from tests.test_utils import make_awaitable
class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@@ -59,7 +58,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Artificially raise the complexity
store = self.hs.get_datastore()
- store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23)
+ store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
# Get the room complexity again -- make sure it's our artificial value
request, channel = self.make_request(
@@ -78,9 +77,40 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=make_awaitable(("", 1))
+ )
+
+ d = handler._remote_join(
+ None,
+ ["other.example.com"],
+ "roomid",
+ UserID.from_string(u1),
+ {"membership": "join"},
+ )
+
+ self.pump()
+
+ # The request failed with a SynapseError saying the resource limit was
+ # exceeded.
+ f = self.get_failure(d, SynapseError)
+ self.assertEqual(f.value.code, 400, f.value)
+ self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def test_join_too_large_admin(self):
+ # Check whether an admin can join if option "admins_can_join" is undefined,
+ # this option defaults to false, so the join should fail.
+
+ u1 = self.register_user("u1", "pass", admin=True)
+
+ handler = self.hs.get_room_member_handler()
+ fed_transport = self.hs.get_federation_transport_client()
+
+ # Mock out some things, because we don't want to test the whole join
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
- return_value=defer.succeed(("", 1))
+ return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -116,13 +146,13 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
+ fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
handler.federation_handler.do_invite_join = Mock(
- return_value=defer.succeed(("", 1))
+ return_value=make_awaitable(("", 1))
)
# Artificially raise the complexity
- self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(
+ self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable(
600
)
@@ -141,3 +171,81 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+
+class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
+ # Test the behavior of joining rooms which exceed the complexity if option
+ # limit_remote_rooms.admins_can_join is True.
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["limit_remote_rooms"] = {
+ "enabled": True,
+ "complexity": 0.05,
+ "admins_can_join": True,
+ }
+ return config
+
+ def test_join_too_large_no_admin(self):
+ # A user which is not an admin should not be able to join a remote room
+ # which is too complex.
+
+ u1 = self.register_user("u1", "pass")
+
+ handler = self.hs.get_room_member_handler()
+ fed_transport = self.hs.get_federation_transport_client()
+
+ # Mock out some things, because we don't want to test the whole join
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=make_awaitable(("", 1))
+ )
+
+ d = handler._remote_join(
+ None,
+ ["other.example.com"],
+ "roomid",
+ UserID.from_string(u1),
+ {"membership": "join"},
+ )
+
+ self.pump()
+
+ # The request failed with a SynapseError saying the resource limit was
+ # exceeded.
+ f = self.get_failure(d, SynapseError)
+ self.assertEqual(f.value.code, 400, f.value)
+ self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def test_join_too_large_admin(self):
+ # An admin should be able to join rooms where a complexity check fails.
+
+ u1 = self.register_user("u1", "pass", admin=True)
+
+ handler = self.hs.get_room_member_handler()
+ fed_transport = self.hs.get_federation_transport_client()
+
+ # Mock out some things, because we don't want to test the whole join
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=make_awaitable(("", 1))
+ )
+
+ d = handler._remote_join(
+ None,
+ ["other.example.com"],
+ "roomid",
+ UserID.from_string(u1),
+ {"membership": "join"},
+ )
+
+ self.pump()
+
+ # The request success since the user is an admin
+ self.get_success(d)
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
new file mode 100644
index 0000000000..1a3ccb263d
--- /dev/null
+++ b/tests/federation/test_federation_catch_up.py
@@ -0,0 +1,422 @@
+from typing import List, Tuple
+
+from mock import Mock
+
+from synapse.events import EventBase
+from synapse.federation.sender import PerDestinationQueue, TransactionManager
+from synapse.federation.units import Edu
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.test_utils import event_injection, make_awaitable
+from tests.unittest import FederatingHomeserverTestCase, override_config
+
+
+class FederationCatchUpTestCases(FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
+
+ def prepare(self, reactor, clock, hs):
+ # stub out get_current_hosts_in_room
+ state_handler = hs.get_state_handler()
+
+ # This mock is crucial for destination_rooms to be populated.
+ state_handler.get_current_hosts_in_room = Mock(
+ return_value=make_awaitable(["test", "host2"])
+ )
+
+ # whenever send_transaction is called, record the pdu data
+ self.pdus = []
+ self.failed_pdus = []
+ self.is_online = True
+ self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.record_transaction
+ )
+
+ async def record_transaction(self, txn, json_cb):
+ if self.is_online:
+ data = json_cb()
+ self.pdus.extend(data["pdus"])
+ return {}
+ else:
+ data = json_cb()
+ self.failed_pdus.extend(data["pdus"])
+ raise IOError("Failed to connect because this is a test!")
+
+ def get_destination_room(self, room: str, destination: str = "host2") -> dict:
+ """
+ Gets the destination_rooms entry for a (destination, room_id) pair.
+
+ Args:
+ room: room ID
+ destination: what destination, default is "host2"
+
+ Returns:
+ Dictionary of { event_id: str, stream_ordering: int }
+ """
+ event_id, stream_ordering = self.get_success(
+ self.hs.get_datastore().db_pool.execute(
+ "test:get_destination_rooms",
+ None,
+ """
+ SELECT event_id, stream_ordering
+ FROM destination_rooms dr
+ JOIN events USING (stream_ordering)
+ WHERE dr.destination = ? AND dr.room_id = ?
+ """,
+ destination,
+ room,
+ )
+ )[0]
+ return {"event_id": event_id, "stream_ordering": stream_ordering}
+
+ @override_config({"send_federation": True})
+ def test_catch_up_destination_rooms_tracking(self):
+ """
+ Tests that we populate the `destination_rooms` table as needed.
+ """
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room = self.helper.create_room_as("u1", tok=u1_token)
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, "@user:host2", "join")
+ )
+
+ event_id_1 = self.helper.send(room, "wombats!", tok=u1_token)["event_id"]
+
+ row_1 = self.get_destination_room(room)
+
+ event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"]
+
+ row_2 = self.get_destination_room(room)
+
+ # check: events correctly registered in order
+ self.assertEqual(row_1["event_id"], event_id_1)
+ self.assertEqual(row_2["event_id"], event_id_2)
+ self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
+
+ @override_config({"send_federation": True})
+ def test_catch_up_last_successful_stream_ordering_tracking(self):
+ """
+ Tests that we populate the `destination_rooms` table as needed.
+ """
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room = self.helper.create_room_as("u1", tok=u1_token)
+
+ # take the remote offline
+ self.is_online = False
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, "@user:host2", "join")
+ )
+
+ self.helper.send(room, "wombats!", tok=u1_token)
+ self.pump()
+
+ lsso_1 = self.get_success(
+ self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+ "host2"
+ )
+ )
+
+ self.assertIsNone(
+ lsso_1,
+ "There should be no last successful stream ordering for an always-offline destination",
+ )
+
+ # bring the remote online
+ self.is_online = True
+
+ event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"]
+
+ lsso_2 = self.get_success(
+ self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+ "host2"
+ )
+ )
+ row_2 = self.get_destination_room(room)
+
+ self.assertEqual(
+ self.pdus[0]["content"]["body"],
+ "rabbits!",
+ "Test fault: didn't receive the right PDU",
+ )
+ self.assertEqual(
+ row_2["event_id"],
+ event_id_2,
+ "Test fault: destination_rooms not updated correctly",
+ )
+ self.assertEqual(
+ lsso_2,
+ row_2["stream_ordering"],
+ "Send succeeded but not marked as last_successful_stream_ordering",
+ )
+
+ @override_config({"send_federation": True}) # critical to federate
+ def test_catch_up_from_blank_state(self):
+ """
+ Runs an overall test of federation catch-up from scratch.
+ Further tests will focus on more narrow aspects and edge-cases, but I
+ hope to provide an overall view with this test.
+ """
+ # bring the other server online
+ self.is_online = True
+
+ # let's make some events for the other server to receive
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_1 = self.helper.create_room_as("u1", tok=u1_token)
+ room_2 = self.helper.create_room_as("u1", tok=u1_token)
+
+ # also critical to federate
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_2, "@user:host2", "join")
+ )
+
+ self.helper.send_state(
+ room_1, event_type="m.room.topic", body={"topic": "wombat"}, tok=u1_token
+ )
+
+ # check: PDU received for topic event
+ self.assertEqual(len(self.pdus), 1)
+ self.assertEqual(self.pdus[0]["type"], "m.room.topic")
+
+ # take the remote offline
+ self.is_online = False
+
+ # send another event
+ self.helper.send(room_1, "hi user!", tok=u1_token)
+
+ # check: things didn't go well since the remote is down
+ self.assertEqual(len(self.failed_pdus), 1)
+ self.assertEqual(self.failed_pdus[0]["content"]["body"], "hi user!")
+
+ # let's delete the federation transmission queue
+ # (this pretends we are starting up fresh.)
+ self.assertFalse(
+ self.hs.get_federation_sender()
+ ._per_destination_queues["host2"]
+ .transmission_loop_running
+ )
+ del self.hs.get_federation_sender()._per_destination_queues["host2"]
+
+ # let's also clear any backoffs
+ self.get_success(
+ self.hs.get_datastore().set_destination_retry_timings("host2", None, 0, 0)
+ )
+
+ # bring the remote online and clear the received pdu list
+ self.is_online = True
+ self.pdus = []
+
+ # now we need to initiate a federation transaction somehow…
+ # to do that, let's send another event (because it's simple to do)
+ # (do it to another room otherwise the catch-up logic decides it doesn't
+ # need to catch up room_1 — something I overlooked when first writing
+ # this test)
+ self.helper.send(room_2, "wombats!", tok=u1_token)
+
+ # we should now have received both PDUs
+ self.assertEqual(len(self.pdus), 2)
+ self.assertEqual(self.pdus[0]["content"]["body"], "hi user!")
+ self.assertEqual(self.pdus[1]["content"]["body"], "wombats!")
+
+ def make_fake_destination_queue(
+ self, destination: str = "host2"
+ ) -> Tuple[PerDestinationQueue, List[EventBase]]:
+ """
+ Makes a fake per-destination queue.
+ """
+ transaction_manager = TransactionManager(self.hs)
+ per_dest_queue = PerDestinationQueue(self.hs, transaction_manager, destination)
+ results_list = []
+
+ async def fake_send(
+ destination_tm: str,
+ pending_pdus: List[EventBase],
+ _pending_edus: List[Edu],
+ ) -> bool:
+ assert destination == destination_tm
+ results_list.extend(pending_pdus)
+ return True # success!
+
+ transaction_manager.send_new_transaction = fake_send
+
+ return per_dest_queue, results_list
+
+ @override_config({"send_federation": True})
+ def test_catch_up_loop(self):
+ """
+ Tests the behaviour of _catch_up_transmission_loop.
+ """
+
+ # ARRANGE:
+ # - a local user (u1)
+ # - 3 rooms which u1 is joined to (and remote user @user:host2 is
+ # joined to)
+ # - some events (1 to 5) in those rooms
+ # we have 'already sent' events 1 and 2 to host2
+ per_dest_queue, sent_pdus = self.make_fake_destination_queue()
+
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_1 = self.helper.create_room_as("u1", tok=u1_token)
+ room_2 = self.helper.create_room_as("u1", tok=u1_token)
+ room_3 = self.helper.create_room_as("u1", tok=u1_token)
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_2, "@user:host2", "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_3, "@user:host2", "join")
+ )
+
+ # create some events
+ self.helper.send(room_1, "you hear me!!", tok=u1_token)
+ event_id_2 = self.helper.send(room_2, "wombats!", tok=u1_token)["event_id"]
+ self.helper.send(room_3, "Matrix!", tok=u1_token)
+ event_id_4 = self.helper.send(room_2, "rabbits!", tok=u1_token)["event_id"]
+ event_id_5 = self.helper.send(room_3, "Synapse!", tok=u1_token)["event_id"]
+
+ # destination_rooms should already be populated, but let us pretend that we already
+ # sent (successfully) up to and including event id 2
+ event_2 = self.get_success(self.hs.get_datastore().get_event(event_id_2))
+
+ # also fetch event 5 so we know its last_successful_stream_ordering later
+ event_5 = self.get_success(self.hs.get_datastore().get_event(event_id_5))
+
+ self.get_success(
+ self.hs.get_datastore().set_destination_last_successful_stream_ordering(
+ "host2", event_2.internal_metadata.stream_ordering
+ )
+ )
+
+ # ACT
+ self.get_success(per_dest_queue._catch_up_transmission_loop())
+
+ # ASSERT, noticing in particular:
+ # - event 3 not sent out, because event 5 replaces it
+ # - order is least recent first, so event 5 comes after event 4
+ # - catch-up is completed
+ self.assertEqual(len(sent_pdus), 2)
+ self.assertEqual(sent_pdus[0].event_id, event_id_4)
+ self.assertEqual(sent_pdus[1].event_id, event_id_5)
+ self.assertFalse(per_dest_queue._catching_up)
+ self.assertEqual(
+ per_dest_queue._last_successful_stream_ordering,
+ event_5.internal_metadata.stream_ordering,
+ )
+
+ @override_config({"send_federation": True})
+ def test_catch_up_on_synapse_startup(self):
+ """
+ Tests the behaviour of get_catch_up_outstanding_destinations and
+ _wake_destinations_needing_catchup.
+ """
+
+ # list of sorted server names (note that there are more servers than the batch
+ # size used in get_catch_up_outstanding_destinations).
+ server_names = ["server%02d" % number for number in range(42)] + ["zzzerver"]
+
+ # ARRANGE:
+ # - a local user (u1)
+ # - a room which u1 is joined to (and remote users @user:serverXX are
+ # joined to)
+
+ # mark the remotes as online
+ self.is_online = True
+
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_id = self.helper.create_room_as("u1", tok=u1_token)
+
+ for server_name in server_names:
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room_id, "@user:%s" % server_name, "join"
+ )
+ )
+
+ # create an event
+ self.helper.send(room_id, "deary me!", tok=u1_token)
+
+ # ASSERT:
+ # - All servers are up to date so none should have outstanding catch-up
+ outstanding_when_successful = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(None)
+ )
+ self.assertEqual(outstanding_when_successful, [])
+
+ # ACT:
+ # - Make the remote servers unreachable
+ self.is_online = False
+
+ # - Mark zzzerver as being backed-off from
+ now = self.clock.time_msec()
+ self.get_success(
+ self.hs.get_datastore().set_destination_retry_timings(
+ "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day
+ )
+ )
+
+ # - Send an event
+ self.helper.send(room_id, "can anyone hear me?", tok=u1_token)
+
+ # ASSERT (get_catch_up_outstanding_destinations):
+ # - all remotes are outstanding
+ # - they are returned in batches of 25, in order
+ outstanding_1 = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(None)
+ )
+
+ self.assertEqual(len(outstanding_1), 25)
+ self.assertEqual(outstanding_1, server_names[0:25])
+
+ outstanding_2 = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(
+ outstanding_1[-1]
+ )
+ )
+ self.assertNotIn("zzzerver", outstanding_2)
+ self.assertEqual(len(outstanding_2), 17)
+ self.assertEqual(outstanding_2, server_names[25:-1])
+
+ # ACT: call _wake_destinations_needing_catchup
+
+ # patch wake_destination to just count the destinations instead
+ woken = []
+
+ def wake_destination_track(destination):
+ woken.append(destination)
+
+ self.hs.get_federation_sender().wake_destination = wake_destination_track
+
+ # cancel the pre-existing timer for _wake_destinations_needing_catchup
+ # this is because we are calling it manually rather than waiting for it
+ # to be called automatically
+ self.hs.get_federation_sender()._catchup_after_startup_timer.cancel()
+
+ self.get_success(
+ self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0
+ )
+
+ # ASSERT (_wake_destinations_needing_catchup):
+ # - all remotes are woken up, save for zzzerver
+ self.assertNotIn("zzzerver", woken)
+ # - all destinations are woken exactly once; they appear once in woken.
+ self.assertCountEqual(woken, server_names[:-1])
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index d1bd18da39..917762e6b6 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -34,7 +34,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call.
- mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+ mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver(
@@ -47,13 +47,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
- mock_send_transaction.return_value = defer.succeed({})
+ mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
)
- self.successResultOf(sender.send_read_receipt(receipt))
+ self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
@@ -87,13 +87,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
- mock_send_transaction.return_value = defer.succeed({})
+ mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
)
- self.successResultOf(sender.send_read_receipt(receipt))
+ self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
@@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
)
- self.successResultOf(sender.send_read_receipt(receipt))
+ self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
mock_send_transaction.assert_not_called()
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 296dc887be..da933ecd75 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -15,6 +15,8 @@
# limitations under the License.
import logging
+from parameterized import parameterized
+
from synapse.events import make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin
@@ -23,6 +25,37 @@ from synapse.rest.client.v1 import login, room
from tests import unittest
+class FederationServerTests(unittest.FederatingHomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ @parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)])
+ def test_bad_request(self, query_content):
+ """
+ Querying with bad data returns a reasonable error code.
+ """
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.inject_room_member(room_1, "@user:other.example.com", "join")
+
+ "/get_missing_events/(?P<room_id>[^/]*)/?"
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
+ query_content,
+ )
+ self.render(request)
+ self.assertEquals(400, channel.code, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
+
+
class ServerACLsTestCase(unittest.TestCase):
def test_blacklisted_server(self):
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 27d83bb7d9..72e22d655f 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -26,7 +26,7 @@ from tests.unittest import override_config
class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
- class Authenticator(object):
+ class Authenticator:
def authenticate_request(self, request, content):
return defer.succeed("otherserver.nottld")
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index ebabe9a7d6..2a0b7c1b56 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.handlers.appservice import ApplicationServicesHandler
+from tests.test_utils import make_awaitable
from tests.utils import MockClock
from .. import unittest
@@ -117,9 +118,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice_alias(is_interested_in_alias=False),
]
- self.mock_as_api.query_alias.return_value = defer.succeed(True)
+ self.mock_as_api.query_alias.return_value = make_awaitable(True)
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_association_from_room_alias.return_value = defer.succeed(
+ self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
Mock(room_id=room_id, servers=servers)
)
@@ -135,7 +136,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def _mkservice(self, is_interested):
service = Mock()
- service.is_interested.return_value = defer.succeed(is_interested)
+ service.is_interested.return_value = make_awaitable(is_interested)
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c01b04e1dc..97877c2e42 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -24,10 +24,11 @@ from synapse.api.errors import ResourceLimitError
from synapse.handlers.auth import AuthHandler
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class AuthHandlers(object):
+class AuthHandlers:
def __init__(self, hs):
self.auth_handler = AuthHandler(hs)
@@ -142,7 +143,7 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.large_number_of_users)
+ return_value=make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
@@ -153,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.large_number_of_users)
+ return_value=make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -168,7 +169,7 @@ class AuthTestCase(unittest.TestCase):
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -178,7 +179,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -188,10 +189,10 @@ class AuthTestCase(unittest.TestCase):
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
@@ -199,10 +200,10 @@ class AuthTestCase(unittest.TestCase):
)
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -215,7 +216,7 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.small_number_of_users)
+ return_value=make_awaitable(self.small_number_of_users)
)
# Ensure does not raise exception
yield defer.ensureDeferred(
@@ -225,7 +226,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.small_number_of_users)
+ return_value=make_awaitable(self.small_number_of_users)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 6aa322bf3a..969d44c787 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -35,6 +35,17 @@ class DeviceTestCase(unittest.HomeserverTestCase):
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
+ def test_device_is_created_with_invalid_name(self):
+ self.get_failure(
+ self.handler.check_device_registered(
+ user_id="@boris:foo",
+ device_id="foo",
+ initial_device_display_name="a"
+ * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
+ ),
+ synapse.api.errors.SynapseError,
+ )
+
def test_device_is_created_if_doesnt_exist(self):
res = self.get_success(
self.handler.check_device_registered(
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 00bb776271..bc0c5aefdc 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -16,8 +16,6 @@
from mock import Mock
-from twisted.internet import defer
-
import synapse
import synapse.api.errors
from synapse.api.constants import EventTypes
@@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room
from synapse.types import RoomAlias, create_requester
from tests import unittest
+from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
def test_get_remote_association(self):
- self.mock_federation.make_query.return_value = defer.succeed(
+ self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 210ddcbb88..366dcfb670 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -30,7 +30,7 @@ from tests import unittest, utils
class E2eKeysHandlerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 3362050ce0..7adde9b9de 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -47,7 +47,7 @@ room_keys = {
class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..d5087e58be 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -21,7 +21,6 @@ from mock import Mock, patch
import attr
import pymacaroons
-from twisted.internet import defer
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@@ -75,7 +74,24 @@ COMMON_CONFIG = {
COOKIE_NAME = b"oidc_session"
COOKIE_PATH = "/_synapse/oidc"
-MockedMappingProvider = Mock(OidcMappingProvider)
+
+class TestMappingProvider(OidcMappingProvider):
+ @staticmethod
+ def parse_config(config):
+ return
+
+ def get_remote_user_id(self, userinfo):
+ return userinfo["sub"]
+
+ async def map_user_attributes(self, userinfo, token):
+ return {"localpart": userinfo["username"], "display_name": None}
+
+ # Do not include get_extra_attributes to test backwards compatibility paths.
+
+
+class TestMappingProviderExtra(TestMappingProvider):
+ async def get_extra_attributes(self, userinfo, token):
+ return {"phone": userinfo["phone"]}
def simple_async_mock(return_value=None, raises=None):
@@ -116,15 +132,19 @@ class OidcHandlerTestCase(HomeserverTestCase):
config = self.default_config()
config["public_baseurl"] = BASE_URL
- oidc_config = config.get("oidc_config", {})
+ oidc_config = {}
oidc_config["enabled"] = True
oidc_config["client_id"] = CLIENT_ID
oidc_config["client_secret"] = CLIENT_SECRET
oidc_config["issuer"] = ISSUER
oidc_config["scopes"] = SCOPES
oidc_config["user_mapping_provider"] = {
- "module": __name__ + ".MockedMappingProvider"
+ "module": __name__ + ".TestMappingProvider",
}
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
hs = self.setup_test_homeserver(
@@ -155,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}})
- @defer.inlineCallbacks
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
- metadata = yield defer.ensureDeferred(self.handler.load_metadata())
+ metadata = self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.assertEqual(metadata.issuer, ISSUER)
@@ -171,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_metadata())
+ self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
- @defer.inlineCallbacks
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
- yield defer.ensureDeferred(self.handler.load_metadata())
+ self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
- @defer.inlineCallbacks
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
- jwks = yield defer.ensureDeferred(self.handler.load_jwks())
+ jwks = self.get_success(self.handler.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
self.assertEqual(jwks, {"keys": []})
# subsequent calls should be cached…
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_jwks())
+ self.get_success(self.handler.load_jwks())
self.http_client.get_json.assert_not_called()
# …unless forced
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.get_success(self.handler.load_jwks(force=True))
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
- with self.assertRaises(RuntimeError):
- yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
self.handler._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock()
- jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ jwks = self.get_success(self.handler.load_jwks(force=True))
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
@@ -289,11 +305,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# This should not throw
self.handler._validate_metadata()
- @defer.inlineCallbacks
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie"])
- url = yield defer.ensureDeferred(
+ url = self.get_success(
self.handler.handle_redirect_request(req, b"http://client/redirect")
)
url = urlparse(url)
@@ -333,20 +348,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(params["nonce"], [nonce])
self.assertEqual(redirect, "http://client/redirect")
- @defer.inlineCallbacks
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
self.handler._render_error = Mock()
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "")
request.args[b"error_description"] = [b"some description"]
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "some description")
- @defer.inlineCallbacks
def test_callback(self):
"""Code callback works and display errors if something went wrong.
@@ -367,39 +380,48 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "foo",
"preferred_username": "bar",
}
- user_id = UserID("foo", "domain.org")
+ user_id = "@foo:domain.org"
self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(spec=["args", "getCookie", "addCookie"])
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
code = "code"
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
- session = self.handler._generate_oidc_session_token(
+ user_agent = "Browser"
+ ip_address = "10.0.0.1"
+ request.getCookie.return_value = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
- request.getCookie.return_value = session
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+ request.getClientIP.return_value = ip_address
+
+ self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url,
+ user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()
@@ -407,13 +429,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._map_userinfo_to_user = simple_async_mock(
raises=MappingException()
)
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error")
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
# Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception())
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
self.handler._auth_handler.complete_sso_login.reset_mock()
@@ -424,30 +446,31 @@ class OidcHandlerTestCase(HomeserverTestCase):
# With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url,
+ user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
self.handler._exchange_code = simple_async_mock(
raises=OidcError("invalid_request")
)
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
- @defer.inlineCallbacks
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
self.handler._render_error = Mock(return_value=None)
@@ -456,20 +479,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Missing cookie
request.args = {}
request.getCookie.return_value = None
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("missing_session", "No session cookie found")
# Missing session parameter
request.args = {}
request.getCookie.return_value = "session"
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request", "State parameter is missing")
# Invalid cookie
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = "session"
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_session")
# Mismatching session
@@ -482,18 +505,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args = {}
request.args[b"state"] = [b"mismatching state"]
request.getCookie.return_value = session
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mismatching_session")
# Valid session
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = session
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
- @defer.inlineCallbacks
def test_exchange_code(self):
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
@@ -502,7 +524,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
)
code = "code"
- ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
+ ret = self.get_success(self.handler._exchange_code(code))
kwargs = self.http_client.request.call_args[1]
self.assertEqual(ret, token)
@@ -524,10 +546,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "foo", "error_description": "bar"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "foo")
- self.assertEqual(exc.exception.error_description, "bar")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "foo")
+ self.assertEqual(exc.value.error_description, "bar")
# Internal server error with no JSON body
self.http_client.request = simple_async_mock(
@@ -535,9 +556,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "server_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
self.http_client.request = simple_async_mock(
@@ -547,17 +567,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "internal_server_error"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "internal_server_error")
+
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.http_client.request = simple_async_mock(
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "server_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
self.http_client.request = simple_async_mock(
@@ -565,6 +584,121 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "some_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "some_error")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestMappingProviderExtra"
+ }
+ }
+ }
+ )
+ def test_extra_attributes(self):
+ """
+ Login while using a mapping provider that implements get_extra_attributes.
+ """
+ token = {
+ "type": "bearer",
+ "id_token": "id_token",
+ "access_token": "access_token",
+ }
+ userinfo = {
+ "sub": "foo",
+ "phone": "1234567",
+ }
+ user_id = "@foo:domain.org"
+ self.handler._exchange_code = simple_async_mock(return_value=token)
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ self.handler._auth_handler.complete_sso_login = simple_async_mock()
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
+
+ state = "state"
+ client_redirect_url = "http://client/redirect"
+ request.getCookie.return_value = self.handler._generate_oidc_session_token(
+ state=state,
+ nonce="nonce",
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
+ )
+
+ request.args = {}
+ request.args[b"code"] = [b"code"]
+ request.args[b"state"] = [state.encode("utf-8")]
+
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
+ request.getClientIP.return_value = "10.0.0.1"
+
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url, {"phone": "1234567"},
+ )
+
+ def test_map_userinfo_to_user(self):
+ """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+ userinfo = {
+ "sub": "test_user",
+ "username": "test_user",
+ }
+ # The token doesn't matter with the default user mapping provider.
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Some providers return an integer ID.
+ userinfo = {
+ "sub": 1234,
+ "username": "test_user_2",
+ }
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user_2:test")
+
+ # Test if the mxid is already taken
+ store = self.hs.get_datastore()
+ user3 = UserID.from_string("@test_user_3:test")
+ self.get_success(
+ store.register_user(user_id=user3.to_string(), password_hash=None)
+ )
+ userinfo = {"sub": "test3", "username": "test_user_3"}
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
+
+ @override_config({"oidc_config": {"allow_existing_users": True}})
+ def test_map_userinfo_to_existing_user(self):
+ """Existing users can log in with OpenID Connect when allow_existing_users is True."""
+ store = self.hs.get_datastore()
+ user4 = UserID.from_string("@test_user_4:test")
+ self.get_success(
+ store.register_user(user_id=user4.to_string(), password_hash=None)
+ )
+ userinfo = {
+ "sub": "test4",
+ "username": "test_user_4",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user_4:test")
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 05ea40a7de..306dcfe944 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -19,6 +19,7 @@ from mock import Mock, call
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
@@ -32,7 +33,6 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest.client.v1 import room
-from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from tests import unittest
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 42a236aa58..1cef10feff 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -24,10 +24,11 @@ from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class ProfileHandlers(object):
+class ProfileHandlers:
def __init__(self, hs):
self.profile_handler = MasterProfileHandler(hs)
@@ -63,12 +64,16 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
+ yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
+
self.handler = hs.get_profile_handler()
self.hs = hs
@defer.inlineCallbacks
def test_get_my_name(self):
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
+ )
displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.frank)
@@ -101,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
@defer.inlineCallbacks
@@ -109,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
+ )
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
# Setting displayname a second time is forbidden
@@ -136,7 +153,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_other_name(self):
- self.mock_federation.make_query.return_value = defer.succeed(
+ self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
@@ -154,7 +171,10 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.set_profile_displayname("caroline", "Caroline", 1)
+ yield defer.ensureDeferred(self.store.create_profile("caroline"))
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname("caroline", "Caroline", 1)
+ )
response = yield defer.ensureDeferred(
self.query_handlers["profile"](
@@ -166,8 +186,10 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png", 1
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png", 1
+ )
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
@@ -184,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/pic.gif",
)
@@ -198,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
@@ -207,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png", 1
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png", 1
+ )
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e951a62a6d..312c03c83d 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -15,8 +15,7 @@
from mock import Mock
-from twisted.internet import defer
-
+from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
@@ -25,15 +24,18 @@ from synapse.rest.client.v2_alpha.register import (
_map_email_to_displayname,
register_servlets,
)
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
from tests.server import FakeChannel
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
from .. import unittest
-class RegistrationHandlers(object):
+class RegistrationHandlers:
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
@@ -108,7 +110,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value - 1)
+ return_value=make_awaitable(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@@ -116,7 +118,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.lots_of_users)
+ return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -124,7 +126,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -134,14 +136,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.lots_of_users)
+ return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
@@ -197,7 +199,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.is_real_user = Mock(return_value=defer.succeed(False))
+ self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@@ -209,8 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
room_alias_str = "#room:test"
- self.store.count_real_users = Mock(return_value=defer.succeed(1))
- self.store.is_real_user = Mock(return_value=defer.succeed(True))
+ self.store.count_real_users = Mock(return_value=make_awaitable(1))
+ self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@@ -224,8 +226,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.count_real_users = Mock(return_value=defer.succeed(2))
- self.store.is_real_user = Mock(return_value=defer.succeed(True))
+ self.store.count_real_users = Mock(return_value=make_awaitable(2))
+ self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@@ -484,6 +486,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_spam_checker_deny(self):
+ """A spam checker can deny registration, which results in an error."""
+
+ class DenyAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.DENY
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [DenyAll()]
+
+ self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+ def test_spam_checker_shadow_ban(self):
+ """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+ class BanAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanAll()]
+
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+ # Get an access token.
+ token = self.macaroon_generator.generate_access_token(user_id)
+ self.get_success(
+ self.store.add_access_token_to_user(
+ user_id=user_id, token=token, device_id=None, valid_until_ms=None
+ )
+ )
+
+ # Ensure the user was marked as shadow-banned.
+ request = Mock(args={})
+ request.args[b"access_token"] = [token.encode("ascii")]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ auth = Auth(self.hs)
+ requester = self.get_success(auth.get_user_by_req(request))
+
+ self.assertTrue(requester.shadow_banned)
+
def test_email_to_displayname_mapping(self):
"""Test that custom emails are mapped to new user displaynames correctly"""
self._check_mapping(
@@ -527,16 +576,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Mock Synapse's threepid validator
get_threepid_validation_session = Mock(
- return_value=defer.succeed(
+ return_value=make_awaitable(
{"medium": "email", "address": email, "validated_at": 0}
)
)
self.store.get_threepid_validation_session = get_threepid_validation_session
- delete_threepid_session = Mock(return_value=defer.succeed(None))
+ delete_threepid_session = Mock(return_value=make_awaitable(None))
self.store.delete_threepid_session = delete_threepid_session
# Mock Synapse's http json post method to check for the internal bind call
- post_json_get_json = Mock(return_value=defer.succeed(None))
+ post_json_get_json = Mock(return_value=make_awaitable(None))
self.hs.get_simple_http_client().post_json_get_json = post_json_get_json
# Retrieve a UIA session ID
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 07092f026a..0229f58315 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -15,7 +15,7 @@
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
-from synapse.storage.data_stores.main import stats
+from synapse.storage.databases.main import stats
from tests import unittest
@@ -48,16 +48,16 @@ class StatsRoomTests(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_rooms",
@@ -67,7 +67,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
@@ -77,7 +77,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -87,8 +87,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- def get_all_room_state(self):
- return self.store.db.simple_select_list(
+ async def get_all_room_state(self):
+ return await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)
@@ -102,7 +102,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
return self.get_success(
- self.store.db.simple_select_one(
+ self.store.db_pool.simple_select_one(
table + "_historical",
{id_col: stat_id, end_ts: end_ts},
cols,
@@ -115,10 +115,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
def test_initial_room(self):
@@ -152,10 +152,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
r = self.get_success(self.get_all_room_state())
@@ -192,9 +192,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# the position that the deltas should begin at, once they take over.
self.hs.config.stats_enabled = True
self.handler.stats_enabled = True
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_update_one(
+ self.store.db_pool.simple_update_one(
table="stats_incremental_position",
keyvalues={},
updatevalues={"stream_id": 0},
@@ -202,17 +202,17 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Now, before the table is actually ingested, add some more events.
@@ -223,13 +223,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Now do the initial ingestion.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -239,12 +239,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
self.reactor.advance(86401)
@@ -259,7 +259,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# self.handler.notify_new_event()
# We need to let the delta processor advance…
- self.pump(10 * 60)
+ self.reactor.advance(10 * 60)
# Get the slices! There should be two -- day 1, and day 2.
r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0))
@@ -352,6 +352,37 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ def test_updating_profile_information_does_not_increase_joined_members_count(self):
+ """
+ Check that the joined_members count does not increase when a user changes their
+ profile information (which is done by sending another join membership event into
+ the room.
+ """
+ self._perform_background_initial_update()
+
+ # Create a user and room
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ # Get the current room stats
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ # Send a profile update into the room
+ new_profile = {"displayname": "bob"}
+ self.helper.change_membership(
+ r1, u1, u1, "join", extra_data=new_profile, tok=u1token
+ )
+
+ # Get the new room stats
+ r1stats_post = self._get_current_stats("room", r1)
+
+ # Ensure that the user count did not changed
+ self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
+ self.assertEqual(
+ r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
+ )
+
def test_send_state_event_nonoverwriting(self):
"""
When we send a non-overwriting state event, it increments total_events AND current_state_events
@@ -675,15 +706,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# preparation stage of the initial background update
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_delete(
+ self.store.db_pool.simple_delete(
"room_stats_current", {"1": 1}, "test_delete_stats"
)
)
self.get_success(
- self.store.db.simple_delete(
+ self.store.db_pool.simple_delete(
"user_stats_current", {"1": 1}, "test_delete_stats"
)
)
@@ -695,9 +726,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# now do the background updates
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_rooms",
@@ -707,7 +738,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
@@ -717,7 +748,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -728,10 +759,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
r1stats_complete = self._get_current_stats("room", r1)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5878f74175..3fec09ea8a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -21,9 +21,10 @@ from mock import ANY, Mock, call
from twisted.internet import defer
from synapse.api.errors import AuthError
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import register_federation_servlets
@@ -72,6 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"delivered_txn",
"get_received_txn_response",
"set_received_txn_response",
+ "get_destination_last_successful_stream_ordering",
"get_destination_retry_timings",
"get_devices_by_remote",
"maybe_store_room_on_invite",
@@ -79,6 +81,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"get_user_directory_stream_pos",
"get_current_state_deltas",
"get_device_updates_by_remote",
+ "get_room_max_stream_ordering",
]
)
@@ -115,10 +118,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
+ self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
(0, [])
)
+ self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable(
+ None
+ )
+
def get_received_txn_response(*args):
return defer.succeed(None)
@@ -126,10 +133,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- def check_user_in_room(room_id, user_id):
+ async def check_user_in_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
- return defer.succeed(None)
+ return None
hs.get_auth().check_user_in_room = check_user_in_room
@@ -143,19 +150,21 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = get_users_in_room
- self.datastore.get_user_directory_stream_pos.return_value = (
+ self.datastore.get_user_directory_stream_pos.side_effect = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
- defer.succeed(1)
+ lambda: make_awaitable(1)
)
self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0
- self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
+ self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
([], 0)
)
- self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
- self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+ self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
+ None
+ )
+ self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
None
)
@@ -166,7 +175,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=20000,
)
)
@@ -193,7 +205,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=20000,
)
)
@@ -268,7 +283,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.stopped_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
)
)
@@ -308,7 +325,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=10000,
)
)
@@ -347,7 +367,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=10000,
)
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index ddee8d9e3a..48f750d357 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_spam_checker(self):
"""
- A user which fails to the spam checks will not appear in search results.
+ A user which fails the spam checks will not appear in search results.
"""
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
@@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
- class AllowAll(object):
+ class AllowAll:
def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll(object):
+ class BlockAll:
def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
@@ -339,7 +339,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_in_public_rooms(self):
r = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
)
)
@@ -350,7 +350,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_who_share_private_rooms(self):
return self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
@@ -362,10 +362,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_createtables",
@@ -374,7 +374,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_rooms",
@@ -384,7 +384,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_users",
@@ -394,7 +394,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_cleanup",
@@ -437,10 +437,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
shares_private = self.get_users_who_share_private_rooms()
@@ -476,10 +476,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
shares_private = self.get_users_who_share_private_rooms()
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2096ba3c91..3e5a856584 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -133,7 +133,7 @@ def create_test_cert_file(sanlist):
@implementer(IOpenSSLServerConnectionCreator)
-class TestServerTLSConnectionFactory(object):
+class TestServerTLSConnectionFactory:
"""An SSL connection creator which returns connections which present a certificate
signed by our test CA."""
@@ -145,7 +145,7 @@ class TestServerTLSConnectionFactory(object):
self._cert_file = create_test_cert_file(sanlist)
def serverConnectionForTLS(self, tlsProtocol):
- ctx = SSL.Context(SSL.TLSv1_METHOD)
+ ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_certificate_file(self._cert_file)
ctx.use_privatekey_file(get_test_key_file())
return Connection(ctx, None)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index db260d599e..c3f7a28dcc 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -972,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_well_known_cache(self):
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -995,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
well_known_server.loseConnection()
# repeat the request: it should hit the cache
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"target-server")
@@ -1003,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((1000.0,))
# now it should connect again
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -1026,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -1052,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# another lookup.
self.reactor.pump((900.0,))
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# The resolver may retry a few times, so fonx all requests that come along
attempts = 0
@@ -1082,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((10000.0,))
# Repated the request, this time it should fail if the lookup fails.
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
clients = self.reactor.tcpClients
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -1252,7 +1264,7 @@ def _log_request(request):
@implementer(IPolicyForHTTPS)
-class TrustingTLSPolicyForHTTPS(object):
+class TrustingTLSPolicyForHTTPS:
"""An IPolicyForHTTPS which checks that the certificate belongs to the
right server, but doesn't check the certificate chain."""
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index fff4f0cbf4..212484a7fe 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -16,6 +16,7 @@
from mock import Mock
from netaddr import IPSet
+from parameterized import parameterized
from twisted.internet import defer
from twisted.internet.defer import TimeoutError
@@ -58,7 +59,9 @@ class FederationClientTests(HomeserverTestCase):
@defer.inlineCallbacks
def do_request():
with LoggingContext("one") as context:
- fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
+ fetch_d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar")
+ )
# Nothing happened yet
self.assertNoResult(fetch_d)
@@ -120,7 +123,9 @@ class FederationClientTests(HomeserverTestCase):
"""
If the DNS lookup returns an error, it will bubble up.
"""
- d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+ )
self.pump()
f = self.failureResultOf(d)
@@ -128,7 +133,9 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
def test_client_connection_refused(self):
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -154,7 +161,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is not connected and is timed out, it'll give a
ConnectingCancelledError or TimeoutError.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -184,7 +193,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -226,7 +237,7 @@ class FederationClientTests(HomeserverTestCase):
# Try making a GET request to a blacklisted IPv4 address
# ------------------------------------------------------
# Make the request
- d = cl.get_json("internal:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
# Nothing happened yet
self.assertNoResult(d)
@@ -244,7 +255,9 @@ class FederationClientTests(HomeserverTestCase):
# Try making a POST request to a blacklisted IPv6 address
# -------------------------------------------------------
# Make the request
- d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+ )
# Nothing has happened yet
self.assertNoResult(d)
@@ -263,7 +276,7 @@ class FederationClientTests(HomeserverTestCase):
# Try making a GET request to a non-blacklisted IPv4 address
# ----------------------------------------------------------
# Make the request
- d = cl.post_json("fine:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))
# Nothing has happened yet
self.assertNoResult(d)
@@ -286,7 +299,7 @@ class FederationClientTests(HomeserverTestCase):
request = MatrixFederationRequest(
method="GET", destination="testserv:8008", path="foo/bar"
)
- d = self.cl._send_request(request, timeout=10000)
+ d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000))
self.pump()
@@ -305,12 +318,14 @@ class FederationClientTests(HomeserverTestCase):
r = self.successResultOf(d)
self.assertEqual(r.code, 200)
- def test_client_headers_no_body(self):
+ @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
+ def test_timeout_reading_body(self, method_name: str):
"""
If the HTTP request is connected, but gets no response before being
- timed out, it'll give a ResponseNeverReceived.
+ timed out, it'll give a RequestSendFailed with can_retry.
"""
- d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+ method = getattr(self.cl, method_name)
+ d = defer.ensureDeferred(method("testserv:8008", "foo/bar", timeout=10000))
self.pump()
@@ -334,7 +349,9 @@ class FederationClientTests(HomeserverTestCase):
self.reactor.advance(10.5)
f = self.failureResultOf(d)
- self.assertIsInstance(f.value, TimeoutError)
+ self.assertIsInstance(f.value, RequestSendFailed)
+ self.assertTrue(f.value.can_retry)
+ self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
def test_client_requires_trailing_slashes(self):
"""
@@ -342,7 +359,9 @@ class FederationClientTests(HomeserverTestCase):
requiring a trailing slash. We need to retry the request with a
trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ )
# Send the request
self.pump()
@@ -395,7 +414,9 @@ class FederationClientTests(HomeserverTestCase):
See test_client_requires_trailing_slashes() for context.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ )
# Send the request
self.pump()
@@ -432,7 +453,11 @@ class FederationClientTests(HomeserverTestCase):
self.failureResultOf(d)
def test_client_sends_body(self):
- self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})
+ defer.ensureDeferred(
+ self.cl.post_json(
+ "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
+ )
+ )
self.pump()
@@ -453,7 +478,7 @@ class FederationClientTests(HomeserverTestCase):
def test_closes_connection(self):
"""Check that the client closes unused HTTP connections"""
- d = self.cl.get_json("testserv:8008", "foo/bar")
+ d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
self.pump()
@@ -486,6 +511,53 @@ class FederationClientTests(HomeserverTestCase):
self.assertFalse(conn.disconnecting)
# wait for a while
- self.pump(120)
+ self.reactor.advance(120)
self.assertTrue(conn.disconnecting)
+
+ @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
+ def test_json_error(self, return_value):
+ """
+ Test what happens if invalid JSON is returned from the remote endpoint.
+ """
+
+ test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8008)
+
+ # complete the connection and wire it up to a fake transport
+ protocol = factory.buildProtocol(None)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ # that should have made it send the request to the transport
+ self.assertRegex(transport.value(), b"^GET /foo/bar")
+ self.assertRegex(transport.value(), b"Host: testserv:8008")
+
+ # Deferred is still without a result
+ self.assertNoResult(test_d)
+
+ # Send it the HTTP response
+ protocol.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Server: Fake\r\n"
+ b"Content-Type: application/json\r\n"
+ b"Content-Length: %i\r\n"
+ b"\r\n"
+ b"%s" % (len(return_value), return_value)
+ )
+
+ self.pump()
+
+ f = self.failureResultOf(test_d)
+ self.assertIsInstance(f.value, ValueError)
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
new file mode 100644
index 0000000000..45089158ce
--- /dev/null
+++ b/tests/http/test_servlet.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+from io import BytesIO
+
+from mock import Mock
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+ parse_json_object_from_request,
+ parse_json_value_from_request,
+)
+
+from tests import unittest
+
+
+def make_request(content):
+ """Make an object that acts enough like a request."""
+ request = Mock(spec=["content"])
+
+ if isinstance(content, dict):
+ content = json.dumps(content).encode("utf8")
+
+ request.content = BytesIO(content)
+ return request
+
+
+class TestServletUtils(unittest.TestCase):
+ def test_parse_json_value(self):
+ """Basic tests for parse_json_value_from_request."""
+ # Test round-tripping.
+ obj = {"foo": 1}
+ result = parse_json_value_from_request(make_request(obj))
+ self.assertEqual(result, obj)
+
+ # Results don't have to be objects.
+ result = parse_json_value_from_request(make_request(b'["foo"]'))
+ self.assertEqual(result, ["foo"])
+
+ # Test empty.
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b""))
+
+ result = parse_json_value_from_request(make_request(b""), allow_empty_body=True)
+ self.assertIsNone(result)
+
+ # Invalid UTF-8.
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b"\xFF\x00"))
+
+ # Invalid JSON.
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b"foo"))
+
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
+
+ def test_parse_json_object(self):
+ """Basic tests for parse_json_object_from_request."""
+ # Test empty.
+ result = parse_json_object_from_request(
+ make_request(b""), allow_empty_body=True
+ )
+ self.assertEqual(result, {})
+
+ # Test not an object
+ with self.assertRaises(SynapseError):
+ parse_json_object_from_request(make_request(b'["foo"]'))
diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
new file mode 100644
index 0000000000..a1cf0862d4
--- /dev/null
+++ b/tests/http/test_simple_client.py
@@ -0,0 +1,180 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from netaddr import IPSet
+
+from twisted.internet import defer
+from twisted.internet.error import DNSLookupError
+
+from synapse.http import RequestTimedOutError
+from synapse.http.client import SimpleHttpClient
+from synapse.server import HomeServer
+
+from tests.unittest import HomeserverTestCase
+
+
+class SimpleHttpClientTests(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs: "HomeServer"):
+ # Add a DNS entry for a test server
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ self.cl = hs.get_simple_http_client()
+
+ def test_dns_error(self):
+ """
+ If the DNS lookup returns an error, it will bubble up.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv2:8008/foo/bar"))
+ self.pump()
+
+ f = self.failureResultOf(d)
+ self.assertIsInstance(f.value, DNSLookupError)
+
+ def test_client_connection_refused(self):
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8008)
+ e = Exception("go away")
+ factory.clientConnectionFailed(None, e)
+ self.pump(0.5)
+
+ f = self.failureResultOf(d)
+
+ self.assertIs(f.value, e)
+
+ def test_client_never_connect(self):
+ """
+ If the HTTP request is not connected and is timed out, it'll give a
+ ConnectingCancelledError or TimeoutError.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ self.assertEqual(clients[0][0], "1.2.3.4")
+ self.assertEqual(clients[0][1], 8008)
+
+ # Deferred is still without a result
+ self.assertNoResult(d)
+
+ # Push by enough to time it out
+ self.reactor.advance(120)
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, RequestTimedOutError)
+
+ def test_client_connect_no_response(self):
+ """
+ If the HTTP request is connected, but gets no response before being
+ timed out, it'll give a ResponseNeverReceived.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ self.assertEqual(clients[0][0], "1.2.3.4")
+ self.assertEqual(clients[0][1], 8008)
+
+ conn = Mock()
+ client = clients[0][2].buildProtocol(None)
+ client.makeConnection(conn)
+
+ # Deferred is still without a result
+ self.assertNoResult(d)
+
+ # Push by enough to time it out
+ self.reactor.advance(120)
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, RequestTimedOutError)
+
+ def test_client_ip_range_blacklist(self):
+ """Ensure that Synapse does not try to connect to blacklisted IPs"""
+
+ # Add some DNS entries we'll blacklist
+ self.reactor.lookups["internal"] = "127.0.0.1"
+ self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337"
+ ip_blacklist = IPSet(["127.0.0.0/8", "fe80::/64"])
+
+ cl = SimpleHttpClient(self.hs, ip_blacklist=ip_blacklist)
+
+ # Try making a GET request to a blacklisted IPv4 address
+ # ------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(cl.get_json("http://internal:8008/foo/bar"))
+ self.pump(1)
+
+ # Check that it was unable to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 0)
+
+ self.failureResultOf(d, DNSLookupError)
+
+ # Try making a POST request to a blacklisted IPv6 address
+ # -------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(
+ cl.post_json_get_json("http://internalv6:8008/foo/bar", {})
+ )
+
+ # Move the reactor forwards
+ self.pump(1)
+
+ # Check that it was unable to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 0)
+
+ # Check that it was due to a blacklisted DNS lookup
+ self.failureResultOf(d, DNSLookupError)
+
+ # Try making a GET request to a non-blacklisted IPv4 address
+ # ----------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(cl.get_json("http://testserv:8008/foo/bar"))
+
+ # Nothing has happened yet
+ self.assertNoResult(d)
+
+ # Move the reactor forwards
+ self.pump(1)
+
+ # Check that it was able to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertNotEqual(len(clients), 0)
+
+ # Connection will still fail as this IP address does not resolve to anything
+ self.failureResultOf(d, RequestTimedOutError)
diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py
index 451d05c0f0..d36f5f426c 100644
--- a/tests/logging/test_structured.py
+++ b/tests/logging/test_structured.py
@@ -29,12 +29,12 @@ from synapse.logging.context import LoggingContext
from tests.unittest import DEBUG, HomeserverTestCase
-class FakeBeginner(object):
+class FakeBeginner:
def beginLoggingTo(self, observers, **kwargs):
self.observers = observers
-class StructuredLoggingTestBase(object):
+class StructuredLoggingTestBase:
"""
Test base that registers a cleanup handler to reset the stdlib log handler
to 'unset'.
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9c778a0e45..ccbb82f6a3 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -47,7 +47,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the new user exists with all provided attributes
self.assertEqual(user_id, "@bob:test")
self.assertTrue(access_token)
- self.assertTrue(self.store.get_user_by_id(user_id))
+ self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
# Check that the email was assigned
emails = self.get_success(self.store.user_get_threepids(user_id))
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 83032cc9ea..3224568640 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase
@attr.s
-class _User(object):
+class _User:
"Helper wrapper for user ID and access token"
id = attr.ib()
token = attr.ib()
@@ -170,7 +170,7 @@ class EmailPusherTests(HomeserverTestCase):
last_stream_ordering = pushers[0]["last_stream_ordering"]
# Advance time a bit, so the pusher will register something has happened
- self.pump(100)
+ self.pump(10)
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 06575ba0a6..ae60874ec3 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -65,7 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Since we use sqlite in memory databases we need to make sure the
# databases objects are the same.
- self.worker_hs.get_datastore().db = hs.get_datastore().db
+ self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler
@@ -198,7 +198,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.streamer = self.hs.get_replication_streamer()
store = self.hs.get_datastore()
- self.database = store.db
+ self.database_pool = store.db_pool
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -254,7 +254,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
)
store = worker_hs.get_datastore()
- store.db._db_pool = self.database._db_pool
+ store.db_pool._db_pool = self.database_pool._db_pool
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1a88c7fb80..c0ee1cfbd6 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,6 +20,7 @@ from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
+from synapse.types import PersistedEventPosition
from tests.server import FakeTransport
@@ -58,7 +59,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEquals
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
- return super(SlavedEventStoreTestCase, self).setUp()
+ return super().setUp()
def prepare(self, *args, **kwargs):
super().prepare(*args, **kwargs)
@@ -160,7 +161,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 0},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
)
self.persist(
@@ -173,7 +174,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 1},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
)
self.persist(
@@ -188,7 +189,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 1, "notify_count": 2},
+ {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
)
def test_get_rooms_for_user_with_stream_ordering(self):
@@ -204,10 +205,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
self.replicate()
+
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
+ )
self.check(
"get_rooms_for_user_with_stream_ordering",
(USER_ID_2,),
- {(ROOM_ID, j2.internal_metadata.stream_ordering)},
+ {(ROOM_ID, expected_pos)},
)
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
@@ -293,9 +298,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# the membership change is only any use to us if the room is in the
# joined_rooms list.
if membership_changes:
- self.assertEqual(
- joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)}
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
)
+ self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
event_id = 0
@@ -366,7 +372,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event))
- self.master_store.add_push_actions_to_staging(
- event.event_id, {user_id: actions for user_id, actions in push_actions}
+ self.get_success(
+ self.master_store.add_push_actions_to_staging(
+ event.event_id,
+ {user_id: actions for user_id, actions in push_actions},
+ False,
+ )
)
return event, context
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 8d4dbf232e..1d7edee5ba 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -16,15 +16,14 @@ import logging
from mock import Mock
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.test_utils import make_awaitable
logger = logging.getLogger(__name__)
@@ -46,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new event.
"""
mock_client = Mock(spec=["put_json"])
- mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ mock_client.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
@@ -74,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new events.
"""
mock_client1 = Mock(spec=["put_json"])
- mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -86,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
- mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -137,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new typing EDUs.
"""
mock_client1 = Mock(spec=["put_json"])
- mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -149,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
- mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -176,7 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.get_success(
typing_handler.started_typing(
target_user=UserID.from_string(user),
- auth_user=UserID.from_string(user),
+ requester=create_requester(user),
room_id=room,
timeout=20000,
)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index b1a4decced..0f1144fe1e 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.fetches = []
- def get_file(destination, path, output_stream, args=None, max_size=None):
+ async def get_file(destination, path, output_stream, args=None, max_size=None):
"""
Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code.
@@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
d = Deferred()
d.addCallback(write_to)
self.fetches.append((d, destination, path, args))
- return make_deferred_yieldable(d)
+ return await make_deferred_yieldable(d)
client = Mock()
client.get_file = get_file
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index faa7f381a9..92c9058887 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -221,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
request, channel = self.make_request(
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
new file mode 100644
index 0000000000..bf79086f78
--- /dev/null
+++ b/tests/rest/admin/test_event_reports.py
@@ -0,0 +1,382 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import report_event
+
+from tests import unittest
+
+
+class EventReportsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ report_event.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id1 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
+
+ self.room_id2 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok)
+
+ # Two rooms and two users. Every user sends and reports every room event
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.other_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id2, user_tok=self.other_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.admin_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id2, user_tok=self.admin_user_tok,
+ )
+
+ self.url = "/_synapse/admin/v1/event_reports"
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.other_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_default_success(self):
+ """
+ Testing list of reported events
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_limit(self):
+ """
+ Testing list of reported events with limit
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 5)
+ self.assertEqual(channel.json_body["next_token"], 5)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_from(self):
+ """
+ Testing list of reported events with a defined starting point (from)
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of reported events with a defined starting point and limit
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(channel.json_body["next_token"], 15)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_filter_room(self):
+ """
+ Testing list of reported events with a filter of room
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?room_id=%s" % self.room_id1,
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["room_id"], self.room_id1)
+
+ def test_filter_user(self):
+ """
+ Testing list of reported events with a filter of user
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?user_id=%s" % self.other_user,
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["user_id"], self.other_user)
+
+ def test_filter_user_and_room(self):
+ """
+ Testing list of reported events with a filter of user and room
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 5)
+ self.assertEqual(len(channel.json_body["event_reports"]), 5)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["user_id"], self.other_user)
+ self.assertEqual(report["room_id"], self.room_id1)
+
+ def test_valid_search_order(self):
+ """
+ Testing search order. Order by timestamps.
+ """
+
+ # fetch the most recent first, largest timestamp
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ report = 1
+ while report < len(channel.json_body["event_reports"]):
+ self.assertGreaterEqual(
+ channel.json_body["event_reports"][report - 1]["received_ts"],
+ channel.json_body["event_reports"][report]["received_ts"],
+ )
+ report += 1
+
+ # fetch the oldest first, smallest timestamp
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ report = 1
+ while report < len(channel.json_body["event_reports"]):
+ self.assertLessEqual(
+ channel.json_body["event_reports"][report - 1]["received_ts"],
+ channel.json_body["event_reports"][report]["received_ts"],
+ )
+ report += 1
+
+ def test_invalid_search_order(self):
+ """
+ Testing that a invalid search order returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual("Unknown direction: bar", channel.json_body["error"])
+
+ def test_limit_is_negative(self):
+ """
+ Testing that a negative list parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_from_is_negative(self):
+ """
+ Testing that a negative from parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 19)
+ self.assertEqual(channel.json_body["next_token"], 19)
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ request, channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def _create_event_and_report(self, room_id, user_tok):
+ """Create and report events
+ """
+ resp = self.helper.send(room_id, tok=user_tok)
+ event_id = resp["event_id"]
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/report/%s" % (room_id, event_id),
+ json.dumps({"score": -100, "reason": "this makes me sad"}),
+ access_token=user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _check_fields(self, content):
+ """Checks that all attributes are present in a event report
+ """
+ for c in content:
+ self.assertIn("id", c)
+ self.assertIn("received_ts", c)
+ self.assertIn("room_id", c)
+ self.assertIn("event_id", c)
+ self.assertIn("user_id", c)
+ self.assertIn("reason", c)
+ self.assertIn("content", c)
+ self.assertIn("sender", c)
+ self.assertIn("room_alias", c)
+ self.assertIn("event_json", c)
+ self.assertIn("score", c["content"])
+ self.assertIn("reason", c["content"])
+ self.assertIn("auth_events", c["event_json"])
+ self.assertIn("type", c["event_json"])
+ self.assertIn("room_id", c["event_json"])
+ self.assertIn("sender", c["event_json"])
+ self.assertIn("content", c["event_json"])
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ba8552c29f..6dfc709dc5 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -283,6 +283,23 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+ def test_purge_is_not_bool(self):
+ """
+ If parameter `purge` is not boolean, return an error
+ """
+ body = json.dumps({"purge": "NotBool"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
def test_purge_room_and_block(self):
"""Test to purge a room and block it.
Members will not be moved to a new room and will not receive a message.
@@ -297,7 +314,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Assert one user in room
self._is_member(room_id=self.room_id, user_id=self.other_user)
- body = json.dumps({"block": True})
+ body = json.dumps({"block": True, "purge": True})
request, channel = self.make_request(
"POST",
@@ -331,7 +348,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Assert one user in room
self._is_member(room_id=self.room_id, user_id=self.other_user)
- body = json.dumps({"block": False})
+ body = json.dumps({"block": False, "purge": True})
request, channel = self.make_request(
"POST",
@@ -351,6 +368,42 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._is_blocked(self.room_id, expect=False)
self._has_no_members(self.room_id)
+ def test_block_room_and_not_purge(self):
+ """Test to block a room without purging it.
+ Members will not be moved to a new room and will not receive a message.
+ The room will not be purged.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ body = json.dumps({"block": False, "purge": False})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url.encode("ascii"),
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(None, channel.json_body["new_room_id"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=False)
+ self._has_no_members(self.room_id)
+
def test_shutdown_room_consent(self):
"""Test that we can shutdown rooms with local users who have not
yet accepted the privacy policy. This used to fail when we tried to
@@ -513,7 +566,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"state_groups_state",
):
count = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table=table,
keyvalues={"room_id": room_id},
retcol="COUNT(*)",
@@ -614,7 +667,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
"state_groups_state",
):
count = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table=table,
keyvalues={"room_id": room_id},
retcol="COUNT(*)",
@@ -1121,6 +1174,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("room_id", channel.json_body)
self.assertIn("name", channel.json_body)
+ self.assertIn("topic", channel.json_body)
+ self.assertIn("avatar", channel.json_body)
self.assertIn("canonical_alias", channel.json_body)
self.assertIn("joined_members", channel.json_body)
self.assertIn("joined_local_members", channel.json_body)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index f16eef15f7..98d0623734 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -22,11 +22,12 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
-from synapse.api.errors import HttpResponseException, ResourceLimitError
-from synapse.rest.client.v1 import login
+from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -335,7 +336,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore()
# Set monthly active users to the limit
- store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
+ store.get_monthly_active_count = Mock(
+ return_value=make_awaitable(self.hs.config.max_mau_value)
+ )
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
self.get_failure(
@@ -588,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=self.hs.config.max_mau_value
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -628,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=self.hs.config.max_mau_value
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -871,6 +874,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self._is_erased("@user:test", False)
+ d = self.store.mark_user_erased("@user:test")
+ self.assertIsNone(self.get_success(d))
+ self._is_erased("@user:test", True)
# Attempt to reactivate the user (without a password).
request, channel = self.make_request(
@@ -903,6 +910,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
+ self._is_erased("@user:test", False)
def test_set_user_as_admin(self):
"""
@@ -992,3 +1000,104 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
+
+ def _is_erased(self, user_id, expect):
+ """Assert that the user is erased or not
+ """
+ d = self.store.is_user_erased(user_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertFalse(self.get_success(d))
+
+
+class UserMembershipRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/joined_rooms" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list rooms of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_get_rooms(self):
+ """
+ Tests that a normal lookup for rooms is successfully
+ """
+ # Create rooms and join
+ other_user_tok = self.login("user", "pass")
+ number_rooms = 5
+ for n in range(number_rooms):
+ self.helper.create_room_as(self.other_user, tok=other_user_tok)
+
+ # Get rooms
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_rooms, channel.json_body["total"])
+ self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index cc264cf0b5..47c0d5634c 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -46,50 +46,63 @@ class RetentionTestCase(unittest.HomeserverTestCase):
}
self.hs = self.setup_test_homeserver(config=config)
+
return self.hs
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
- def test_retention_state_event(self):
- """Tests that the server configuration can limit the values a user can set to the
- room's retention policy.
+ self.store = self.hs.get_datastore()
+ self.serializer = self.hs.get_event_client_serializer()
+ self.clock = self.hs.get_clock()
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
"""
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_day_ms * 4},
+ body={"max_lifetime": lifetime},
tok=self.token,
- expect_code=400,
)
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_with_state_event_outside_allowed(self):
+ """Tests that the server configuration can override the policy for a room when
+ running the purge jobs.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set a max_lifetime higher than the maximum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_hour_ms},
+ body={"max_lifetime": one_day_ms * 4},
tok=self.token,
- expect_code=400,
)
- def test_retention_event_purged_with_state_event(self):
- """Tests that expired events are correctly purged when the room's retention policy
- is defined by a state event.
- """
- room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Check that the event is purged after waiting for the maximum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
- # Set the room's retention period to 2 days.
- lifetime = one_day_ms * 2
+ # Set a max_lifetime lower than the minimum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": lifetime},
+ body={"max_lifetime": one_hour_ms},
tok=self.token,
)
- self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+ # Check that the event is purged after waiting for the minimum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 0.5)
def test_retention_event_purged_without_state_event(self):
"""Tests that expired events are correctly purged when the room's retention policy
@@ -141,11 +154,33 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# That event should be the second, not outdated event.
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
- def _test_retention_event_purged(self, room_id, increment):
+ def _test_retention_event_purged(self, room_id: str, increment: float):
+ """Run the following test scenario to test the message retention policy support:
+
+ 1. Send event 1
+ 2. Increment time by `increment`
+ 3. Send event 2
+ 4. Increment time by `increment`
+ 5. Check that event 1 has been purged
+ 6. Check that event 2 has not been purged
+ 7. Check that state events that were sent before event 1 aren't purged.
+ The main reason for sending a second event is because currently Synapse won't
+ purge the latest message in a room because it would otherwise result in a lack of
+ forward extremities for this room. It's also a good thing to ensure the purge jobs
+ aren't too greedy and purge messages they shouldn't.
+
+ Args:
+ room_id: The ID of the room to test retention in.
+ increment: The number of milliseconds to advance the clock each time. Must be
+ defined so that events in the room aren't purged if they are `increment`
+ old but are purged if they are `increment * 2` old.
+ """
# Get the create event to, later, check that we can still access it.
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
- message_handler.get_room_data(self.user_id, room_id, EventTypes.Create)
+ message_handler.get_room_data(
+ self.user_id, room_id, EventTypes.Create, state_key=""
+ )
)
# Send a first event to the room. This is the event we'll want to be purged at the
@@ -155,7 +190,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
expired_event_id = resp.get("event_id")
# Check that we can retrieve the event.
- expired_event = self.get_event(room_id, expired_event_id)
+ expired_event = self.get_event(expired_event_id)
self.assertEqual(
expired_event.get("content", {}).get("body"), "1", expired_event
)
@@ -173,26 +208,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# one should still be kept.
self.reactor.advance(increment / 1000)
- # Check that the event has been purged from the database.
- self.get_event(room_id, expired_event_id, expected_code=404)
+ # Check that the first event has been purged from the database, i.e. that we
+ # can't retrieve it anymore, because it has expired.
+ self.get_event(expired_event_id, expect_none=True)
- # Check that the event that hasn't been purged can still be retrieved.
- valid_event = self.get_event(room_id, valid_event_id)
+ # Check that the event that hasn't expired can still be retrieved.
+ valid_event = self.get_event(valid_event_id)
self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
# Check that we can still access state events that were sent before the event that
# has been purged.
self.get_event(room_id, create_event.event_id)
- def get_event(self, room_id, event_id, expected_code=200):
- url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+ def get_event(self, event_id, expect_none=False):
+ event = self.get_success(self.store.get_event(event_id, allow_none=True))
- request, channel = self.make_request("GET", url, access_token=self.token)
- self.render(request)
+ if expect_none:
+ self.assertIsNone(event)
+ return {}
- self.assertEqual(channel.code, expected_code, channel.result)
+ self.assertIsNotNone(event)
- return channel.json_body
+ time_now = self.clock.time_msec()
+ serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+
+ return serialized
class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
new file mode 100644
index 0000000000..dfe4bf7762
--- /dev/null
+++ b/tests/rest/client/test_shadow_banned.py
@@ -0,0 +1,312 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock, patch
+
+import synapse.rest.admin
+from synapse.api.constants import EventTypes
+from synapse.rest.client.v1 import directory, login, profile, room
+from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+
+from tests import unittest
+
+
+class _ShadowBannedBase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ # Create two users, one of which is shadow-banned.
+ self.banned_user_id = self.register_user("banned", "test")
+ self.banned_access_token = self.login("banned", "test")
+
+ self.store = self.hs.get_datastore()
+
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="users",
+ keyvalues={"name": self.banned_user_id},
+ updatevalues={"shadow_banned": True},
+ desc="shadow_ban",
+ )
+ )
+
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class RoomTestCase(_ShadowBannedBase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def test_invite(self):
+ """Invites from shadow-banned users don't actually get sent."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ self.helper.invite(
+ room=room_id,
+ src=self.banned_user_id,
+ tok=self.banned_access_token,
+ targ=self.other_user_id,
+ )
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ def test_invite_3pid(self):
+ """Ensure that a 3PID invite does not attempt to contact the identity server."""
+ identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler.lookup_3pid = Mock(
+ side_effect=AssertionError("This should not get called")
+ )
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/invite" % (room_id,),
+ {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # This should have raised an error earlier, but double check this wasn't called.
+ identity_handler.lookup_3pid.assert_not_called()
+
+ def test_create_room(self):
+ """Invitations during a room creation should be discarded, but the room still gets created."""
+ # The room creation is successful.
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ {"visibility": "public", "invite": [self.other_user_id]},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ room_id = channel.json_body["room_id"]
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ # Since a real room was created, the other user should be able to join it.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ # Both users should be in the room.
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
+
+ def test_message(self):
+ """Messages from shadow-banned users don't actually get sent."""
+
+ room_id = self.helper.create_room_as(
+ self.other_user_id, tok=self.other_access_token
+ )
+
+ # The user should be in the room.
+ self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
+
+ # Sending a message should complete successfully.
+ result = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "with right label"},
+ tok=self.banned_access_token,
+ )
+ self.assertIn("event_id", result)
+ event_id = result["event_id"]
+
+ latest_events = self.get_success(
+ self.store.get_latest_event_ids_in_room(room_id)
+ )
+ self.assertNotIn(event_id, latest_events)
+
+ def test_upgrade(self):
+ """A room upgrade should fail, but look like it succeeded."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
+ {"new_version": "6"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ # A new room_id should be returned.
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(new_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
+
+ def test_typing(self):
+ """Typing notifications should not be propagated into the room."""
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/typing/%s" % (room_id, self.banned_user_id),
+ {"typing": True, "timeout": 30000},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # There should be no typing events.
+ event_source = self.hs.get_event_sources().sources["typing"]
+ self.assertEquals(event_source.get_current_key(), 0)
+
+ # The other user can join and send typing events.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/typing/%s" % (room_id, self.other_user_id),
+ {"typing": True, "timeout": 30000},
+ access_token=self.other_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # These appear in the room.
+ self.assertEquals(event_source.get_current_key(), 1)
+ events = self.get_success(
+ event_source.get_new_events(from_key=0, room_ids=[room_id])
+ )
+ self.assertEquals(
+ events[0],
+ [
+ {
+ "type": "m.typing",
+ "room_id": room_id,
+ "content": {"user_ids": [self.other_user_id]},
+ }
+ ],
+ )
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class ProfileTestCase(_ShadowBannedBase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ def test_displayname(self):
+ """Profile changes should succeed, but don't end up in a room."""
+ original_display_name = "banned"
+ new_display_name = "new name"
+
+ # Join a room.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # The update should succeed.
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
+ {"displayname": new_display_name},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEqual(channel.json_body, {})
+
+ # The user's display name should be updated.
+ request, channel = self.make_request(
+ "GET", "/profile/%s/displayname" % (self.banned_user_id,)
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["displayname"], new_display_name)
+
+ # But the display name in the room should not be.
+ message_handler = self.hs.get_message_handler()
+ event = self.get_success(
+ message_handler.get_room_data(
+ self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ )
+ )
+ self.assertEqual(
+ event.content, {"membership": "join", "displayname": original_display_name}
+ )
+
+ def test_room_displayname(self):
+ """Changes to state events for a room should be processed, but not end up in the room."""
+ original_display_name = "banned"
+ new_display_name = "new name"
+
+ # Join a room.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # The update should succeed.
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
+ % (room_id, self.banned_user_id),
+ {"membership": "join", "displayname": new_display_name},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertIn("event_id", channel.json_body)
+
+ # The display name in the room should not be changed.
+ message_handler = self.hs.get_message_handler()
+ event = self.get_success(
+ message_handler.get_room_data(
+ self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ )
+ )
+ self.assertEqual(
+ event.content, {"membership": "join", "displayname": original_display_name}
+ )
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
deleted file mode 100644
index d03e121664..0000000000
--- a/tests/rest/client/test_third_party_rules.py
+++ /dev/null
@@ -1,170 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import threading
-from typing import Dict
-
-from mock import Mock
-
-from synapse.events import EventBase
-from synapse.module_api import ModuleApi
-from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-from synapse.types import Requester, StateMap
-
-from tests import unittest
-
-thread_local = threading.local()
-
-
-class ThirdPartyRulesTestModule:
- def __init__(self, config: Dict, module_api: ModuleApi):
- # keep a record of the "current" rules module, so that the test can patch
- # it if desired.
- thread_local.rules_module = self
- self.module_api = module_api
-
- async def on_create_room(
- self, requester: Requester, config: dict, is_requester_admin: bool
- ):
- return True
-
- async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
- return True
-
- @staticmethod
- def parse_config(config):
- return config
-
-
-def current_rules_module() -> ThirdPartyRulesTestModule:
- return thread_local.rules_module
-
-
-class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def default_config(self):
- config = super().default_config()
- config["third_party_event_rules"] = {
- "module": __name__ + ".ThirdPartyRulesTestModule",
- "config": {},
- }
- return config
-
- def prepare(self, reactor, clock, homeserver):
- # Create a user and room to play with during the tests
- self.user_id = self.register_user("kermit", "monkey")
- self.tok = self.login("kermit", "monkey")
-
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
-
- def test_third_party_rules(self):
- """Tests that a forbidden event is forbidden from being sent, but an allowed one
- can be sent.
- """
- # patch the rules module with a Mock which will return False for some event
- # types
- async def check(ev, state):
- return ev.type != "foo.bar.forbidden"
-
- callback = Mock(spec=[], side_effect=check)
- current_rules_module().check_event_allowed = callback
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
- {},
- access_token=self.tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"200", channel.result)
-
- callback.assert_called_once()
-
- # there should be various state events in the state arg: do some basic checks
- state_arg = callback.call_args[0][1]
- for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
- self.assertIn(k, state_arg)
- ev = state_arg[k]
- self.assertEqual(ev.type, k[0])
- self.assertEqual(ev.state_key, k[1])
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
- {},
- access_token=self.tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
-
- def test_modify_event(self):
- """Tests that the module can successfully tweak an event before it is persisted.
- """
- # first patch the event checker so that it will modify the event
- async def check(ev: EventBase, state):
- ev.content = {"x": "y"}
- return True
-
- current_rules_module().check_event_allowed = check
-
- # now send the event
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
- {"x": "x"},
- access_token=self.tok,
- )
- self.render(request)
- self.assertEqual(channel.result["code"], b"200", channel.result)
- event_id = channel.json_body["event_id"]
-
- # ... and check that it got modified
- request, channel = self.make_request(
- "GET",
- "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
- access_token=self.tok,
- )
- self.render(request)
- self.assertEqual(channel.result["code"], b"200", channel.result)
- ev = channel.json_body
- self.assertEqual(ev["content"]["x"], "y")
-
- def test_send_event(self):
- """Tests that the module can send an event into a room via the module api"""
- content = {
- "msgtype": "m.text",
- "body": "Hello!",
- }
- event_dict = {
- "room_id": self.room_id,
- "type": "m.room.message",
- "content": content,
- "sender": self.user_id,
- }
- event = self.get_success(
- current_rules_module().module_api.create_and_send_event_into_room(
- event_dict
- )
- ) # type: EventBase
-
- self.assertEquals(event.sender, self.user_id)
- self.assertEquals(event.room_id, self.room_id)
- self.assertEquals(event.type, "m.room.message")
- self.assertEquals(event.content, content)
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
new file mode 100644
index 0000000000..715e87de08
--- /dev/null
+++ b/tests/rest/client/third_party_rules.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import Requester
+
+from tests import unittest
+
+
+class ThirdPartyRulesTestModule:
+ def __init__(self, config, *args, **kwargs):
+ pass
+
+ async def on_create_room(
+ self, requester: Requester, config: dict, is_requester_admin: bool
+ ):
+ return True
+
+ async def check_event_allowed(self, event, context):
+ if event.type == "foo.bar.forbidden":
+ return False
+ else:
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ return config
+
+
+class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["third_party_event_rules"] = {
+ "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule",
+ "config": {},
+ }
+
+ self.hs = self.setup_test_homeserver(config=config)
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ # Create a user and room to play with during the tests
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_third_party_rules(self):
+ """Tests that a forbidden event is forbidden from being sent, but an allowed one
+ can be sent.
+ """
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index db52725cfe..5d987a30c7 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -7,8 +7,9 @@ from mock import Mock
import jwt
import synapse.rest.admin
+from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
-from synapse.rest.client.v2_alpha import devices
+from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from tests import unittest
@@ -62,8 +63,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -76,14 +76,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -111,8 +110,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -132,7 +130,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -160,8 +157,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -174,14 +170,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -754,3 +749,134 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
"JWT validation failed: Signature verification failed",
)
+
+
+AS_USER = "as_user_alice"
+
+
+class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ register.register_servlets,
+ ]
+
+ def register_as_user(self, username):
+ request, channel = self.make_request(
+ b"POST",
+ "/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
+ {"username": username},
+ )
+ self.render(request)
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+
+ self.service = ApplicationService(
+ id="unique_identifier",
+ token="some_token",
+ hostname="example.com",
+ sender="@asbot:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {"regex": r"@as_user.*", "exclusive": False}
+ ],
+ ApplicationService.NS_ROOMS: [],
+ ApplicationService.NS_ALIASES: [],
+ },
+ )
+ self.another_service = ApplicationService(
+ id="another__identifier",
+ token="another_token",
+ hostname="example.com",
+ sender="@as2bot:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {"regex": r"@as2_user.*", "exclusive": False}
+ ],
+ ApplicationService.NS_ROOMS: [],
+ ApplicationService.NS_ALIASES: [],
+ },
+ )
+
+ self.hs.get_datastore().services_cache.append(self.service)
+ self.hs.get_datastore().services_cache.append(self.another_service)
+ return self.hs
+
+ def test_login_appservice_user(self):
+ """Test that an appservice user can use /login
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_login_appservice_user_bot(self):
+ """Test that the appservice bot can use /login
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": self.service.sender},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_login_appservice_wrong_user(self):
+ """Test that non-as users cannot login with the as token
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": "fibble_wibble"},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_login_appservice_wrong_as(self):
+ """Test that as users cannot login with wrong as token
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.another_service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_login_appservice_no_token(self):
+ """Test that users must provide a token when using the appservice
+ login method
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 8df58b4a63..ace0a3c08d 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
profile_handler=self.mock_handler,
)
- def _get_user_by_req(request=None, allow_guest=False):
- return defer.succeed(synapse.types.create_requester(myid))
+ async def _get_user_by_req(request=None, allow_guest=False):
+ return synapse.types.create_requester(myid)
hs.get_auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
new file mode 100644
index 0000000000..081052f6a6
--- /dev/null
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -0,0 +1,448 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, push_rule, room
+
+from tests.unittest import HomeserverTestCase
+
+
+class PushRuleAttributesTestCase(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ push_rule.register_servlets,
+ ]
+ hijack_auth = False
+
+ def test_enabled_on_creation(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is enabled upon creation, even though a rule with that
+ ruleId existed previously and was disabled.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_on_recreation(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is enabled upon creation, even if a rule with that
+ ruleId existed previously and was disabled.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # disable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": False},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule disabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], False)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_disable(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is disabled and enabled when we ask for it.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # disable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": False},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule disabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], False)
+
+ # re-enable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule enabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_404_when_get_non_existent(self):
+ """
+ Tests that `enabled` gives 404 when the rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check 404 for deleted rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_get_non_existent_server_rule(self):
+ """
+ Tests that `enabled` gives 404 when the server-default rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_put_non_existent_rule(self):
+ """
+ Tests that `enabled` gives 404 when we put to a rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_put_non_existent_server_rule(self):
+ """
+ Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/.m.muahahah/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_get(self):
+ """
+ Tests that `actions` gives you what you expect on a fresh rule.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET actions for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}]
+ )
+
+ def test_actions_put(self):
+ """
+ Tests that PUT on actions updates the value you'd get from GET.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # change the rule actions
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET actions for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["actions"], ["dont_notify"])
+
+ def test_actions_404_when_get_non_existent(self):
+ """
+ Tests that `actions` gives 404 when the rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check 404 for deleted rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_get_non_existent_server_rule(self):
+ """
+ Tests that `actions` gives 404 when the server-default rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_put_non_existent_rule(self):
+ """
+ Tests that `actions` gives 404 when putting to a rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_put_non_existent_server_rule(self):
+ """
+ Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/.m.muahahah/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 5ccda8b2bd..0d809d25d5 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -23,14 +23,12 @@ from urllib import parse as urlparse
from mock import Mock
-from twisted.internet import defer
-
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias
+from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
from tests import unittest
@@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
self.hs.get_federation_handler = Mock(return_value=Mock())
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
self.hs.get_datastore().insert_client_ip = _insert_client_ip
@@ -677,6 +675,92 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
+class RoomJoinRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit(self):
+ """Tests that local joins are actually rate-limited."""
+ for i in range(3):
+ self.helper.create_room_as(self.user_id)
+
+ self.helper.create_room_as(self.user_id, expect_code=429)
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_profile_change(self):
+ """Tests that sending a profile update into all of the user's joined rooms isn't
+ rate-limited by the rate-limiter on joins."""
+
+ # Create and join as many rooms as the rate-limiting config allows in a second.
+ room_ids = [
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ ]
+ # Let some time for the rate-limiter to forget about our multi-join.
+ self.reactor.advance(2)
+ # Add one to make sure we're joined to more rooms than the config allows us to
+ # join in a second.
+ room_ids.append(self.helper.create_room_as(self.user_id))
+
+ # Create a profile for the user, since it hasn't been done on registration.
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.create_profile(UserID.from_string(self.user_id).localpart)
+ )
+
+ # Update the display name for the user.
+ path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
+ request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ # Check that all the rooms have been sent a profile update into.
+ for room_id in room_ids:
+ path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
+ room_id,
+ self.user_id,
+ )
+
+ request, channel = self.make_request("GET", path)
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+ self.assertIn("displayname", channel.json_body)
+ self.assertEquals(channel.json_body["displayname"], "John Doe")
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_idempotent(self):
+ """Tests that the room join endpoints remain idempotent despite rate-limiting
+ on room joins."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # Let's test both paths to be sure.
+ paths_to_test = [
+ "/_matrix/client/r0/rooms/%s/join",
+ "/_matrix/client/r0/join/%s",
+ ]
+
+ for path in paths_to_test:
+ # Make sure we send more requests than the rate-limiting config would allow
+ # if all of these requests ended up joining the user to a room.
+ for i in range(4):
+ request, channel = self.make_request("POST", path % room_id, {})
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
@@ -821,6 +905,7 @@ class RoomMessageListTestCase(RoomBase):
first_token = self.get_success(
store.get_topological_token_for_event(first_event_id)
)
+ first_token_str = self.get_success(first_token.to_string(store))
# Send a second message in the room, which won't be removed, and which we'll
# use as the marker to purge events before.
@@ -828,6 +913,7 @@ class RoomMessageListTestCase(RoomBase):
second_token = self.get_success(
store.get_topological_token_for_event(second_event_id)
)
+ second_token_str = self.get_success(second_token.to_string(store))
# Send a third event in the room to ensure we don't fall under any edge case
# due to our marker being the latest forward extremity in the room.
@@ -837,7 +923,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -852,7 +942,7 @@ class RoomMessageListTestCase(RoomBase):
pagination_handler._purge_history(
purge_id=purge_id,
room_id=self.room_id,
- token=second_token,
+ token=second_token_str,
delete_local_events=True,
)
)
@@ -862,7 +952,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -876,7 +970,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ first_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 18260bb90e..94d2bf2eb1 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_handlers().federation_handler = Mock()
- def get_user_by_access_token(token=None, allow_guest=False):
+ async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
@@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_auth().get_user_by_access_token = get_user_by_access_token
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
hs.get_datastore().insert_client_ip = _insert_client_ip
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 22d734e763..afaf9f7b85 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -30,7 +30,7 @@ from tests.server import make_request, render
@attr.s
-class RestHelper(object):
+class RestHelper:
"""Contains extra helper functions to quickly and clearly perform a given
REST action, which isn't the focus of the test.
"""
@@ -39,7 +39,9 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
- def create_room_as(self, room_creator=None, is_public=True, tok=None):
+ def create_room_as(
+ self, room_creator=None, is_public=True, tok=None, expect_code=200,
+ ):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
@@ -54,9 +56,11 @@ class RestHelper(object):
)
render(request, self.resource, self.hs.get_reactor())
- assert channel.result["code"] == b"200", channel.result
+ assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id
- return channel.json_body["room_id"]
+
+ if expect_code == 200:
+ return channel.json_body["room_id"]
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
@@ -88,7 +92,28 @@ class RestHelper(object):
expect_code=expect_code,
)
- def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
+ def change_membership(
+ self,
+ room: str,
+ src: str,
+ targ: str,
+ membership: str,
+ extra_data: dict = {},
+ tok: Optional[str] = None,
+ expect_code: int = 200,
+ ) -> None:
+ """
+ Send a membership state event into a room.
+
+ Args:
+ room: The ID of the room to send to
+ src: The mxid of the event sender
+ targ: The mxid of the event's target. The state key
+ membership: The type of membership event
+ extra_data: Extra information to include in the content of the event
+ tok: The user access token to use
+ expect_code: The expected HTTP response code
+ """
temp_id = self.auth_user_id
self.auth_user_id = src
@@ -97,6 +122,7 @@ class RestHelper(object):
path = path + "?access_token=%s" % tok
data = {"membership": membership}
+ data.update(extra_data)
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 0a51aeff92..ae2cd67f35 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -19,6 +19,7 @@ import os
import re
from email.parser import Parser
from typing import Optional
+from urllib.parse import urlencode
import pkg_resources
@@ -27,6 +28,7 @@ from synapse.api.constants import LoginType, Membership
from synapse.api.errors import Codes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
+from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest
from tests.unittest import override_config
@@ -70,6 +72,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
+ self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
def test_basic_password_reset(self):
"""Test basic password reset flow
@@ -251,8 +254,32 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Remove the host
path = link.replace("https://example.com", "")
+ # Load the password reset confirmation page
request, channel = self.make_request("GET", path, shorthand=False)
- self.render(request)
+ request.render(self.submit_token_resource)
+ self.pump()
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Now POST to the same endpoint, mimicking the same behaviour as clicking the
+ # password reset confirm button
+
+ # Send arguments as url-encoded form data, matching the template's behaviour
+ form_args = []
+ for key, value_list in request.args.items():
+ for value in value_list:
+ arg = (key, value)
+ form_args.append(arg)
+
+ # Confirm the password reset
+ request, channel = self.make_request(
+ "POST",
+ path,
+ content=urlencode(form_args).encode("utf8"),
+ shorthand=False,
+ content_is_form=True,
+ )
+ request.render(self.submit_token_resource)
+ self.pump()
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
@@ -705,6 +732,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
def test_next_link_domain_whitelist(self):
"""Tests next_link parameters must fit the whitelist if provided"""
+
+ # Ensure not providing a next_link parameter still works
+ self._request_token(
+ "something@example.com", "some_secret", next_link=None, expect_code=200,
+ )
+
self._request_token(
"something@example.com",
"some_secret",
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index e0e9e94fbf..de00350580 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
from synapse.api.errors import Codes
from synapse.rest.client.v2_alpha import filter
@@ -73,8 +75,10 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_get_filter(self):
- filter_id = self.filtering.add_user_filter(
- user_localpart="apple", user_filter=self.EXAMPLE_FILTER
+ filter_id = defer.ensureDeferred(
+ self.filtering.add_user_filter(
+ user_localpart="apple", user_filter=self.EXAMPLE_FILTER
+ )
)
self.reactor.advance(1)
filter_id = filter_id.result
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index ceca4041e1..ecf697e5e0 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -112,8 +112,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
+ @override_config({"enable_registration": False})
def test_POST_disabled_registration(self):
- self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
@@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -182,7 +182,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
new file mode 100644
index 0000000000..5ae72fd008
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import shared_rooms
+
+from tests import unittest
+
+
+class UserSharedRoomsTest(unittest.HomeserverTestCase):
+ """
+ Tests the UserSharedRoomsServlet.
+ """
+
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ shared_rooms.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["update_user_directory"] = True
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def _get_shared_rooms(self, token, other_user):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+ % other_user,
+ access_token=token,
+ )
+ self.render(request)
+ return request, channel
+
+ def test_shared_room_list_public(self):
+ """
+ A room should show up in the shared list of rooms between two users
+ if it is public.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ def test_shared_room_list_private(self):
+ """
+ A room should show up in the shared list of rooms between two users
+ if it is private.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ def test_shared_room_list_mixed(self):
+ """
+ The shared room list between two users should contain both public and private
+ rooms.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token)
+ self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token)
+ self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token)
+ self.helper.join(room_public, user=u2, tok=u2_token)
+ self.helper.join(room_private, user=u1, tok=u1_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 2)
+ self.assertTrue(room_public in channel.json_body["joined"])
+ self.assertTrue(room_private in channel.json_body["joined"])
+
+ def test_shared_room_list_after_leave(self):
+ """
+ A room should no longer be considered shared if the other
+ user has left it.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ # Assert user directory is not empty
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ self.helper.leave(room, user=u1, tok=u1_token)
+
+ request, channel = self._get_shared_rooms(u2_token, u1)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index fa3a3ec1bd..a31e44c97e 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
import json
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v2_alpha import read_marker, sync
from tests import unittest
from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch)
)
self.assertRaises(TimedOutException, self.render, request)
+
+
+class UnreadMessagesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ read_marker.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.url = "/sync?since=%s"
+ self.next_batch = "s0"
+
+ # Register the first user (used to check the unread counts).
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # Create the room we'll check unread counts for.
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ # Register the second user (used to send events to the room).
+ self.user2 = self.register_user("kermit2", "monkey")
+ self.tok2 = self.login("kermit2", "monkey")
+
+ # Change the power levels of the room so that the second user can send state
+ # events.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.PowerLevels,
+ {
+ "users": {self.user_id: 100, self.user2: 100},
+ "users_default": 0,
+ "events": {
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ "m.room.history_visibility": 100,
+ "m.room.canonical_alias": 50,
+ "m.room.avatar": 50,
+ "m.room.tombstone": 100,
+ "m.room.server_acl": 100,
+ "m.room.encryption": 100,
+ },
+ "events_default": 0,
+ "state_default": 50,
+ "ban": 50,
+ "kick": 50,
+ "redact": 50,
+ "invite": 0,
+ },
+ tok=self.tok,
+ )
+
+ def test_unread_counts(self):
+ """Tests that /sync returns the right value for the unread count (MSC2654)."""
+
+ # Check that our own messages don't increase the unread count.
+ self.helper.send(self.room_id, "hello", tok=self.tok)
+ self._check_unread_count(0)
+
+ # Join the new user and check that this doesn't increase the unread count.
+ self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
+ self._check_unread_count(0)
+
+ # Check that the new user sending a message increases our unread count.
+ res = self.helper.send(self.room_id, "hello", tok=self.tok2)
+ self._check_unread_count(1)
+
+ # Send a read receipt to tell the server we've read the latest event.
+ body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/read_markers" % self.room_id,
+ body,
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that the unread counter is back to 0.
+ self._check_unread_count(0)
+
+ # Check that room name changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+ )
+ self._check_unread_count(1)
+
+ # Check that room topic changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+ )
+ self._check_unread_count(2)
+
+ # Check that encrypted messages increase the unread counter.
+ self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
+ self._check_unread_count(3)
+
+ # Check that custom events with a body increase the unread counter.
+ self.helper.send_event(
+ self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that edits don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "body": "hello",
+ "msgtype": "m.text",
+ "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+ },
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that notices don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"body": "hello", "msgtype": "m.notice"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that tombstone events changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.Tombstone,
+ {"replacement_room": "!someroom:test"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(5)
+
+ def _check_unread_count(self, expected_count: True):
+ """Syncs and compares the unread count with the expected value."""
+
+ request, channel = self.make_request(
+ "GET", self.url % self.next_batch, access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ room_entry = channel.json_body["rooms"]["join"][self.room_id]
+ self.assertEqual(
+ room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+ )
+
+ # Store the next batch for the next request.
+ self.next_batch = channel.json_body["next_batch"]
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 99eb477149..6850c666be 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect an outgoing GET request for the given key
"""
- def get_json(destination, path, ignore_backoff=False, **kwargs):
+ async def get_json(destination, path, ignore_backoff=False, **kwargs):
self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
# wire up outbound POST /key/v2/query requests from hs2 so that they
# will be forwarded to hs1
- def post_json(destination, path, data):
+ async def post_json(destination, path, data):
self.assertEqual(destination, self.hs.hostname)
self.assertEqual(
path, "/_matrix/key/v2/query",
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index f4f3e56777..5f897d49cf 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -120,12 +120,13 @@ class _TestImage:
extension = attr.ib(type=bytes)
expected_cropped = attr.ib(type=Optional[bytes])
expected_scaled = attr.ib(type=Optional[bytes])
+ expected_found = attr.ib(default=True, type=bool)
@parameterized_class(
("test_image",),
[
- # smol png
+ # smoll png
(
_TestImage(
unhexlify(
@@ -161,6 +162,8 @@ class _TestImage:
None,
),
),
+ # an empty file
+ (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
@@ -303,12 +306,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
- self._test_thumbnail("crop", self.test_image.expected_cropped)
+ self._test_thumbnail(
+ "crop", self.test_image.expected_cropped, self.test_image.expected_found
+ )
def test_thumbnail_scale(self):
- self._test_thumbnail("scale", self.test_image.expected_scaled)
+ self._test_thumbnail(
+ "scale", self.test_image.expected_scaled, self.test_image.expected_found
+ )
- def _test_thumbnail(self, method, expected_body):
+ def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
request, channel = self.make_request(
"GET", self.media_id + params, shorthand=False
@@ -325,11 +332,23 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
self.pump()
- self.assertEqual(channel.code, 200)
- if expected_body is not None:
+ if expected_found:
+ self.assertEqual(channel.code, 200)
+ if expected_body is not None:
+ self.assertEqual(
+ channel.result["body"], expected_body, channel.result["body"]
+ )
+ else:
+ # ensure that the result is at least some valid image
+ Image.open(BytesIO(channel.result["body"]))
+ else:
+ # A 404 with a JSON body.
+ self.assertEqual(channel.code, 404)
self.assertEqual(
- channel.result["body"], expected_body, channel.result["body"]
+ channel.json_body,
+ {
+ "errcode": "M_NOT_FOUND",
+ "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
+ % method,
+ },
)
- else:
- # ensure that the result is at least some valid image
- Image.open(BytesIO(channel.result["body"]))
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 74765a582b..c00a7b9114 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -32,7 +32,7 @@ from tests.server import FakeTransport
@attr.s
-class FakeResponse(object):
+class FakeResponse:
version = attr.ib()
code = attr.ib()
phrase = attr.ib()
@@ -43,7 +43,7 @@ class FakeResponse(object):
@property
def request(self):
@attr.s
- class FakeTransport(object):
+ class FakeTransport:
absoluteURI = self.absoluteURI
return FakeTransport()
@@ -111,7 +111,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups = {}
- class Resolver(object):
+ class Resolver:
def resolveHostName(
_self,
resolutionReceiver,
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
new file mode 100644
index 0000000000..2d021f6565
--- /dev/null
+++ b/tests/rest/test_health.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.rest.health import HealthResource
+
+from tests import unittest
+
+
+class HealthCheckTests(unittest.HomeserverTestCase):
+ def setUp(self):
+ super().setUp()
+
+ # replace the JsonResource with a HealthResource.
+ self.resource = HealthResource()
+
+ def test_health(self):
+ request, channel = self.make_request("GET", "/health", shorthand=False)
+ self.render(request)
+
+ self.assertEqual(request.code, 200)
+ self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index b090bb974c..dcd65c2a50 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -21,7 +21,7 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
def setUp(self):
- super(WellKnownTests, self).setUp()
+ super().setUp()
# replace the JsonResource with a WellKnownResource
self.resource = WellKnownResource(self.hs)
diff --git a/tests/server.py b/tests/server.py
index b6e0b14e78..b404ad4e2a 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,6 +1,6 @@
import json
import logging
-from io import BytesIO
+from io import SEEK_END, BytesIO
import attr
from zope.interface import implementer
@@ -35,7 +35,7 @@ class TimedOutException(Exception):
@attr.s
-class FakeChannel(object):
+class FakeChannel:
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
@@ -135,6 +135,7 @@ def make_request(
request=SynapseRequest,
shorthand=True,
federation_auth_origin=None,
+ content_is_form=False,
):
"""
Make a web request using the given method and path, feed it the
@@ -150,6 +151,8 @@ def make_request(
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
+ content_is_form: Whether the content is URL encoded form data. Adds the
+ 'Content-Type': 'application/x-www-form-urlencoded' header.
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
@@ -181,6 +184,8 @@ def make_request(
req = request(channel)
req.process = lambda: b""
req.content = BytesIO(content)
+ # Twisted expects to be at the end of the content when parsing the request.
+ req.content.seek(SEEK_END)
req.postpath = list(map(unquote, path[1:].split(b"/")))
if access_token:
@@ -195,7 +200,13 @@ def make_request(
)
if content:
- req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
+ if content_is_form:
+ req.requestHeaders.addRawHeader(
+ b"Content-Type", b"application/x-www-form-urlencoded"
+ )
+ else:
+ # Assume the body is JSON
+ req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
req.requestReceived(method, path, b"1.1")
@@ -242,14 +253,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
lookups = self.lookups = {}
@implementer(IResolverSimple)
- class FakeResolver(object):
+ class FakeResolver:
def getHostByName(self, name, timeout=None):
if name not in lookups:
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
- super(ThreadedMemoryReactorClock, self).__init__()
+ super().__init__()
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
p = udp.Port(port, protocol, interface, maxPacketSize, self)
@@ -371,7 +382,7 @@ def get_clock():
@attr.s(cmp=False)
-class FakeTransport(object):
+class FakeTransport:
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
straight into an IProtocol object: it exists to connect two IProtocols together.
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 99908edba3..6382b19dc3 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -27,6 +27,7 @@ from synapse.server_notices.resource_limits_server_notices import (
)
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import default_config
@@ -66,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
self._rlsn._store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(1000)
+ return_value=make_awaitable(1000)
)
self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock())
@@ -79,7 +80,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=defer.succeed("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
- self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
+ self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self):
@@ -101,7 +102,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
@@ -119,7 +120,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -155,7 +156,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(None)
+ return_value=make_awaitable(None)
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -214,7 +215,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -258,10 +259,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self):
- self.store.get_monthly_active_count = Mock(return_value=1000)
+ self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
self.store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(1000)
+ return_value=make_awaitable(1000)
)
# Call the function multiple times to ensure we only send the notice once
@@ -275,7 +276,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id)
)
- token = self.get_success(self.event_source.get_current_token())
+ token = self.event_source.get_current_token()
events, _ = self.get_success(
self.store.get_recent_events_for_room(
room_id, limit=100, end_token=token.room_key
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index f2955a9c69..ad9bbef9d2 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -49,7 +49,7 @@ class FakeClock:
return defer.succeed(None)
-class FakeEvent(object):
+class FakeEvent:
"""A fake event we use as a convenience.
NOTE: Again as a convenience we use "node_ids" rather than event_ids to
@@ -595,7 +595,7 @@ def pairwise(iterable):
@attr.s
-class TestStateResolutionStore(object):
+class TestStateResolutionStore:
event_map = attr.ib()
def get_events(self, event_ids, allow_rejected=False):
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 5a50e4fdd4..f5afed017c 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -99,7 +99,7 @@ class CacheTestCase(unittest.HomeserverTestCase):
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def test_passthrough(self):
- class A(object):
+ class A:
@cached()
def func(self, key):
return key
@@ -113,7 +113,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_hit(self):
callcount = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -131,7 +131,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_invalidate(self):
callcount = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -149,7 +149,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
- class A(object):
+ class A:
@cached()
def func(self, key):
return key
@@ -160,7 +160,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_max_entries(self):
callcount = [0]
- class A(object):
+ class A:
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
@@ -187,7 +187,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
d = defer.succeed(123)
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -205,7 +205,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0]
callcount2 = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -238,7 +238,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0]
callcount2 = [0]
- class A(object):
+ class A:
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
@@ -275,7 +275,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0]
callcount2 = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
self.table_name = "table_" + hs.get_secrets().token_hex(6)
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"create",
lambda x, *a: x.execute(*a),
"CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"index",
lambda x, *a: x.execute(*a),
"CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["hello"], ["there"]]
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"test",
- self.storage.db.simple_upsert_many_txn,
+ self.storage.db_pool.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage.db.simple_select_list(
+ self.storage.db_pool.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["bleb"]]
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"test",
- self.storage.db.simple_upsert_many_txn,
+ self.storage.db_pool.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage.db.simple_select_list(
+ self.storage.db_pool.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ef296e7dab..46f94914ff 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,13 +24,14 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
-from synapse.storage.database import Database, make_conn
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -178,14 +179,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_appservice_state_none(self):
service = Mock(id="999")
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(None, state)
@defer.inlineCallbacks
def test_get_appservice_state_up(self):
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
service = Mock(id=self.as_list[0]["id"])
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.UP, state)
@defer.inlineCallbacks
@@ -194,20 +195,22 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
service = Mock(id=self.as_list[1]["id"])
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.DOWN, state)
@defer.inlineCallbacks
def test_get_appservices_by_state_none(self):
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(0, len(services))
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +222,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -234,7 +243,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def test_create_appservice_txn_first(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -246,7 +257,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_last_txn(service.id, 9643) # AS is falling behind
yield self._insert_txn(service.id, 9644, events)
yield self._insert_txn(service.id, 9645, events)
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -256,7 +269,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643)
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -277,7 +292,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(self.as_list[2]["id"], 10, events)
yield self._insert_txn(self.as_list[3]["id"], 9643, events)
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -289,7 +306,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
txn_id = 1
yield self._insert_txn(service.id, txn_id, events)
- yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ yield defer.ensureDeferred(
+ self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ )
res = yield self.db_pool.runQuery(
self.engine.convert_param_style(
@@ -315,7 +334,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
txn_id = 5
yield self._set_last_txn(service.id, 4)
yield self._insert_txn(service.id, txn_id, events)
- yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ yield defer.ensureDeferred(
+ self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ )
res = yield self.db_pool.runQuery(
self.engine.convert_param_style(
@@ -339,7 +360,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def test_get_oldest_unsent_txn_none(self):
service = Mock(id=self.as_list[0]["id"])
- txn = yield self.store.get_oldest_unsent_txn(service)
+ txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(None, txn)
@defer.inlineCallbacks
@@ -349,14 +370,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- self.store.get_events_as_list = Mock(return_value=events)
+ self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
yield self._insert_txn(service.id, 11, other_events)
yield self._insert_txn(service.id, 12, other_events)
- txn = yield self.store.get_oldest_unsent_txn(service)
+ txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(service, txn.service)
self.assertEquals(10, txn.id)
self.assertEquals(events, txn.events)
@@ -366,8 +387,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(1, len(services))
self.assertEquals(self.as_list[0]["id"], services[0].id)
@@ -379,8 +400,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(2, len(services))
self.assertEquals(
@@ -391,8 +412,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, database: Database, db_conn, hs):
- super(TestTransactionStore, self).__init__(database, db_conn, hs)
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 940b166129..02aae1c13d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,7 +1,5 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest
@@ -9,7 +7,9 @@ from tests import unittest
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
- self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater
+ self.updates = (
+ self.hs.get_datastore().db_pool.updates
+ ) # type: BackgroundUpdater
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
@@ -29,18 +29,17 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore()
self.get_success(
- store.db.simple_insert(
+ store.db_pool.simple_insert(
"background_updates",
values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
)
)
# first step: make a bit of progress
- @defer.inlineCallbacks
- def update(progress, count):
- yield self.clock.sleep((count * duration_ms) / 1000)
+ async def update(progress, count):
+ await self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
- yield store.db.runInteraction(
+ await store.db_pool.runInteraction(
"update_progress",
self.updates._background_update_progress_txn,
"test_update",
@@ -65,13 +64,12 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# second step: complete the update
# we should now get run with a much bigger number of items to update
- @defer.inlineCallbacks
- def update(progress, count):
+ async def update(progress, count):
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
count, target_background_update_duration_ms / duration_ms, places=0,
)
- yield self.updates._end_background_update("test_update")
+ await self.updates._end_background_update("test_update")
return count
self.update_handler.side_effect = update
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index b589506c60..eac7e4dcd2 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,7 +21,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.engines import create_engine
from tests import unittest
@@ -56,8 +56,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
engine = create_engine(sqlite_config)
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
+ fake_engine.in_transaction.return_value = False
- db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+ db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
self.datastore = SQLBaseStore(db, None, hs)
@@ -66,8 +67,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_insert(
- table="tablename", values={"columname": "Value"}
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename", values={"columname": "Value"}
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -78,10 +81,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_insert(
- table="tablename",
- # Use OrderedDict() so we can assert on the SQL generated
- values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename",
+ # Use OrderedDict() so we can assert on the SQL generated
+ values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -93,8 +98,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
- value = yield self.datastore.db.simple_select_one_onecol(
- table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ value = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one_onecol(
+ table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ )
)
self.assertEquals("Value", value)
@@ -107,10 +114,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
- ret = yield self.datastore.db.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "TheKey"},
- retcols=["colA", "colB", "colC"],
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "TheKey"},
+ retcols=["colA", "colB", "colC"],
+ )
)
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@@ -123,11 +132,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
- ret = yield self.datastore.db.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "Not here"},
- retcols=["colA"],
- allow_none=True,
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "Not here"},
+ retcols=["colA"],
+ allow_none=True,
+ )
)
self.assertFalse(ret)
@@ -138,8 +149,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
- ret = yield self.datastore.db.simple_select_list(
- table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_list(
+ table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+ )
)
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
@@ -151,10 +164,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_update_one(
- table="tablename",
- keyvalues={"keycol": "TheKey"},
- updatevalues={"columnname": "New Value"},
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_update_one(
+ table="tablename",
+ keyvalues={"keycol": "TheKey"},
+ updatevalues={"columnname": "New Value"},
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -166,10 +181,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_4cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_update_one(
- table="tablename",
- keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
- updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_update_one(
+ table="tablename",
+ keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
+ updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -181,8 +198,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_delete_one(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_delete_one(
- table="tablename", keyvalues={"keycol": "Go away"}
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_delete_one(
+ table="tablename", keyvalues={"keycol": "Go away"}
+ )
)
self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 43425c969a..080761d1d2 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
@@ -47,12 +47,12 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
"""
# Make sure we don't clash with in progress updates.
self.assertTrue(
- self.store.db.updates._all_done, "Background updates are still ongoing"
+ self.store.db_pool.updates._all_done, "Background updates are still ongoing"
)
schema_path = os.path.join(
prepare_database.dir_path,
- "data_stores",
+ "databases",
"main",
"schema",
"delta",
@@ -64,19 +64,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
prepare_database.executescript(txn, schema_path)
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"test_delete_forward_extremities", run_delta_file
)
)
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
def test_soft_failed_extremities_handled_correctly(self):
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
@@ -271,7 +271,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Pump the reactor repeatedly so that the background updates have a
# chance to run.
- self.pump(10 * 60)
+ self.pump(20)
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
"3"
] = 300000
+
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
# All entries within time frame
self.assertEqual(
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
3,
)
# Oldest room to expire
- self.pump(1)
+ self.pump(1.01)
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
self.assertEqual(
len(
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 3b483bc7f0..755c70db31 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -16,13 +16,12 @@
from mock import Mock
-from twisted.internet import defer
-
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -86,7 +85,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -117,7 +116,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -155,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
user_id = "@user:server"
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(lots_of_users)
+ return_value=make_awaitable(lots_of_users)
)
self.get_success(
self.store.insert_client_ip(
@@ -204,10 +203,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_devices_last_seen_bg_update(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
user_id = "@user:id"
@@ -225,7 +224,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# But clear the associated entry in devices table
self.get_success(
- self.store.db.simple_update(
+ self.store.db_pool.simple_update(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={"last_seen": None, "ip": None, "user_agent": None},
@@ -252,7 +251,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# Register the background update to run again.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
table="background_updates",
values={
"update_name": "devices_last_seen",
@@ -263,14 +262,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# We should now get the correct result again
@@ -293,10 +292,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_old_user_ips_pruned(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
user_id = "@user:id"
@@ -315,7 +314,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should see that in the DB
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -341,7 +340,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should get no results.
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index c2539b353a..ecb00f4e02 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -23,7 +23,7 @@ import tests.utils
class DeviceStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(DeviceStoreTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
@@ -34,9 +34,11 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_store_new_device(self):
- yield self.store.store_device("user_id", "device_id", "display_name")
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device_id", "display_name")
+ )
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_get_devices_by_user(self):
- yield self.store.store_device("user_id", "device1", "display_name 1")
- yield self.store.store_device("user_id", "device2", "display_name 2")
- yield self.store.store_device("user_id2", "device3", "display_name 3")
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device1", "display_name 1")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device2", "display_name 2")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id2", "device3", "display_name 3")
+ )
- res = yield self.store.get_devices_by_user("user_id")
+ res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset(
{
@@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids, ["somehost"]
+ yield defer.ensureDeferred(
+ self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "somehost", -1, limit=100
+ now_stream_id, device_updates = yield defer.ensureDeferred(
+ self.store.get_device_updates_by_remote("somehost", -1, limit=100)
)
# Check original device_ids are contained within these updates
@@ -99,29 +107,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_update_device(self):
- yield self.store.store_device("user_id", "device_id", "display_name 1")
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device_id", "display_name 1")
+ )
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
- yield self.store.update_device("user_id", "device_id")
- res = yield self.store.get_device("user_id", "device_id")
+ yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
- yield self.store.update_device(
- "user_id", "device_id", new_display_name="display_name 2"
+ yield defer.ensureDeferred(
+ self.store.update_device(
+ "user_id", "device_id", new_display_name="display_name 2"
+ )
)
# check it worked
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm:
- yield self.store.update_device(
- "user_id", "unknown_device_id", new_display_name="display_name 2"
+ yield defer.ensureDeferred(
+ self.store.update_device(
+ "user_id", "unknown_device_id", new_display_name="display_name 2"
+ )
)
self.assertEqual(404, cm.exception.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 4e128e1047..da93ca3980 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -34,35 +34,53 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_room_to_alias(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
self.assertEquals(
["#my-room:test"],
- (yield self.store.get_aliases_for_room(self.room.to_string())),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_aliases_for_room(self.room.to_string())
+ )
+ ),
)
@defer.inlineCallbacks
def test_alias_to_room(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]},
- (yield self.store.get_association_from_room_alias(self.alias)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_association_from_room_alias(self.alias)
+ )
+ ),
)
@defer.inlineCallbacks
def test_delete_alias(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
- room_id = yield self.store.delete_room_alias(self.alias)
+ room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
- (yield self.store.get_association_from_room_alias(self.alias))
+ (
+ yield defer.ensureDeferred(
+ self.store.get_association_from_room_alias(self.alias)
+ )
+ )
)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 398d546280..3fc4bb13b6 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -30,11 +30,15 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070
json = {"key": "value"}
- yield self.store.store_device("user", "device", None)
+ yield defer.ensureDeferred(self.store.store_device("user", "device", None))
- yield self.store.set_e2e_device_keys("user", "device", now, json)
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
- res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
+ )
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
@@ -45,14 +49,18 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070
json = {"key": "value"}
- yield self.store.store_device("user", "device", None)
+ yield defer.ensureDeferred(self.store.store_device("user", "device", None))
- changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+ changed = yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing
# changed
- changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+ changed = yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
self.assertFalse(changed)
@defer.inlineCallbacks
@@ -60,10 +68,16 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070
json = {"key": "value"}
- yield self.store.set_e2e_device_keys("user", "device", now, json)
- yield self.store.store_device("user", "device", "display_name")
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user", "device", "display_name")
+ )
- res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
+ )
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
@@ -75,18 +89,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def test_multiple_devices(self):
now = 1470174257070
- yield self.store.store_device("user1", "device1", None)
- yield self.store.store_device("user1", "device2", None)
- yield self.store.store_device("user2", "device1", None)
- yield self.store.store_device("user2", "device2", None)
+ yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
+ yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
+ yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
+ yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
- yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
- yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
- yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
- yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
+ )
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
+ )
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
+ )
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+ )
- res = yield self.store.get_e2e_device_keys(
- (("user1", "device1"), ("user2", "device2"))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys_for_cs_api(
+ (("user1", "device1"), ("user2", "device2"))
+ )
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 3aeec0dc0f..d4c3b867e3 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -56,7 +56,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
for i in range(0, 20):
- self.get_success(self.store.db.runInteraction("insert", insert_event, i))
+ self.get_success(
+ self.store.db_pool.runInteraction("insert", insert_event, i)
+ )
# this should get the last ten
r = self.get_success(self.store.get_prev_events_for_room(room_id))
@@ -81,13 +83,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for i in range(0, 20):
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room1)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room1)
)
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room2)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room2)
)
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room3)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room3)
)
# Test simple case
@@ -164,7 +166,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
depth = depth_map[event_id]
- self.store.db.simple_insert_txn(
+ self.store.db_pool.simple_insert_txn(
txn,
table="events",
values={
@@ -179,7 +181,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
- self.store.db.simple_insert_many_txn(
+ self.store.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
values=[
@@ -192,7 +194,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for event_id in auth_graph:
next_stream_ordering += 1
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"insert", insert_event, event_id, next_stream_ordering
)
)
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a7b85004e5..3957471f3f 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
# Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)]
@@ -52,14 +52,14 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
self.reactor.advance(60 * 60 * 1000)
self.pump(1)
- items = set(
+ items = list(
filter(
lambda x: b"synapse_forward_extremities_" in x,
- generate_latest(REGISTRY).split(b"\n"),
+ generate_latest(REGISTRY, emit_help=False).split(b"\n"),
)
)
- expected = {
+ expected = [
b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
@@ -72,9 +72,12 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
- b"synapse_forward_extremities_count 3.0",
- b"synapse_forward_extremities_sum 10.0",
- }
-
+ # per https://docs.google.com/document/d/1KwV0mAXwwbvvifBvDKH_LU1YjyXE_wxCkHNoCGq1GX0/edit#heading=h.wghdjzzh72j9,
+ # "inf" is valid: "this includes variants such as inf"
+ b'synapse_forward_extremities_bucket{le="inf"} 3.0',
+ b"# TYPE synapse_forward_extremities_gcount gauge",
+ b"synapse_forward_extremities_gcount 3.0",
+ b"# TYPE synapse_forward_extremities_gsum gauge",
+ b"synapse_forward_extremities_gsum 10.0",
+ ]
self.assertEqual(items, expected)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b45bc9c115..c0595963dd 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
- yield self.store.get_unread_push_actions_for_user_in_range_for_http(
- USER_ID, 0, 1000, 20
+ yield defer.ensureDeferred(
+ self.store.get_unread_push_actions_for_user_in_range_for_http(
+ USER_ID, 0, 1000, 20
+ )
)
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self):
- yield self.store.get_unread_push_actions_for_user_in_range_for_email(
- USER_ID, 0, 1000, 20
+ yield defer.ensureDeferred(
+ self.store.get_unread_push_actions_for_user_in_range_for_email(
+ USER_ID, 0, 1000, 20
+ )
)
@defer.inlineCallbacks
@@ -56,12 +60,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield self.store.db.runInteraction(
- "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+ counts = yield defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+ )
)
self.assertEquals(
counts,
- {"notify_count": noitf_count, "highlight_count": highlight_count},
+ {
+ "notify_count": noitf_count,
+ "unread_count": 0, # Unread counts are tested in the sync tests.
+ "highlight_count": highlight_count,
+ },
)
@defer.inlineCallbacks
@@ -72,28 +82,36 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
- yield self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}
+ yield defer.ensureDeferred(
+ self.store.add_push_actions_to_staging(
+ event.event_id, {user_id: action}, False,
+ )
)
- yield self.store.db.runInteraction(
- "",
- self.persist_events_store._set_push_actions_for_event_and_users_txn,
- [(event, None)],
- [(event, None)],
+ yield defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "",
+ self.persist_events_store._set_push_actions_for_event_and_users_txn,
+ [(event, None)],
+ [(event, None)],
+ )
)
def _rotate(stream):
- return self.store.db.runInteraction(
- "", self.store._rotate_notifs_before_txn, stream
+ return defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "", self.store._rotate_notifs_before_txn, stream
+ )
)
def _mark_read(stream, depth):
- return self.store.db.runInteraction(
- "",
- self.store._remove_old_push_actions_before_txn,
- room_id,
- user_id,
- stream,
+ return defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "",
+ self.store._remove_old_push_actions_before_txn,
+ room_id,
+ user_id,
+ stream,
+ )
)
yield _assert_counts(0, 0)
@@ -117,8 +135,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
- yield self.store.db.simple_delete(
- table="event_push_actions", keyvalues={"1": 1}, desc=""
+ yield defer.ensureDeferred(
+ self.store.db_pool.simple_delete(
+ table="event_push_actions", keyvalues={"1": 1}, desc=""
+ )
)
yield _assert_counts(1, 0)
@@ -136,33 +156,43 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return self.store.db.simple_insert(
- "events",
- {
- "stream_ordering": so,
- "received_ts": ts,
- "event_id": "event%i" % so,
- "type": "",
- "room_id": "",
- "content": "",
- "processed": True,
- "outlier": False,
- "topological_ordering": 0,
- "depth": 0,
- },
+ return defer.ensureDeferred(
+ self.store.db_pool.simple_insert(
+ "events",
+ {
+ "stream_ordering": so,
+ "received_ts": ts,
+ "event_id": "event%i" % so,
+ "type": "",
+ "room_id": "",
+ "content": "",
+ "processed": True,
+ "outlier": False,
+ "topological_ordering": 0,
+ "depth": 0,
+ },
+ )
)
# start with the base case where there are no events in the table
- r = yield self.store.find_first_stream_ordering_after_ts(11)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(11)
+ )
self.assertEqual(r, 0)
# now with one event
yield add_event(2, 10)
- r = yield self.store.find_first_stream_ordering_after_ts(9)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(9)
+ )
self.assertEqual(r, 2)
- r = yield self.store.find_first_stream_ordering_after_ts(10)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(10)
+ )
self.assertEqual(r, 2)
- r = yield self.store.find_first_stream_ordering_after_ts(11)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(11)
+ )
self.assertEqual(r, 3)
# add a bunch of dummy events to the events table
@@ -175,25 +205,37 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
):
yield add_event(stream_ordering, ts)
- r = yield self.store.find_first_stream_ordering_after_ts(110)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(110)
+ )
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5
- r = yield self.store.find_first_stream_ordering_after_ts(120)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(120)
+ )
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
- r = yield self.store.find_first_stream_ordering_after_ts(129)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(129)
+ )
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event
- r = yield self.store.find_first_stream_ordering_after_ts(140)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(140)
+ )
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end
- r = yield self.store.find_first_stream_ordering_after_ts(160)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(160)
+ )
self.assertEqual(r, 21)
# check we can find an event at ordering zero
yield add_event(0, 5)
- r = yield self.store.find_first_stream_ordering_after_ts(1)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(1)
+ )
self.assertEqual(r, 0)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 55e9ecf264..392b08832b 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.unittest import HomeserverTestCase
@@ -27,9 +26,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.db = self.store.db # type: Database
+ self.db_pool = self.store.db_pool # type: DatabasePool
- self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
@@ -43,29 +42,64 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
)
- def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
- self.db,
+ self.db_pool,
+ stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
+ writers=writers,
)
- return self.get_success(self.db.runWithConnection(_create))
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
+ """Insert N rows as the given instance, inserting with stream IDs pulled
+ from the postgres sequence.
+ """
+
def _insert(txn):
for _ in range(number):
txn.execute(
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
(instance_name,),
)
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+ """,
+ (instance_name,),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
- self.get_success(self.db.runInteraction("test_single_instance", _insert))
+ def _insert_row_with_id(self, instance_name: str, stream_id: int):
+ """Insert one row as the given instance with given stream_id, updating
+ the postgres sequence position to match.
+ """
+
+ def _insert(txn):
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ )
+ txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, stream_id, stream_id),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
def test_empty(self):
"""Test an ID generator against an empty database gives sensible
@@ -88,22 +122,72 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
async def _get_next_async():
- with await id_gen.get_next() as stream_id:
+ async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+
+ def test_out_of_order_finish(self):
+ """Test that IDs persisted out of order are correctly handled
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ ctx1 = self.get_success(id_gen.get_next())
+ ctx2 = self.get_success(id_gen.get_next())
+ ctx3 = self.get_success(id_gen.get_next())
+ ctx4 = self.get_success(id_gen.get_next())
+
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
+ s3 = self.get_success(ctx3.__aenter__())
+ s4 = self.get_success(ctx4.__aenter__())
+
+ self.assertEqual(s1, 8)
+ self.assertEqual(s2, 9)
+ self.assertEqual(s3, 10)
+ self.assertEqual(s4, 11)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ self.get_success(ctx2.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ self.get_success(ctx1.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
+
+ self.get_success(ctx4.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
+
+ self.get_success(ctx3.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 11})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
@@ -112,18 +196,18 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_rows("first", 3)
self._insert_rows("second", 4)
- first_id_gen = self._create_id_generator("first")
- second_id_gen = self._create_id_generator("second")
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token("first"), 3)
- self.assertEqual(first_id_gen.get_current_token("second"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
async def _get_next_async():
- with await first_id_gen.get_next() as stream_id:
+ async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(
@@ -141,7 +225,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# stream ID
async def _get_next_async():
- with await second_id_gen.get_next() as stream_id:
+ async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9)
self.assertEqual(
@@ -166,7 +250,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -176,9 +260,315 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- self.get_success(self.db.runInteraction("test", _get_next_txn))
+ self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+
+ def test_get_persisted_upto_position(self):
+ """Test that `get_persisted_upto_position` correctly tracks updates to
+ positions.
+ """
+
+ # The following tests are a bit cheeky in that we notify about new
+ # positions via `advance` without *actually* advancing the postgres
+ # sequence.
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+ # Min is 3 and there is a gap between 5, so we expect it to be 3.
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ # We advance "first" straight to 6. Min is now 5 but there is no gap so
+ # we expect it to be 6
+ id_gen.advance("first", 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+ # No gap, so we expect 7.
+ id_gen.advance("second", 7)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # We haven't seen 8 yet, so we expect 7 still.
+ id_gen.advance("second", 9)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # Now that we've seen 7, 8 and 9 we can got straight to 9.
+ id_gen.advance("first", 8)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 9)
+
+ # Jump forward with gaps. The minimum is 11, even though we haven't seen
+ # 10 we know that everything before 11 must be persisted.
+ id_gen.advance("first", 11)
+ id_gen.advance("second", 15)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 11)
+
+ def test_get_persisted_upto_position_get_next(self):
+ """Test that `get_persisted_upto_position` correctly tracks updates to
+ positions when `get_next` is called.
+ """
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ async def _get_next_async():
+ async with id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+ # We assume that so long as `get_next` does correctly advance the
+ # `persisted_upto_position` in this case, then it will be correct in the
+ # other cases that are tested above (since they'll hit the same code).
+
+ def test_restart_during_out_of_order_persistence(self):
+ """Test that restarting a process while another process is writing out
+ of order updates are handled correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # Persist two rows at once
+ ctx1 = self.get_success(id_gen.get_next())
+ ctx2 = self.get_success(id_gen.get_next())
+
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
+
+ self.assertEqual(s1, 8)
+ self.assertEqual(s2, 9)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # We finish persisting the second row before restart
+ self.get_success(ctx2.__aexit__(None, None, None))
+
+ # We simulate a restart of another worker by just creating a new ID gen.
+ id_gen_worker = self._create_id_generator("worker")
+
+ # Restarted worker should not see the second persisted row
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
+ self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
+
+ # Now if we persist the first row then both instances should jump ahead
+ # correctly.
+ self.get_success(ctx1.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ id_gen_worker.advance("master", 9)
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
+
+ def test_writer_config_change(self):
+ """Test that changing the writer config correctly works.
+ """
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ # Initial config has two writers
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
+
+ # New config removes one of the configs. Note that if the writer is
+ # removed from config we assume that it has been shut down and has
+ # finished persisting, hence why the persisted upto position is 5.
+ id_gen_2 = self._create_id_generator("second", writers=["second"])
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), 5)
+ self.assertEqual(id_gen_2.get_current_token_for_writer("second"), 5)
+
+ # This config points to a single, previously unused writer.
+ id_gen_3 = self._create_id_generator("third", writers=["third"])
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 5)
+
+ # For new writers we assume their initial position to be the current
+ # persisted up to position. This stops Synapse from doing a full table
+ # scan when a new writer comes along.
+ self.assertEqual(id_gen_3.get_current_token_for_writer("third"), 5)
+
+ id_gen_4 = self._create_id_generator("fourth", writers=["third"])
+ self.assertEqual(id_gen_4.get_current_token_for_writer("third"), 5)
+
+ # Check that we get a sane next stream ID with this new config.
+
+ async def _get_next_async():
+ async with id_gen_3.get_next() as stream_id:
+ self.assertEqual(stream_id, 6)
+
+ self.get_success(_get_next_async())
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
+
+ # If we add back the old "first" then we shouldn't see the persisted up
+ # to position revert back to 3.
+ id_gen_5 = self._create_id_generator("five", writers=["first", "third"])
+ self.assertEqual(id_gen_5.get_persisted_upto_position(), 6)
+ self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
+ self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
+
+ def test_sequence_consistency(self):
+ """Test that we error out if the table and sequence diverges.
+ """
+
+ # Prefill with some rows
+ self._insert_row_with_id("master", 3)
+
+ # Now we add a row *without* updating the stream ID
+ def _insert(txn):
+ txn.execute("INSERT INTO foobar VALUES (26, 'master')")
+
+ self.get_success(self.db_pool.runInteraction("_insert", _insert))
+
+ # Creating the ID gen should error
+ with self.assertRaises(IncorrectDatabaseSetup):
+ self._create_id_generator("first")
+
+
+class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
+ """
+
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db_pool = self.store.db_pool # type: DatabasePool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db_pool,
+ stream_name="test_stream",
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ writers=writers,
+ positive=False,
+ )
+
+ return self.get_success(self.db_pool.runWithConnection(_create))
+
+ def _insert_row(self, instance_name: str, stream_id: int):
+ """Insert one row as the given instance with given stream_id.
+ """
+
+ def _insert(txn):
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ )
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, -stream_id, -stream_id),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+ id_gen = self._create_id_generator()
+
+ async def _get_next_async():
+ async with id_gen.get_next() as stream_id:
+ self._insert_row("master", stream_id)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(id_gen.get_positions(), {"master": -1})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -1)
+
+ async def _get_next_async2():
+ async with id_gen.get_next_mult(3) as stream_ids:
+ for stream_id in stream_ids:
+ self._insert_row("master", stream_id)
+
+ self.get_success(_get_next_async2())
+
+ self.assertEqual(id_gen.get_positions(), {"master": -4})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -4)
+
+ # Test loading from DB by creating a second ID gen
+ second_id_gen = self._create_id_generator()
+
+ self.assertEqual(second_id_gen.get_positions(), {"master": -4})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
+
+ def test_multiple_instance(self):
+ """Tests that having multiple instances that get advanced over
+ federation works corretly.
+ """
+ id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
+ id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
+
+ async def _get_next_async():
+ async with id_gen_1.get_next() as stream_id:
+ self._insert_row("first", stream_id)
+ id_gen_2.advance("first", stream_id)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
+
+ async def _get_next_async2():
+ async with id_gen_2.get_next() as stream_id:
+ self._insert_row("second", stream_id)
+ id_gen_1.advance("second", stream_id)
+
+ self.get_success(_get_next_async2())
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 0155ffd04e..fe37d2ed5a 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -34,14 +34,16 @@ class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_users_paginate(self):
- yield self.store.register_user(self.user.to_string(), "pass")
- yield self.store.create_profile(self.user.localpart)
- yield self.store.set_profile_displayname(
- self.user.localpart, self.displayname, 1
+ yield defer.ensureDeferred(
+ self.store.register_user(self.user.to_string(), "pass")
+ )
+ yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.user.localpart, self.displayname, 1)
)
- users, total = yield self.store.get_users_paginate(
- 0, 10, name="bc", guests=False
+ users, total = yield defer.ensureDeferred(
+ self.store.get_users_paginate(0, 10, name="bc", guests=False)
)
self.assertEquals(1, total)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9c04e92577..8d97b6d4cd 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
@@ -78,7 +79,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# XXX why are we doing this here? this function is only run at startup
# so it is odd to re-run it here.
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
)
@@ -136,6 +137,21 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 1)
+ def test_appservice_user_not_counted_in_mau(self):
+ self.get_success(
+ self.store.register_user(
+ user_id="@appservice_user:server", appservice_id="wibble"
+ )
+ )
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
+ d = self.store.upsert_monthly_active_user("@appservice_user:server")
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
@@ -204,7 +220,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_add_threepid(user, "email", email, now, now)
)
- d = self.store.db.runInteraction(
+ d = self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.get_success(d)
@@ -230,7 +246,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
d = self.store.populate_monthly_active_users(user_id)
self.get_success(d)
@@ -238,7 +254,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
@@ -251,7 +267,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock(
@@ -280,7 +296,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
]
self.hs.config.mau_limits_reserved_threepids = threepids
- d = self.store.db.runInteraction(
+ d = self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.get_success(d)
@@ -293,8 +309,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.register_user(user_id=user2, password_hash=None))
now = int(self.hs.get_clock().time_msec())
- self.store.user_add_threepid(user1, "email", user1_email, now, now)
- self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ self.get_success(
+ self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ )
+ self.get_success(
+ self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ )
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), len(threepids))
@@ -333,7 +353,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
@@ -378,7 +398,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(count, 4)
+ self.assertEqual(count, 1)
d = self.store.get_monthly_active_count_by_service()
result = self.get_success(d)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 7458a37e54..7a38022e71 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,19 +33,36 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
+ )
self.assertEquals(
- "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
+ "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.u_frank.localpart)
+ )
+ ),
)
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here", 1
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.u_frank.localpart, "http://my.site/here", 1
+ )
)
self.assertEquals(
"http://my.site/here",
- (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.u_frank.localpart)
+ )
+ ),
)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index b9fafaa1a6..cc1f3c53c5 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
+from synapse.api.errors import NotFoundError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -44,28 +47,22 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
- event = store.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
+ token = self.get_success(
+ store.get_topological_token_for_event(last["event_id"])
+ )
+ token_str = self.get_success(token.to_string(self.hs.get_datastore()))
# Purge everything before this topological token
- purge = storage.purge_events.purge_history(self.room_id, event, True)
- self.pump()
- self.assertEqual(self.successResultOf(purge), None)
-
- # Try and get the events
- get_first = store.get_event(first["event_id"])
- get_second = store.get_event(second["event_id"])
- get_third = store.get_event(third["event_id"])
- get_last = store.get_event(last["event_id"])
- self.pump()
+ self.get_success(
+ storage.purge_events.purge_history(self.room_id, token_str, True)
+ )
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
- self.failureResultOf(get_first)
- self.failureResultOf(get_second)
- self.failureResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_failure(store.get_event(first["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(second["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(third["event_id"]), NotFoundError)
+ self.get_success(store.get_event(last["event_id"]))
def test_purge_wont_delete_extrems(self):
"""
@@ -80,28 +77,19 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
- event = storage.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
- event = "t{}-{}".format(
- *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
+ token = self.get_success(
+ storage.get_topological_token_for_event(last["event_id"])
)
+ event = "t{}-{}".format(token.topological + 1, token.stream + 1)
# Purge everything before this topological token
- purge = storage.purge_history(self.room_id, event, True)
+ purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
self.pump()
f = self.failureResultOf(purge)
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
- get_first = storage.get_event(first["event_id"])
- get_second = storage.get_event(second["event_id"])
- get_third = storage.get_event(third["event_id"])
- get_last = storage.get_event(last["event_id"])
- self.pump()
-
- # Nothing is deleted.
- self.successResultOf(get_first)
- self.successResultOf(get_second)
- self.successResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_success(storage.get_event(first["event_id"]))
+ self.get_success(storage.get_event(second["event_id"]))
+ self.get_success(storage.get_event(third["event_id"]))
+ self.get_success(storage.get_event(last["event_id"]))
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index db3667dc43..1ea35d60c1 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def build(self, prev_event_ids):
- built_event = yield self._base_builder.build(prev_event_ids)
+ built_event = yield defer.ensureDeferred(
+ self._base_builder.build(prev_event_ids)
+ )
built_event._event_id = self._event_id
built_event._dict["event_id"] = self._event_id
@@ -249,6 +251,10 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def room_id(self):
return self._base_builder.room_id
+ @property
+ def type(self):
+ return self._base_builder.type
+
event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
@@ -341,7 +347,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
event_json = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
@@ -359,7 +365,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.reactor.advance(60 * 60 * 2)
event_json = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 71a40a0a49..6b582771fe 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,6 +17,7 @@
from twisted.internet import defer
from synapse.api.constants import UserTypes
+from synapse.api.errors import ThreepidValidationError
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -36,7 +37,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_register(self):
- yield self.store.register_user(self.user_id, self.pwhash)
+ yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals(
{
@@ -52,17 +53,21 @@ class RegistrationStoreTestCase(unittest.TestCase):
"user_type": None,
"deactivated": 0,
},
- (yield self.store.get_user_by_id(self.user_id)),
+ (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
)
@defer.inlineCallbacks
def test_add_tokens(self):
- yield self.store.register_user(self.user_id, self.pwhash)
- yield self.store.add_access_token_to_user(
- self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+ yield defer.ensureDeferred(
+ self.store.add_access_token_to_user(
+ self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ )
)
- result = yield self.store.get_user_by_access_token(self.tokens[1])
+ result = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[1])
+ )
self.assertDictContainsSubset(
{"name": self.user_id, "device_id": self.device_id}, result
@@ -73,31 +78,41 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
- yield self.store.register_user(self.user_id, self.pwhash)
- yield self.store.add_access_token_to_user(
- self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+ yield defer.ensureDeferred(
+ self.store.add_access_token_to_user(
+ self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
+ )
)
- yield self.store.add_access_token_to_user(
- self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.store.add_access_token_to_user(
+ self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ )
)
# now delete some
- yield self.store.user_delete_access_tokens(
- self.user_id, device_id=self.device_id
+ yield defer.ensureDeferred(
+ self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
)
# check they were deleted
- user = yield self.store.get_user_by_access_token(self.tokens[1])
+ user = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[1])
+ )
self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted
- user = yield self.store.get_user_by_access_token(self.tokens[0])
+ user = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[0])
+ )
self.assertEqual(self.user_id, user["name"])
# now delete the rest
- yield self.store.user_delete_access_tokens(self.user_id)
+ yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
- user = yield self.store.get_user_by_access_token(self.tokens[0])
+ user = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[0])
+ )
self.assertIsNone(user, "access token was not deleted without device_id")
@defer.inlineCallbacks
@@ -105,14 +120,48 @@ class RegistrationStoreTestCase(unittest.TestCase):
TEST_USER = "@test:test"
SUPPORT_USER = "@support:test"
- res = yield self.store.is_support_user(None)
+ res = yield defer.ensureDeferred(self.store.is_support_user(None))
self.assertFalse(res)
- yield self.store.register_user(user_id=TEST_USER, password_hash=None)
- res = yield self.store.is_support_user(TEST_USER)
+ yield defer.ensureDeferred(
+ self.store.register_user(user_id=TEST_USER, password_hash=None)
+ )
+ res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
self.assertFalse(res)
- yield self.store.register_user(
- user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+ yield defer.ensureDeferred(
+ self.store.register_user(
+ user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+ )
)
- res = yield self.store.is_support_user(SUPPORT_USER)
+ res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
+
+ @defer.inlineCallbacks
+ def test_3pid_inhibit_invalid_validation_session_error(self):
+ """Tests that enabling the configuration option to inhibit 3PID errors on
+ /requestToken also inhibits validation errors caused by an unknown session ID.
+ """
+
+ # Check that, with the config setting set to false (the default value), a
+ # validation error is caused by the unknown session ID.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Unknown session_id", e)
+
+ # Set the config setting to true.
+ self.store._ignore_unknown_session_error = True
+
+ # Check that now the validation error is caused by the token not matching.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1d77b4a2d6..bc8400f240 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id=self.u_creator.to_string(),
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id=self.u_creator.to_string(),
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
@@ -52,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (yield self.store.get_room(self.room.to_string())),
+ (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_get_room_unknown_room(self):
- self.assertIsNone((yield self.store.get_room("!uknown:test")),)
+ self.assertIsNone(
+ (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
+ )
@defer.inlineCallbacks
def test_get_room_with_stats(self):
@@ -67,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
- (yield self.store.get_room_with_stats(self.room.to_string())),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats(self.room.to_string())
+ )
+ ),
)
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
- self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats("!uknown:test")
+ )
+ ),
+ )
class RoomEventsStoreTestCase(unittest.TestCase):
@@ -88,17 +102,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id="@creator:text",
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield self.storage.persistence.persist_event(
- self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(
+ self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+ )
)
@defer.inlineCallbacks
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index f282921538..12ccc1f53e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -87,7 +87,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
- self.pump(20)
+ self.pump()
self.assertTrue("_known_servers_count" not in self.store.__dict__.keys())
@@ -101,7 +101,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Initialises to 1 -- itself
self.assertEqual(self.store._known_servers_count, 1)
- self.pump(20)
+ self.pump()
# No rooms have been joined, so technically the SQL returns 0, but it
# will still say it knows about itself.
@@ -111,7 +111,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
- self.pump(20)
+ self.pump(1)
# It now knows about Charlie's server.
self.assertEqual(self.store._known_servers_count, 2)
@@ -179,20 +179,20 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def test_can_rerun_update(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
table="background_updates",
values={
"update_name": "current_state_events_membership",
@@ -203,12 +203,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index a0e133cd4a..8bd12fa847 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id="@creator:text",
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
@@ -68,7 +70,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- yield self.storage.persistence.persist_event(event, context)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
@@ -87,8 +91,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.storage.state.get_state_groups_ids(
- self.room, [e2.event_id]
+ state_group_map = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@@ -106,8 +110,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.storage.state.get_state_groups(
- self.room, [e2.event_id]
+ state_group_map = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
@@ -148,7 +152,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield self.storage.state.get_state_for_event(e5.event_id)
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(e5.event_id)
+ )
self.assertIsNotNone(e4)
@@ -164,22 +170,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+ )
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+ )
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+ )
)
self.assertStateMapEqual(
@@ -188,12 +200,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield self.storage.state.get_state_for_event(
- e5.event_id,
- state_filter=StateFilter(
- types={EventTypes.Member: {self.u_alice.to_string()}},
- include_others=True,
- ),
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {self.u_alice.to_string()}},
+ include_others=True,
+ ),
+ )
)
self.assertStateMapEqual(
@@ -206,11 +220,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield self.storage.state.get_state_for_event(
- e5.event_id,
- state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
- ),
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()}, include_others=True
+ ),
+ )
)
self.assertStateMapEqual(
@@ -222,8 +238,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield self.storage.state.get_state_groups_ids(
- room_id, [e5.event_id]
+ group_ids = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 6a545d2eb0..738e912468 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -31,16 +31,24 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
- yield self.store.update_profile_in_user_dir(ALICE, "alice", None)
- yield self.store.update_profile_in_user_dir(BOB, "bob", None)
- yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
- yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+ yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(ALICE, "alice", None)
+ )
+ yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(BOB, "bob", None)
+ )
+ yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
+ )
+ yield defer.ensureDeferred(
+ self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+ )
@defer.inlineCallbacks
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
- r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
@@ -51,7 +59,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
def test_search_user_dir_all_users(self):
self.hs.config.user_directory_search_all_users = True
try:
- r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"]))
self.assertDictEqual(
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 87a16d7d7a..27a7fc9ed7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,7 +1,23 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from mock import Mock
-from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
+from twisted.internet.defer import succeed
+from synapse.api.errors import FederationError
from synapse.events import make_event_from_dict
from synapse.logging.context import LoggingContext
from synapse.types import Requester, UserID
@@ -10,6 +26,7 @@ from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
+from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase):
@@ -26,24 +43,19 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
user_id = UserID("us", "test")
- our_user = Requester(user_id, None, False, None, None)
+ our_user = Requester(user_id, None, False, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
- room_deferred = ensureDeferred(
+ self.room_id = self.get_success(
room_creator.create_room(
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
)
- )
- self.reactor.advance(0.1)
- self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
+ )[0]["room_id"]
self.store = self.homeserver.get_datastore()
# Figure out what the most recent event is
- most_recent = self.successResultOf(
- maybeDeferred(
- self.homeserver.get_datastore().get_latest_event_ids_in_room,
- self.room_id,
- )
+ most_recent = self.get_success(
+ self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id)
)[0]
join_event = make_event_from_dict(
@@ -73,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
# Send the join, it should return None (which is not an error)
- d = ensureDeferred(
- self.handler.on_receive_pdu(
- "test.serv", join_event, sent_to_us_directly=True
- )
+ self.assertEqual(
+ self.get_success(
+ self.handler.on_receive_pdu(
+ "test.serv", join_event, sent_to_us_directly=True
+ )
+ ),
+ None,
)
- self.reactor.advance(1)
- self.assertEqual(self.successResultOf(d), None)
# Make sure we actually joined the room
self.assertEqual(
- self.successResultOf(
- maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
- )[0],
+ self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
"$join:test.serv",
)
@@ -95,7 +106,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
prev_events that said event references.
"""
- def post_json(destination, path, data, headers=None, timeout=0):
+ async def post_json(destination, path, data, headers=None, timeout=0):
# If it asks us for new missing events, give them NOTHING
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
return {"events": []}
@@ -103,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.http_client.post_json = post_json
# Figure out what the most recent event is
- most_recent = self.successResultOf(
- maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
+ most_recent = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
)[0]
# Now lie about an event
@@ -124,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
with LoggingContext(request="lying_event"):
- d = ensureDeferred(
+ failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
- )
+ ),
+ FederationError,
)
- # Step the reactor, so the database fetches come back
- self.reactor.advance(1)
-
# on_receive_pdu should throw an error
- failure = self.failureResultOf(d)
self.assertEqual(
failure.value.args[0],
(
@@ -144,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
# Make sure the invalid event isn't there
- extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
- self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
+ extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
+ self.assertEqual(extrem[0], "$join:test.serv")
def test_retry_device_list_resync(self):
"""Tests that device lists are marked as stale if they couldn't be synced, and
@@ -173,7 +181,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
store = self.homeserver.get_datastore()
- store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"]))
+ store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
diff --git a/tests/test_server.py b/tests/test_server.py
index 073b2362cc..655c918a15 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -157,6 +157,28 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+ def test_head_request(self):
+ """
+ JsonResource.handler_for_request gives correctly decoded URL args to
+ the callback, while Twisted will give the raw bytes of URL query
+ arguments.
+ """
+
+ def _callback(request, **kwargs):
+ return 200, {"result": True}
+
+ res = JsonResource(self.homeserver)
+ res.register_paths(
+ "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet",
+ )
+
+ # The path was registered as GET, but this is a HEAD request.
+ request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertNotIn("body", channel.result)
+
class OptionsResourceTests(unittest.TestCase):
def setUp(self):
@@ -255,7 +277,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor = ThreadedMemoryReactorClock()
def test_good_response(self):
- def callback(request):
+ async def callback(request):
request.write(b"response")
request.finish()
@@ -275,7 +297,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
with the right location.
"""
- def callback(request, **kwargs):
+ async def callback(request, **kwargs):
raise RedirectException(b"/look/an/eagle", 301)
res = WrapHtmlRequestHandlerTests.TestResource()
@@ -295,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
returned too
"""
- def callback(request, **kwargs):
+ async def callback(request, **kwargs):
e = RedirectException(b"/no/over/there", 304)
e.cookies.append(b"session=yespls")
raise e
@@ -312,3 +334,19 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.assertEqual(location_headers, [b"/no/over/there"])
cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
self.assertEqual(cookies_headers, [b"session=yespls"])
+
+ def test_head_request(self):
+ """A head request should work by being turned into a GET request."""
+
+ async def callback(request):
+ request.write(b"response")
+ request.finish()
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"HEAD", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertNotIn("body", channel.result)
diff --git a/tests/test_state.py b/tests/test_state.py
index 4858e8fc59..80b0ccbc40 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -71,7 +71,7 @@ def create_event(
return event
-class StateGroupStore(object):
+class StateGroupStore:
def __init__(self):
self._event_to_state_group = {}
self._group_to_state = {}
@@ -80,16 +80,16 @@ class StateGroupStore(object):
self._next_group = 1
- def get_state_groups_ids(self, room_id, event_ids):
+ async def get_state_groups_ids(self, room_id, event_ids):
groups = {}
for event_id in event_ids:
group = self._event_to_state_group.get(event_id)
if group:
groups[group] = self._group_to_state[group]
- return defer.succeed(groups)
+ return groups
- def store_state_group(
+ async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
state_group = self._next_group
@@ -97,19 +97,17 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids)
- return defer.succeed(state_group)
+ return state_group
- def get_events(self, event_ids, **kwargs):
- return defer.succeed(
- {
- e_id: self._event_id_to_event[e_id]
- for e_id in event_ids
- if e_id in self._event_id_to_event
- }
- )
+ async def get_events(self, event_ids, **kwargs):
+ return {
+ e_id: self._event_id_to_event[e_id]
+ for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
- def get_state_group_delta(self, name):
- return defer.succeed((None, None))
+ async def get_state_group_delta(self, name):
+ return (None, None)
def register_events(self, events):
for e in events:
@@ -121,17 +119,17 @@ class StateGroupStore(object):
def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group
- def get_room_version_id(self, room_id):
- return defer.succeed(RoomVersions.V1.identifier)
+ async def get_room_version_id(self, room_id):
+ return RoomVersions.V1.identifier
class DictObj(dict):
def __init__(self, **kwargs):
- super(DictObj, self).__init__(kwargs)
+ super().__init__(kwargs)
self.__dict__ = self
-class Graph(object):
+class Graph:
def __init__(self, nodes, edges):
events = {}
clobbered = set(events.keys())
@@ -213,7 +211,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertEqual(2, len(prev_state_ids))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -259,7 +257,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -318,7 +316,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_e = context_store["E"]
- prev_state_ids = yield ctx_e.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@@ -393,7 +391,7 @@ class StateTestCase(unittest.TestCase):
ctx_b = context_store["B"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
@@ -425,7 +423,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state)
)
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -450,7 +448,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state)
)
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -476,12 +474,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = yield self.store.store_state_group(
- prev_event_id,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state},
+ group_name = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
)
self.store.register_event_id_state_group(prev_event_id, group_name)
@@ -508,18 +508,20 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = yield self.store.store_state_group(
- prev_event_id,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state},
+ group_name = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
)
self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
@@ -691,21 +693,25 @@ class StateTestCase(unittest.TestCase):
def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
- sg1 = yield self.store.store_state_group(
- prev_event_id_1,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state_1},
+ sg1 = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id_1,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state_1},
+ )
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
- sg2 = yield self.store.store_state_group(
- prev_event_id_2,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state_2},
+ sg2 = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id_2,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state_2},
+ )
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 508aeba078..a298cc0fd3 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,6 +17,7 @@
"""
Utilities for running the unit tests
"""
+from asyncio import Future
from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@@ -38,6 +39,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed")
-async def make_awaitable(result: Any):
- """Create an awaitable that just returns a result."""
- return result
+def make_awaitable(result: Any) -> Awaitable[Any]:
+ """
+ Makes an awaitable, suitable for mocking an `async` function.
+ This uses Futures as they can be awaited multiple times so can be returned
+ to multiple callers.
+ """
+ future = Future() # type: ignore
+ future.set_result(result)
+ return future
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8522c6fc09..e93aa84405 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -13,14 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import Collection
"""
Utility functions for poking events into the storage of the server under test.
@@ -58,7 +57,7 @@ async def inject_member_event(
async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> EventBase:
"""Inject a generic event into a room
@@ -72,7 +71,10 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- await hs.get_storage().persistence.persist_event(event, context)
+ persistence = hs.get_storage().persistence
+ assert persistence is not None
+
+ await persistence.persist_event(event, context)
return event
@@ -80,7 +82,7 @@ async def inject_event(
async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
if room_version is None:
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 2d96b0fa8d..fdfb840b62 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -29,8 +29,7 @@ class ToTwistedHandler(logging.Handler):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit(
- twisted.logger.LogLevel.levelWithName(log_level),
- log_entry.replace("{", r"(").replace("}", r")"),
+ twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index b371efc0df..510b630114 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -37,10 +37,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.hs = yield setup_test_homeserver(self.addCleanup)
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
- self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage()
- yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
+ yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@defer.inlineCallbacks
def test_filtering(self):
@@ -64,8 +63,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
- filtered = yield filter_events_for_server(
- self.storage, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(self.storage, "test_server", events_to_filter)
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -99,11 +98,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt)
# the erasey user gets erased
- yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+ yield defer.ensureDeferred(
+ self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+ )
# ... and the filtering happens.
- filtered = yield filter_events_for_server(
- self.storage, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(self.storage, "test_server", events_to_filter)
)
for i in range(0, len(events_to_filter)):
@@ -140,7 +141,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder)
)
- yield self.storage.persistence.persist_event(event, context)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
@defer.inlineCallbacks
@@ -162,7 +165,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- yield self.storage.persistence.persist_event(event, context)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
@defer.inlineCallbacks
@@ -183,7 +188,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- yield self.storage.persistence.persist_event(event, context)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
@defer.inlineCallbacks
@@ -265,8 +272,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
storage.main = test_store
storage.state = test_store
- filtered = yield filter_events_for_server(
- test_store, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(test_store, "test_server", events_to_filter)
)
logger.info("Filtering took %f seconds", time.time() - start)
@@ -287,7 +294,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
test_large_room.skip = "Disabled by default because it's slow"
-class _TestStore(object):
+class _TestStore:
"""Implements a few methods of the DataStore, so that we can test
filter_events_for_server
diff --git a/tests/unittest.py b/tests/unittest.py
index 68d2586efd..e654c0442d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import gc
import hashlib
import hmac
@@ -23,11 +22,12 @@ import logging
import time
from typing import Optional, Tuple, Type, TypeVar, Union
-from mock import Mock
+from mock import Mock, patch
from canonicaljson import json
from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
@@ -92,7 +92,7 @@ class TestCase(unittest.TestCase):
root logger's logging level while that test (case|method) runs."""
def __init__(self, methodName, *args, **kwargs):
- super(TestCase, self).__init__(methodName, *args, **kwargs)
+ super().__init__(methodName, *args, **kwargs)
method = getattr(self, methodName)
@@ -169,6 +169,19 @@ def INFO(target):
return target
+def logcontext_clean(target):
+ """A decorator which marks the TestCase or method as 'logcontext_clean'
+
+ ... ie, any logcontext errors should cause a test failure
+ """
+
+ def logcontext_error(msg):
+ raise AssertionError("logcontext error: %s" % (msg))
+
+ patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
+ return patcher(target)
+
+
class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
@@ -241,20 +254,20 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
- def get_user_by_access_token(token=None, allow_guest=False):
- return succeed(
- {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- )
-
- def get_user_by_req(request, allow_guest=False, rights="access"):
- return succeed(
- create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
- )
+ async def get_user_by_access_token(token=None, allow_guest=False):
+ return {
+ "user": UserID.from_string(self.helper.auth_user_id),
+ "token_id": 1,
+ "is_guest": False,
+ }
+
+ async def get_user_by_req(request, allow_guest=False, rights="access"):
+ return create_requester(
+ UserID.from_string(self.helper.auth_user_id),
+ 1,
+ False,
+ False,
+ None,
)
self.hs.get_auth().get_user_by_req = get_user_by_req
@@ -353,6 +366,7 @@ class HomeserverTestCase(TestCase):
request: Type[T] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: str = None,
+ content_is_form: bool = False,
) -> Tuple[T, FakeChannel]:
"""
Create a SynapseRequest at the path using the method and containing the
@@ -368,6 +382,8 @@ class HomeserverTestCase(TestCase):
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
+ content_is_form: Whether the content is URL encoded form data. Adds the
+ 'Content-Type': 'application/x-www-form-urlencoded' header.
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
@@ -384,6 +400,7 @@ class HomeserverTestCase(TestCase):
request,
shorthand,
federation_auth_origin,
+ content_is_form,
)
def render(self, request):
@@ -422,8 +439,8 @@ class HomeserverTestCase(TestCase):
async def run_bg_updates():
with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
- while not await stor.db.updates.has_completed_background_updates():
- await stor.db.updates.do_next_background_update(1)
+ while not await stor.db_pool.updates.has_completed_background_updates():
+ await stor.db_pool.updates.do_next_background_update(1)
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
@@ -459,6 +476,35 @@ class HomeserverTestCase(TestCase):
self.pump()
return self.failureResultOf(d, exc)
+ def get_success_or_raise(self, d, by=0.0):
+ """Drive deferred to completion and return result or raise exception
+ on failure.
+ """
+
+ if inspect.isawaitable(d):
+ deferred = ensureDeferred(d)
+ if not isinstance(deferred, Deferred):
+ return d
+
+ results = [] # type: list
+ deferred.addBoth(results.append)
+
+ self.pump(by=by)
+
+ if not results:
+ self.fail(
+ "Success result expected on {!r}, found no result instead".format(
+ deferred
+ )
+ )
+
+ result = results[0]
+
+ if isinstance(result, Failure):
+ result.raiseException()
+
+ return result
+
def register_user(self, username, password, admin=False):
"""
Register a user. Requires the Admin API be registered.
@@ -544,7 +590,7 @@ class HomeserverTestCase(TestCase):
"""
event_creator = self.hs.get_event_creation_handler()
secrets = self.hs.get_secrets()
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
event, context = self.get_success(
event_creator.create_event(
@@ -571,7 +617,7 @@ class HomeserverTestCase(TestCase):
Add the given event as an extremity to the room.
"""
self.get_success(
- self.hs.get_datastore().db.simple_insert(
+ self.hs.get_datastore().db_pool.simple_insert(
table="event_forward_extremities",
values={"room_id": room_id, "event_id": event_id},
desc="test_add_extremity",
@@ -614,7 +660,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
"""
def prepare(self, reactor, clock, homeserver):
- class Authenticator(object):
+ class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 4d2b9e0d64..677e925477 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -88,7 +88,7 @@ class CacheTestCase(unittest.TestCase):
class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -122,7 +122,7 @@ class DescriptorTestCase(unittest.TestCase):
def test_cache_num_args(self):
"""Only the first num_args arguments should matter to the cache"""
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -156,7 +156,7 @@ class DescriptorTestCase(unittest.TestCase):
"""If the wrapped function throws synchronously, things should continue to work
"""
- class Cls(object):
+ class Cls:
@cached()
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
@@ -180,7 +180,7 @@ class DescriptorTestCase(unittest.TestCase):
complete_lookup = defer.Deferred()
- class Cls(object):
+ class Cls:
@descriptors.cached()
def fn(self, arg1):
@defer.inlineCallbacks
@@ -223,7 +223,7 @@ class DescriptorTestCase(unittest.TestCase):
"""Check that the cache sets and restores logcontexts correctly when
the lookup function throws an exception"""
- class Cls(object):
+ class Cls:
@descriptors.cached()
def fn(self, arg1):
@defer.inlineCallbacks
@@ -263,7 +263,7 @@ class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache_default_args(self):
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -300,7 +300,7 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called()
def test_cache_iterable(self):
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -336,7 +336,7 @@ class DescriptorTestCase(unittest.TestCase):
"""If the wrapped function throws synchronously, things should continue to work
"""
- class Cls(object):
+ class Cls:
@descriptors.cached(iterable=True)
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
@@ -358,7 +358,7 @@ class DescriptorTestCase(unittest.TestCase):
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
assert current_context().request == "c1"
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
assert current_context().request == "c1"
return self.mock(args1, arg2)
@@ -408,7 +408,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def test_invalidate(self):
"""Make sure that invalidation callbacks are called."""
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
return self.mock(args1, arg2)
obj = Cls()
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index 8d6627ec33..2012263184 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -112,7 +112,7 @@ class FileConsumerTests(unittest.TestCase):
self.assertTrue(string_file.closed)
-class DummyPullProducer(object):
+class DummyPullProducer:
def __init__(self):
self.consumer = None
self.deferred = defer.Deferred()
@@ -134,7 +134,7 @@ class DummyPullProducer(object):
return d
-class BlockingStringWrite(object):
+class BlockingStringWrite:
def __init__(self):
self.buffer = ""
self.closed = False
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9e348694ad..5f46ed0cef 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
def test_new_destination(self):
"""A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastore()
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
# advance the clock a bit before making the request
self.pump(1)
@@ -36,18 +34,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
with limiter:
pass
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
def test_limiter(self):
"""General test case which walks through the process of a failing request"""
store = self.hs.get_datastore()
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
try:
@@ -58,29 +52,22 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError:
pass
- # wait for the update to land
- self.pump()
-
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], failure_ts)
self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
# now if we try again we should get a failure
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- self.failureResultOf(d, NotRetryingDestination)
+ self.get_failure(
+ get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
+ )
#
# advance the clock and try again
#
self.pump(MIN_RETRY_INTERVAL)
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
try:
@@ -91,12 +78,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError:
pass
- # wait for the update to land
- self.pump()
-
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], retry_ts)
self.assertGreaterEqual(
@@ -109,10 +91,8 @@ class RetryLimiterTestCase(HomeserverTestCase):
#
# one more go, with success
#
- self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
with limiter:
@@ -121,7 +101,5 @@ class RetryLimiterTestCase(HomeserverTestCase):
# wait for the update to land
self.pump()
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index bd32e2cee7..d3dea3b52a 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
from synapse.util.async_helpers import ReadWriteLock
@@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
rwlock.read(key), # 5
rwlock.write(key), # 6
]
+ ds = [defer.ensureDeferred(d) for d in ds]
self._assert_called_before_not_after(ds, 2)
@@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase):
with ds[6].result:
pass
- d = rwlock.write(key)
+ d = defer.ensureDeferred(rwlock.write(key))
self.assertTrue(d.called)
with d.result:
pass
- d = rwlock.read(key)
+ d = defer.ensureDeferred(rwlock.read(key))
self.assertTrue(d.called)
with d.result:
pass
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index 4f4da29a98..8491f7cc83 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -28,9 +28,6 @@ class StringUtilsTestCase(unittest.TestCase):
"_--something==_",
"...--==-18913",
"8Dj2odd-e9asd.cd==_--ddas-secret-",
- # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766
- # To be removed in a future release
- "SECRET:1234567890",
]
bad = [
diff --git a/tests/utils.py b/tests/utils.py
index 8eb6fdd98e..867cf88977 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -154,6 +154,10 @@ def default_config(name, parse=False):
"account": {"per_second": 10000, "burst_count": 10000},
"failed_attempts": {"per_second": 10000, "burst_count": 10000},
},
+ "rc_joins": {
+ "local": {"per_second": 10000, "burst_count": 10000},
+ "remote": {"per_second": 10000, "burst_count": 10000},
+ },
"saml2_enabled": False,
"public_baseurl": None,
"default_identity_server": None,
@@ -470,7 +474,7 @@ class MockHttpResource(HttpServer):
self.callbacks.append((method, path_pattern, callback))
-class MockKey(object):
+class MockKey:
alg = "mock_alg"
version = "mock_version"
signature = b"\x9a\x87$"
@@ -489,7 +493,7 @@ class MockKey(object):
return b"<fake_encoded_key>"
-class MockClock(object):
+class MockClock:
now = 1000
def __init__(self):
@@ -566,7 +570,7 @@ def _format_call(args, kwargs):
)
-class DeferredMockCallable(object):
+class DeferredMockCallable:
"""A callable instance that stores a set of pending call expectations and
return values for them. It allows a unit test to assert that the given set
of function calls are eventually made, by awaiting on them to be called.
@@ -640,14 +644,8 @@ class DeferredMockCallable(object):
)
-@defer.inlineCallbacks
-def create_room(hs, room_id, creator_id):
+async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room
-
- Args:
- hs
- room_id (str)
- creator_id (str)
"""
persistence_store = hs.get_storage().persistence
@@ -655,7 +653,7 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
- yield store.store_room(
+ await store.store_room(
room_id=room_id,
room_creator_user_id=creator_id,
is_public=False,
@@ -673,8 +671,6 @@ def create_room(hs, room_id, creator_id):
},
)
- event, context = yield defer.ensureDeferred(
- event_creation_handler.create_new_client_event(builder)
- )
+ event, context = await event_creation_handler.create_new_client_event(builder)
- yield persistence_store.persist_event(event, context)
+ await persistence_store.persist_event(event, context)
|