summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_e2e_keys.py8
-rw-r--r--tests/replication/slave/storage/_base.py5
-rw-r--r--tests/rest/client/v2_alpha/test_register.py3
-rw-r--r--tests/storage/test_appservice.py17
-rw-r--r--tests/storage/test_base.py3
-rw-r--r--tests/storage/test_client_ips.py49
-rw-r--r--tests/storage/test_profile.py3
-rw-r--r--tests/storage/test_user_directory.py4
-rw-r--r--tests/test_federation.py30
-rw-r--r--tests/util/test_logcontext.py24
10 files changed, 94 insertions, 52 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 854eb6c024..fdfa2cbbc4 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -183,6 +183,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
+    test_replace_master_key.skip = (
+        "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
+    )
+
     @defer.inlineCallbacks
     def test_reupload_signatures(self):
         """re-uploading a signature should not fail"""
@@ -503,3 +507,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             ],
             other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
         )
+
+    test_upload_signatures.skip = (
+        "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
+    )
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index e7472e3a93..3dae83c543 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,6 +20,7 @@ from synapse.replication.tcp.client import (
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.storage.database import Database
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -42,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
 
         self.master_store = self.hs.get_datastore()
         self.storage = hs.get_storage()
-        self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
+        self.slaved_store = self.STORE_TYPE(
+            Database(hs), self.hs.get_db_conn(), self.hs
+        )
         self.event_id = 0
 
         server_factory = ReplicationStreamProtocolFactory(self.hs)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index c0d0d2b44e..d0c997e385 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         # Email config.
         self.email_attempts = []
 
-        def sendmail(*args, **kwargs):
+        async def sendmail(*args, **kwargs):
             self.email_attempts.append((args, kwargs))
-            return
 
         config["email"] = {
             "enable_notifs": True,
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index dfeea24599..2e521e9ab7 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
+from synapse.storage.database import Database
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -54,7 +55,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
+        database = Database(hs)
+        self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -123,7 +125,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        self.store = TestTransactionStore(hs.get_db_conn(), hs)
+        database = Database(hs)
+        self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(
@@ -382,8 +385,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
-    def __init__(self, db_conn, hs):
-        super(TestTransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TestTransactionStore, self).__init__(database, db_conn, hs)
 
 
 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -416,7 +419,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        ApplicationServiceStore(hs.get_db_conn(), hs)
+        ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -432,7 +435,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -453,7 +456,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 7915d48a9e..537cfe9f64 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,6 +21,7 @@ from mock import Mock
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 
 from tests import unittest
@@ -59,7 +60,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
         )
 
-        self.datastore = SQLBaseStore(None, hs)
+        self.datastore = SQLBaseStore(Database(hs), None, hs)
 
     @defer.inlineCallbacks
     def test_insert_1col(self):
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index fc279340d4..bf674dd184 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(12345678)
 
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(10)
 
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 12345678000,
@@ -209,14 +213,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                 self.store.db.updates.do_next_background_update(100), by=0.1
             )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
-
         # Force persisting to disk
         self.reactor.advance(200)
 
@@ -224,7 +230,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.get_success(
             self.store.db.simple_update(
                 table="devices",
-                keyvalues={"user_id": user_id, "device_id": "device_id"},
+                keyvalues={"user_id": user_id, "device_id": device_id},
                 updatevalues={"last_seen": None, "ip": None, "user_agent": None},
                 desc="test_devices_last_seen_bg_update",
             )
@@ -232,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should now get nulls when querying
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": None,
                 "user_agent": None,
                 "last_seen": None,
@@ -272,14 +278,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should now get the correct result again
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
@@ -296,11 +302,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                 self.store.db.updates.do_next_background_update(100), by=0.1
             )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -324,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                     "access_token": "access_token",
                     "ip": "ip",
                     "user_agent": "user_agent",
-                    "device_id": "device_id",
+                    "device_id": device_id,
                     "last_seen": 0,
                 }
             ],
@@ -347,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # But we should still get the correct values for the device
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24c7fe16c3..9b6f7211ae 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -16,7 +16,6 @@
 
 from twisted.internet import defer
 
-from synapse.storage.data_stores.main.profile import ProfileStore
 from synapse.types import UserID
 
 from tests import unittest
@@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
 
-        self.store = ProfileStore(hs.get_db_conn(), hs)
+        self.store = hs.get_datastore()
 
         self.u_frank = UserID.from_string("@frank:test")
 
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7eea57c0e2..6a545d2eb0 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -15,8 +15,6 @@
 
 from twisted.internet import defer
 
-from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
-
 from tests import unittest
 from tests.utils import setup_test_homeserver
 
@@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
+        self.store = self.hs.get_datastore()
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 7d82b58466..68684460c6 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,6 +1,6 @@
 from mock import Mock
 
-from twisted.internet.defer import maybeDeferred, succeed
+from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
 
 from synapse.events import FrozenEvent
 from synapse.logging.context import LoggingContext
@@ -33,6 +33,8 @@ class MessageAcceptTests(unittest.TestCase):
         self.reactor.advance(0.1)
         self.room_id = self.successResultOf(room)["room_id"]
 
+        self.store = self.homeserver.get_datastore()
+
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
             maybeDeferred(
@@ -68,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         # Send the join, it should return None (which is not an error)
-        d = self.handler.on_receive_pdu(
-            "test.serv", join_event, sent_to_us_directly=True
+        d = ensureDeferred(
+            self.handler.on_receive_pdu(
+                "test.serv", join_event, sent_to_us_directly=True
+            )
         )
         self.reactor.advance(1)
         self.assertEqual(self.successResultOf(d), None)
@@ -77,10 +81,7 @@ class MessageAcceptTests(unittest.TestCase):
         # Make sure we actually joined the room
         self.assertEqual(
             self.successResultOf(
-                maybeDeferred(
-                    self.homeserver.get_datastore().get_latest_event_ids_in_room,
-                    self.room_id,
-                )
+                maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
             )[0],
             "$join:test.serv",
         )
@@ -100,10 +101,7 @@ class MessageAcceptTests(unittest.TestCase):
 
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
-            maybeDeferred(
-                self.homeserver.get_datastore().get_latest_event_ids_in_room,
-                self.room_id,
-            )
+            maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
         )[0]
 
         # Now lie about an event
@@ -123,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         with LoggingContext(request="lying_event"):
-            d = self.handler.on_receive_pdu(
-                "test.serv", lying_event, sent_to_us_directly=True
+            d = ensureDeferred(
+                self.handler.on_receive_pdu(
+                    "test.serv", lying_event, sent_to_us_directly=True
+                )
             )
 
             # Step the reactor, so the database fetches come back
@@ -141,7 +141,5 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         # Make sure the invalid event isn't there
-        extrem = maybeDeferred(
-            self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id
-        )
+        extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
         self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8b8455c8b7..281b32c4b8 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase):
             nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.request, "foo-bar")
 
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable_with_await(self):
+        # an async function which retuns an incomplete coroutine, but doesn't
+        # follow the synapse rules.
+
+        async def blocking_function():
+            d = defer.Deferred()
+            reactor.callLater(0, d.callback, None)
+            await d
+
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+
+            d1 = make_deferred_yieldable(blocking_function())
+            # make sure that the context was reset by make_deferred_yieldable
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+            yield d1
+
+            # now it should be restored
+            self._check_test_key("one")
+
 
 # a function which returns a deferred which has been "called", but
 # which had a function which returned another incomplete deferred on