summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py4
-rw-r--r--tests/config/test_tls.py25
-rw-r--r--tests/handlers/test_typing.py2
-rw-r--r--tests/patch_inline_callbacks.py94
-rw-r--r--tests/rest/admin/test_admin.py2
-rw-r--r--tests/rest/client/v1/test_rooms.py9
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py2
-rw-r--r--tests/storage/test_event_federation.py2
-rw-r--r--tests/storage/test_monthly_active_users.py58
9 files changed, 80 insertions, 118 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
index f7fc502f01..ed805db1c2 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -16,9 +16,9 @@
 
 from twisted.trial import util
 
-import tests.patch_inline_callbacks
+from synapse.util.patch_inline_callbacks import do_patch
 
 # attempt to do the patch before we load any synapse code
-tests.patch_inline_callbacks.do_patch()
+do_patch()
 
 util.DEFAULT_TIMEOUT_DURATION = 20
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index b02780772a..1be6ff563b 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -21,17 +21,24 @@ import yaml
 
 from OpenSSL import SSL
 
+from synapse.config._base import Config, RootConfig
 from synapse.config.tls import ConfigError, TlsConfig
 from synapse.crypto.context_factory import ClientTLSOptionsFactory
 
 from tests.unittest import TestCase
 
 
-class TestConfig(TlsConfig):
+class FakeServer(Config):
+    section = "server"
+
     def has_tls_listener(self):
         return False
 
 
