diff --git a/tests/__init__.py b/tests/__init__.py
index f7fc502f01..ed805db1c2 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -16,9 +16,9 @@
from twisted.trial import util
-import tests.patch_inline_callbacks
+from synapse.util.patch_inline_callbacks import do_patch
# attempt to do the patch before we load any synapse code
-tests.patch_inline_callbacks.do_patch()
+do_patch()
util.DEFAULT_TIMEOUT_DURATION = 20
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index b02780772a..1be6ff563b 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -21,17 +21,24 @@ import yaml
from OpenSSL import SSL
+from synapse.config._base import Config, RootConfig
from synapse.config.tls import ConfigError, TlsConfig
from synapse.crypto.context_factory import ClientTLSOptionsFactory
from tests.unittest import TestCase
-class TestConfig(TlsConfig):
+class FakeServer(Config):
+ section = "server"
+
def has_tls_listener(self):
return False
+class TestConfig(RootConfig):
+ config_classes = [FakeServer, TlsConfig]
+
+
class TLSConfigTests(TestCase):
def test_warn_self_signed(self):
"""
@@ -202,13 +209,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
conf = TestConfig()
conf.read_config(
yaml.safe_load(
- TestConfig().generate_config_section(
+ TestConfig().generate_config(
"/config_dir_path",
"my_super_secure_server",
"/data_dir_path",
- "/tls_cert_path",
- "tls_private_key",
- None, # This is the acme_domain
+ tls_certificate_path="/tls_cert_path",
+ tls_private_key_path="tls_private_key",
+ acme_domain=None, # This is the acme_domain
)
),
"/config_dir_path",
@@ -223,13 +230,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
conf = TestConfig()
conf.read_config(
yaml.safe_load(
- TestConfig().generate_config_section(
+ TestConfig().generate_config(
"/config_dir_path",
"my_super_secure_server",
"/data_dir_path",
- "/tls_cert_path",
- "tls_private_key",
- "my_supe_secure_server", # This is the acme_domain
+ tls_certificate_path="/tls_cert_path",
+ tls_private_key_path="tls_private_key",
+ acme_domain="my_supe_secure_server", # This is the acme_domain
)
),
"/config_dir_path",
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index c4503c1611..0bb96674a2 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -187,9 +187,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, 404)
@defer.inlineCallbacks
- def test_update_bad_version(self):
- """Check that we get a 400 if the version in the body is missing or
- doesn't match
+ def test_update_omitted_version(self):
+ """Check that the update succeeds if the version is missing from the body
"""
version = yield self.handler.create_version(
self.local_user,
@@ -197,19 +196,35 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- res = None
- try:
- yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- },
- )
- except errors.SynapseError as e:
- res = e.code
- self.assertEqual(res, 400)
+ yield self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ },
+ )
+
+ # check we can retrieve it as the current version
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertDictEqual(
+ res,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": version,
+ },
+ )
+
+ @defer.inlineCallbacks
+ def test_update_bad_version(self):
+ """Check that we get a 400 if the version in the body doesn't match
+ """
+ version = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
+ self.assertEqual(version, "1")
res = None
try:
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1f2ef5d01f..67f1013051 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -139,7 +139,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
defer.succeed(1)
)
- self.datastore.get_current_state_deltas.return_value = None
+ self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
deleted file mode 100644
index 220884311c..0000000000
--- a/tests/patch_inline_callbacks.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import functools
-import sys
-
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
-from twisted.python.failure import Failure
-
-
-def do_patch():
- """
- Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
- """
-
- from synapse.logging.context import LoggingContext
-
- orig_inline_callbacks = defer.inlineCallbacks
-
- def new_inline_callbacks(f):
-
- orig = orig_inline_callbacks(f)
-
- @functools.wraps(f)
- def wrapped(*args, **kwargs):
- start_context = LoggingContext.current_context()
-
- try:
- res = orig(*args, **kwargs)
- except Exception:
- if LoggingContext.current_context() != start_context:
- err = "%s changed context from %s to %s on exception" % (
- f,
- start_context,
- LoggingContext.current_context(),
- )
- print(err, file=sys.stderr)
- raise Exception(err)
- raise
-
- if not isinstance(res, Deferred) or res.called:
- if LoggingContext.current_context() != start_context:
- err = "%s changed context from %s to %s" % (
- f,
- start_context,
- LoggingContext.current_context(),
- )
- # print the error to stderr because otherwise all we
- # see in travis-ci is the 500 error
- print(err, file=sys.stderr)
- raise Exception(err)
- return res
-
- if LoggingContext.current_context() != LoggingContext.sentinel:
- err = (
- "%s returned incomplete deferred in non-sentinel context "
- "%s (start was %s)"
- ) % (f, LoggingContext.current_context(), start_context)
- print(err, file=sys.stderr)
- raise Exception(err)
-
- def check_ctx(r):
- if LoggingContext.current_context() != start_context:
- err = "%s completion of %s changed context from %s to %s" % (
- "Failure" if isinstance(r, Failure) else "Success",
- f,
- start_context,
- LoggingContext.current_context(),
- )
- print(err, file=sys.stderr)
- raise Exception(err)
- return r
-
- res.addBoth(check_ctx)
- return res
-
- return wrapped
-
- defer.inlineCallbacks = new_inline_callbacks
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 5877bb2133..d3a4f717f7 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -62,7 +62,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.device_handler.check_device_registered = Mock(return_value="FAKE")
self.datastore = Mock(return_value=Mock())
- self.datastore.get_current_state_deltas = Mock(return_value=[])
+ self.datastore.get_current_state_deltas = Mock(return_value=(0, []))
self.secrets = Mock()
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index fe741637f5..2f2ca74611 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -484,6 +484,15 @@ class RoomsCreateTestCase(RoomBase):
self.render(request)
self.assertEquals(400, channel.code)
+ def test_post_room_invitees_invalid_mxid(self):
+ # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
+ # Note the trailing space in the MXID here!
+ request, channel = self.make_request(
+ "POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
+ )
+ self.render(request)
+ self.assertEquals(400, channel.code)
+
class RoomTopicTestCase(RoomBase):
""" Tests /rooms/$room_id/topic REST events. """
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index f42a8efbf4..e0e9e94fbf 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -92,7 +92,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
self.render(request)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.result["code"], b"404")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index b58386994e..2fe50377f8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -57,7 +57,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
"(event_id, algorithm, hash) "
"VALUES (?, 'sha256', ?)"
),
- (event_id, b"ffff"),
+ (event_id, bytearray(b"ffff")),
)
for i in range(0, 11):
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 1494650d10..90a63dc477 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -50,6 +50,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": user2_email},
{"medium": "email", "address": user3_email},
]
+ self.hs.config.mau_limits_reserved_threepids = threepids
# -1 because user3 is a support user and does not count
user_num = len(threepids) - 1
@@ -84,6 +85,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.hs.config.max_mau_value = 0
self.reactor.advance(FORTY_DAYS)
+ self.hs.config.max_mau_value = 5
self.store.reap_monthly_active_users()
self.pump()
@@ -147,9 +149,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.reap_monthly_active_users()
self.pump()
count = self.store.get_monthly_active_count()
- self.assertEquals(
- self.get_success(count), initial_users - self.hs.config.max_mau_value
- )
+ self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
self.reactor.advance(FORTY_DAYS)
self.store.reap_monthly_active_users()
@@ -158,6 +158,44 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(count), 0)
+ def test_reap_monthly_active_users_reserved_users(self):
+ """ Tests that reaping correctly handles reaping where reserved users are
+ present"""
+
+ self.hs.config.max_mau_value = 5
+ initial_users = 5
+ reserved_user_number = initial_users - 1
+ threepids = []
+ for i in range(initial_users):
+ user = "@user%d:server" % i
+ email = "user%d@example.com" % i
+ self.get_success(self.store.upsert_monthly_active_user(user))
+ threepids.append({"medium": "email", "address": email})
+ # Need to ensure that the most recent entries in the
+ # monthly_active_users table are reserved
+ now = int(self.hs.get_clock().time_msec())
+ if i != 0:
+ self.get_success(
+ self.store.register_user(user_id=user, password_hash=None)
+ )
+ self.get_success(
+ self.store.user_add_threepid(user, "email", email, now, now)
+ )
+
+ self.hs.config.mau_limits_reserved_threepids = threepids
+ self.store.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
+ count = self.store.get_monthly_active_count()
+ self.assertTrue(self.get_success(count), initial_users)
+
+ users = self.store.get_registered_reserved_users()
+ self.assertEquals(len(self.get_success(users)), reserved_user_number)
+
+ self.get_success(self.store.reap_monthly_active_users())
+ count = self.store.get_monthly_active_count()
+ self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
+
def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list
user_id = "@user_id:host"
@@ -192,12 +230,13 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_get_reserved_real_user_account(self):
# Test no reserved users, or reserved threepids
- count = self.store.get_registered_reserved_users_count()
- self.assertEquals(self.get_success(count), 0)
+ users = self.get_success(self.store.get_registered_reserved_users())
+ self.assertEquals(len(users), 0)
# Test reserved users but no registered users
user1 = "@user1:example.com"
user2 = "@user2:example.com"
+
user1_email = "user1@example.com"
user2_email = "user2@example.com"
threepids = [
@@ -210,8 +249,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.pump()
- count = self.store.get_registered_reserved_users_count()
- self.assertEquals(self.get_success(count), 0)
+ users = self.get_success(self.store.get_registered_reserved_users())
+ self.assertEquals(len(users), 0)
# Test reserved registed users
self.store.register_user(user_id=user1, password_hash=None)
@@ -221,8 +260,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
now = int(self.hs.get_clock().time_msec())
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
- count = self.store.get_registered_reserved_users_count()
- self.assertEquals(self.get_success(count), len(threepids))
+
+ users = self.get_success(self.store.get_registered_reserved_users())
+ self.assertEquals(len(users), len(threepids))
def test_support_user_not_add_to_mau_limits(self):
support_user_id = "@support:test"
|