+class TestConfig(RootConfig):
+    config_classes = [FakeServer, TlsConfig]
+
+
 class TLSConfigTests(TestCase):
     def test_warn_self_signed(self):
         """
@@ -202,13 +209,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
         conf = TestConfig()
         conf.read_config(
             yaml.safe_load(
-                TestConfig().generate_config_section(
+                TestConfig().generate_config(
                     "/config_dir_path",
                     "my_super_secure_server",
                     "/data_dir_path",
-                    "/tls_cert_path",
-                    "tls_private_key",
-                    None,  # This is the acme_domain
+                    tls_certificate_path="/tls_cert_path",
+                    tls_private_key_path="tls_private_key",
+                    acme_domain=None,  # This is the acme_domain
                 )
             ),
             "/config_dir_path",
@@ -223,13 +230,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
         conf = TestConfig()
         conf.read_config(
             yaml.safe_load(
-                TestConfig().generate_config_section(
+                TestConfig().generate_config(
                     "/config_dir_path",
                     "my_super_secure_server",
                     "/data_dir_path",
-                    "/tls_cert_path",
-                    "tls_private_key",
-                    "my_supe_secure_server",  # This is the acme_domain
+                    tls_certificate_path="/tls_cert_path",
+                    tls_private_key_path="tls_private_key",
+                    acme_domain="my_supe_secure_server",  # This is the acme_domain
                 )
             ),
             "/config_dir_path",
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1f2ef5d01f..67f1013051 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -139,7 +139,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             defer.succeed(1)
         )
 
-        self.datastore.get_current_state_deltas.return_value = None
+        self.datastore.get_current_state_deltas.return_value = (0, None)
 
         self.datastore.get_to_device_stream_token = lambda: 0
         self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
deleted file mode 100644
index 220884311c..0000000000
--- a/tests/patch_inline_callbacks.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import functools
-import sys
-
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
-from twisted.python.failure import Failure
-
-
-def do_patch():
-    """
-    Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
-    """
-
-    from synapse.logging.context import LoggingContext
-
-    orig_inline_callbacks = defer.inlineCallbacks
-
-    def new_inline_callbacks(f):
-
-        orig = orig_inline_callbacks(f)
-
-        @functools.wraps(f)
-        def wrapped(*args, **kwargs):
-            start_context = LoggingContext.current_context()
-
-            try:
-                res = orig(*args, **kwargs)
-            except Exception:
-                if LoggingContext.current_context() != start_context:
-                    err = "%s changed context from %s to %s on exception" % (
-                        f,
-                        start_context,
-                        LoggingContext.current_context(),
-                    )
-                    print(err, file=sys.stderr)
-                    raise Exception(err)
-                raise
-
-            if not isinstance(res, Deferred) or res.called:
-                if LoggingContext.current_context() != start_context:
-                    err = "%s changed context from %s to %s" % (
-                        f,
-                        start_context,
-                        LoggingContext.current_context(),
-                    )
-                    # print the error to stderr because otherwise all we
-                    # see in travis-ci is the 500 error
-                    print(err, file=sys.stderr)
-                    raise Exception(err)
-                return res
-
-            if LoggingContext.current_context() != LoggingContext.sentinel:
-                err = (
-                    "%s returned incomplete deferred in non-sentinel context "
-                    "%s (start was %s)"
-                ) % (f, LoggingContext.current_context(), start_context)
-                print(err, file=sys.stderr)
-                raise Exception(err)
-
-            def check_ctx(r):
-                if LoggingContext.current_context() != start_context:
-                    err = "%s completion of %s changed context from %s to %s" % (
-                        "Failure" if isinstance(r, Failure) else "Success",
-                        f,
-                        start_context,
-                        LoggingContext.current_context(),
-                    )
-                    print(err, file=sys.stderr)
-                    raise Exception(err)
-                return r
-
-            res.addBoth(check_ctx)
-            return res
-
-        return wrapped
-
-    defer.inlineCallbacks = new_inline_callbacks
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 5877bb2133..d3a4f717f7 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -62,7 +62,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         self.device_handler.check_device_registered = Mock(return_value="FAKE")
 
         self.datastore = Mock(return_value=Mock())
-        self.datastore.get_current_state_deltas = Mock(return_value=[])
+        self.datastore.get_current_state_deltas = Mock(return_value=(0, []))
 
         self.secrets = Mock()
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index fe741637f5..2f2ca74611 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -484,6 +484,15 @@ class RoomsCreateTestCase(RoomBase):
         self.render(request)
         self.assertEquals(400, channel.code)
 
+    def test_post_room_invitees_invalid_mxid(self):
+        # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
+        # Note the trailing space in the MXID here!
+        request, channel = self.make_request(
+            "POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
+        )
+        self.render(request)
+        self.assertEquals(400, channel.code)
+
 
 class RoomTopicTestCase(RoomBase):
     """ Tests /rooms/$room_id/topic REST events. """
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index f42a8efbf4..e0e9e94fbf 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -92,7 +92,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         )
         self.render(request)
 
-        self.assertEqual(channel.result["code"], b"400")
+        self.assertEqual(channel.result["code"], b"404")
         self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
 
     # Currently invalid params do not have an appropriate errcode
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index b58386994e..2fe50377f8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -57,7 +57,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
                     "(event_id, algorithm, hash) "
                     "VALUES (?, 'sha256', ?)"
                 ),
-                (event_id, b"ffff"),
+                (event_id, bytearray(b"ffff")),
             )
 
         for i in range(0, 11):
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 1494650d10..90a63dc477 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -50,6 +50,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
             {"medium": "email", "address": user2_email},
             {"medium": "email", "address": user3_email},
         ]
+        self.hs.config.mau_limits_reserved_threepids = threepids
         # -1 because user3 is a support user and does not count
         user_num = len(threepids) - 1
 
@@ -84,6 +85,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.hs.config.max_mau_value = 0
 
         self.reactor.advance(FORTY_DAYS)
+        self.hs.config.max_mau_value = 5
 
         self.store.reap_monthly_active_users()
         self.pump()
@@ -147,9 +149,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.reap_monthly_active_users()
         self.pump()
         count = self.store.get_monthly_active_count()
-        self.assertEquals(
-            self.get_success(count), initial_users - self.hs.config.max_mau_value
-        )
+        self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
 
         self.reactor.advance(FORTY_DAYS)
         self.store.reap_monthly_active_users()
@@ -158,6 +158,44 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         count = self.store.get_monthly_active_count()
         self.assertEquals(self.get_success(count), 0)
 
+    def test_reap_monthly_active_users_reserved_users(self):
+        """ Tests that reaping correctly handles reaping where reserved users are
+        present"""
+
+        self.hs.config.max_mau_value = 5
+        initial_users = 5
+        reserved_user_number = initial_users - 1
+        threepids = []
+        for i in range(initial_users):
+            user = "@user%d:server" % i
+            email = "user%d@example.com" % i
+            self.get_success(self.store.upsert_monthly_active_user(user))
+            threepids.append({"medium": "email", "address": email})
+            # Need to ensure that the most recent entries in the
+            # monthly_active_users table are reserved
+            now = int(self.hs.get_clock().time_msec())
+            if i != 0:
+                self.get_success(
+                    self.store.register_user(user_id=user, password_hash=None)
+                )
+                self.get_success(
+                    self.store.user_add_threepid(user, "email", email, now, now)
+                )
+
+        self.hs.config.mau_limits_reserved_threepids = threepids
+        self.store.runInteraction(
+            "initialise", self.store._initialise_reserved_users, threepids
+        )
+        count = self.store.get_monthly_active_count()
+        self.assertTrue(self.get_success(count), initial_users)
+
+        users = self.store.get_registered_reserved_users()
+        self.assertEquals(len(self.get_success(users)), reserved_user_number)
+
+        self.get_success(self.store.reap_monthly_active_users())
+        count = self.store.get_monthly_active_count()
+        self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
+
     def test_populate_monthly_users_is_guest(self):
         # Test that guest users are not added to mau list
         user_id = "@user_id:host"
@@ -192,12 +230,13 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
     def test_get_reserved_real_user_account(self):
         # Test no reserved users, or reserved threepids
-        count = self.store.get_registered_reserved_users_count()
-        self.assertEquals(self.get_success(count), 0)
+        users = self.get_success(self.store.get_registered_reserved_users())
+        self.assertEquals(len(users), 0)
         # Test reserved users but no registered users
 
         user1 = "@user1:example.com"
         user2 = "@user2:example.com"
+
         user1_email = "user1@example.com"
         user2_email = "user2@example.com"
         threepids = [
@@ -210,8 +249,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         )
 
         self.pump()
-        count = self.store.get_registered_reserved_users_count()
-        self.assertEquals(self.get_success(count), 0)
+        users = self.get_success(self.store.get_registered_reserved_users())
+        self.assertEquals(len(users), 0)
 
         # Test reserved registed users
         self.store.register_user(user_id=user1, password_hash=None)
@@ -221,8 +260,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         now = int(self.hs.get_clock().time_msec())
         self.store.user_add_threepid(user1, "email", user1_email, now, now)
         self.store.user_add_threepid(user2, "email", user2_email, now, now)
-        count = self.store.get_registered_reserved_users_count()
-        self.assertEquals(self.get_success(count), len(threepids))
+
+        users = self.get_success(self.store.get_registered_reserved_users())
+        self.assertEquals(len(users), len(threepids))
 
     def test_support_user_not_add_to_mau_limits(self):
         support_user_id = "@support:test"