diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 319e2c2325..f5afed017c 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -99,7 +99,7 @@ class CacheTestCase(unittest.HomeserverTestCase):
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def test_passthrough(self):
- class A(object):
+ class A:
@cached()
def func(self, key):
return key
@@ -113,7 +113,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_hit(self):
callcount = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -131,7 +131,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_invalidate(self):
callcount = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -149,7 +149,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
- class A(object):
+ class A:
@cached()
def func(self, key):
return key
@@ -160,7 +160,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_max_entries(self):
callcount = [0]
- class A(object):
+ class A:
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
@@ -187,7 +187,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
d = defer.succeed(123)
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -205,7 +205,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0]
callcount2 = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -238,7 +238,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0]
callcount2 = [0]
- class A(object):
+ class A:
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
@@ -275,7 +275,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0]
callcount2 = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 98b74890d5..cb808d4de4 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
)
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -207,7 +208,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +222,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -234,7 +243,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def test_create_appservice_txn_first(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -246,7 +257,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_last_txn(service.id, 9643) # AS is falling behind
yield self._insert_txn(service.id, 9644, events)
yield self._insert_txn(service.id, 9645, events)
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -256,7 +269,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643)
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -277,7 +292,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(self.as_list[2]["id"], 10, events)
yield self._insert_txn(self.as_list[3]["id"], 9643, events)
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -289,7 +306,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
txn_id = 1
yield self._insert_txn(service.id, txn_id, events)
- yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ yield defer.ensureDeferred(
+ self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ )
res = yield self.db_pool.runQuery(
self.engine.convert_param_style(
@@ -315,7 +334,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
txn_id = 5
yield self._set_last_txn(service.id, 4)
yield self._insert_txn(service.id, txn_id, events)
- yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ yield defer.ensureDeferred(
+ self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ )
res = yield self.db_pool.runQuery(
self.engine.convert_param_style(
@@ -349,7 +370,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
+ self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 2efbc97c2e..02aae1c13d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,7 +1,5 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest
@@ -38,11 +36,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
# first step: make a bit of progress
- @defer.inlineCallbacks
- def update(progress, count):
- yield self.clock.sleep((count * duration_ms) / 1000)
+ async def update(progress, count):
+ await self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
- yield store.db_pool.runInteraction(
+ await store.db_pool.runInteraction(
"update_progress",
self.updates._background_update_progress_txn,
"test_update",
@@ -67,13 +64,12 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# second step: complete the update
# we should now get run with a much bigger number of items to update
- @defer.inlineCallbacks
- def update(progress, count):
+ async def update(progress, count):
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
count, target_background_update_duration_ms / duration_ms, places=0,
)
- yield self.updates._end_background_update("test_update")
+ await self.updates._end_background_update("test_update")
return count
self.update_handler.side_effect = update
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index efcaeef1e7..40ba652248 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_insert(
- table="tablename", values={"columname": "Value"}
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename", values={"columname": "Value"}
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_insert(
- table="tablename",
- # Use OrderedDict() so we can assert on the SQL generated
- values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename",
+ # Use OrderedDict() so we can assert on the SQL generated
+ values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -93,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
- value = yield self.datastore.db_pool.simple_select_one_onecol(
- table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ value = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one_onecol(
+ table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ )
)
self.assertEquals("Value", value)
@@ -107,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
- ret = yield self.datastore.db_pool.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "TheKey"},
- retcols=["colA", "colB", "colC"],
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "TheKey"},
+ retcols=["colA", "colB", "colC"],
+ )
)
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@@ -123,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
- ret = yield self.datastore.db_pool.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "Not here"},
- retcols=["colA"],
- allow_none=True,
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "Not here"},
+ retcols=["colA"],
+ allow_none=True,
+ )
)
self.assertFalse(ret)
@@ -138,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
- ret = yield self.datastore.db_pool.simple_select_list(
- table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_list(
+ table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+ )
)
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
@@ -151,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_update_one(
- table="tablename",
- keyvalues={"keycol": "TheKey"},
- updatevalues={"columnname": "New Value"},
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_update_one(
+ table="tablename",
+ keyvalues={"keycol": "TheKey"},
+ updatevalues={"columnname": "New Value"},
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -166,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_4cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_update_one(
- table="tablename",
- keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
- updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_update_one(
+ table="tablename",
+ keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
+ updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -181,8 +197,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_delete_one(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_delete_one(
- table="tablename", keyvalues={"keycol": "Go away"}
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_delete_one(
+ table="tablename", keyvalues={"keycol": "Go away"}
+ )
)
self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 3fab5a5248..080761d1d2 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
@@ -271,7 +271,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Pump the reactor repeatedly so that the background updates have a
# chance to run.
- self.pump(10 * 60)
+ self.pump(20)
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
"3"
] = 300000
+
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
# All entries within time frame
self.assertEqual(
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
3,
)
# Oldest room to expire
- self.pump(1)
+ self.pump(1.01)
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
self.assertEqual(
len(
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 224ea6fd79..370c247e16 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -16,13 +16,12 @@
from mock import Mock
-from twisted.internet import defer
-
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -155,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
user_id = "@user:server"
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(lots_of_users)
+ side_effect=lambda: make_awaitable(lots_of_users)
)
self.get_success(
self.store.insert_client_ip(
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 87ed8f8cd1..34ae8c9da7 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name")
)
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name 1")
)
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
@@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# check it worked
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index daac947cb2..da93ca3980 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -42,7 +42,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals(
["#my-room:test"],
- (yield self.store.get_aliases_for_room(self.room.to_string())),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_aliases_for_room(self.room.to_string())
+ )
+ ),
)
@defer.inlineCallbacks
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index d57cdffd8b..3fc4bb13b6 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -32,10 +32,12 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(self.store.store_device("user", "device", None))
- yield self.store.set_e2e_device_keys("user", "device", now, json)
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user", "device"),))
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -49,12 +51,16 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(self.store.store_device("user", "device", None))
- changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+ changed = yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing
# changed
- changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+ changed = yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
self.assertFalse(changed)
@defer.inlineCallbacks
@@ -62,13 +68,15 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070
json = {"key": "value"}
- yield self.store.set_e2e_device_keys("user", "device", now, json)
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
yield defer.ensureDeferred(
self.store.store_device("user", "device", "display_name")
)
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user", "device"),))
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -86,13 +94,23 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
- yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
- yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
- yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
- yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
+ )
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
+ )
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
+ )
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+ )
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
+ self.store.get_e2e_device_keys_for_cs_api(
+ (("user1", "device1"), ("user2", "device2"))
+ )
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a7b85004e5..949846fe33 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
# Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)]
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 857db071d4..c0595963dd 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -60,12 +60,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield self.store.db_pool.runInteraction(
- "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+ counts = yield defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+ )
)
self.assertEquals(
counts,
- {"notify_count": noitf_count, "highlight_count": highlight_count},
+ {
+ "notify_count": noitf_count,
+ "unread_count": 0, # Unread counts are tested in the sync tests.
+ "highlight_count": highlight_count,
+ },
)
@defer.inlineCallbacks
@@ -78,28 +84,34 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(
self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}
+ event.event_id, {user_id: action}, False,
)
)
- yield self.store.db_pool.runInteraction(
- "",
- self.persist_events_store._set_push_actions_for_event_and_users_txn,
- [(event, None)],
- [(event, None)],
+ yield defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "",
+ self.persist_events_store._set_push_actions_for_event_and_users_txn,
+ [(event, None)],
+ [(event, None)],
+ )
)
def _rotate(stream):
- return self.store.db_pool.runInteraction(
- "", self.store._rotate_notifs_before_txn, stream
+ return defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "", self.store._rotate_notifs_before_txn, stream
+ )
)
def _mark_read(stream, depth):
- return self.store.db_pool.runInteraction(
- "",
- self.store._remove_old_push_actions_before_txn,
- room_id,
- user_id,
- stream,
+ return defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "",
+ self.store._remove_old_push_actions_before_txn,
+ room_id,
+ user_id,
+ stream,
+ )
)
yield _assert_counts(0, 0)
@@ -123,8 +135,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
- yield self.store.db_pool.simple_delete(
- table="event_push_actions", keyvalues={"1": 1}, desc=""
+ yield defer.ensureDeferred(
+ self.store.db_pool.simple_delete(
+ table="event_push_actions", keyvalues={"1": 1}, desc=""
+ )
)
yield _assert_counts(1, 0)
@@ -142,33 +156,43 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return self.store.db_pool.simple_insert(
- "events",
- {
- "stream_ordering": so,
- "received_ts": ts,
- "event_id": "event%i" % so,
- "type": "",
- "room_id": "",
- "content": "",
- "processed": True,
- "outlier": False,
- "topological_ordering": 0,
- "depth": 0,
- },
+ return defer.ensureDeferred(
+ self.store.db_pool.simple_insert(
+ "events",
+ {
+ "stream_ordering": so,
+ "received_ts": ts,
+ "event_id": "event%i" % so,
+ "type": "",
+ "room_id": "",
+ "content": "",
+ "processed": True,
+ "outlier": False,
+ "topological_ordering": 0,
+ "depth": 0,
+ },
+ )
)
# start with the base case where there are no events in the table
- r = yield self.store.find_first_stream_ordering_after_ts(11)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(11)
+ )
self.assertEqual(r, 0)
# now with one event
yield add_event(2, 10)
- r = yield self.store.find_first_stream_ordering_after_ts(9)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(9)
+ )
self.assertEqual(r, 2)
- r = yield self.store.find_first_stream_ordering_after_ts(10)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(10)
+ )
self.assertEqual(r, 2)
- r = yield self.store.find_first_stream_ordering_after_ts(11)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(11)
+ )
self.assertEqual(r, 3)
# add a bunch of dummy events to the events table
@@ -181,25 +205,37 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
):
yield add_event(stream_ordering, ts)
- r = yield self.store.find_first_stream_ordering_after_ts(110)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(110)
+ )
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5
- r = yield self.store.find_first_stream_ordering_after_ts(120)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(120)
+ )
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
- r = yield self.store.find_first_stream_ordering_after_ts(129)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(129)
+ )
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event
- r = yield self.store.find_first_stream_ordering_after_ts(140)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(140)
+ )
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end
- r = yield self.store.find_first_stream_ordering_after_ts(160)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(160)
+ )
self.assertEqual(r, 21)
# check we can find an event at ordering zero
yield add_event(0, 5)
- r = yield self.store.find_first_stream_ordering_after_ts(1)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(1)
+ )
self.assertEqual(r, 0)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index e845410dae..f0a8e32f1e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -58,6 +58,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
+ """Insert N rows as the given instance, inserting with stream IDs pulled
+ from the postgres sequence.
+ """
+
def _insert(txn):
for _ in range(number):
txn.execute(
@@ -65,7 +69,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
(instance_name,),
)
- self.get_success(self.db_pool.runInteraction("test_single_instance", _insert))
+ self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+ def _insert_row_with_id(self, instance_name: str, stream_id: int):
+ """Insert one row as the given instance with given stream_id, updating
+ the postgres sequence position to match.
+ """
+
+ def _insert(txn):
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ )
+ txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+
+ self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
def test_empty(self):
"""Test an ID generator against an empty database gives sensible
@@ -88,7 +105,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -98,12 +115,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
@@ -116,8 +133,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second")
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token("first"), 3)
- self.assertEqual(first_id_gen.get_current_token("second"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -166,7 +183,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -176,9 +193,179 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+
+ def test_get_persisted_upto_position(self):
+ """Test that `get_persisted_upto_position` correctly tracks updates to
+ positions.
+ """
+
+ # The following tests are a bit cheeky in that we notify about new
+ # positions via `advance` without *actually* advancing the postgres
+ # sequence.
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ id_gen = self._create_id_generator("first")
+
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+ # Min is 3 and there is a gap between 5, so we expect it to be 3.
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ # We advance "first" straight to 6. Min is now 5 but there is no gap so
+ # we expect it to be 6
+ id_gen.advance("first", 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+ # No gap, so we expect 7.
+ id_gen.advance("second", 7)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # We haven't seen 8 yet, so we expect 7 still.
+ id_gen.advance("second", 9)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # Now that we've seen 7, 8 and 9 we can got straight to 9.
+ id_gen.advance("first", 8)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 9)
+
+ # Jump forward with gaps. The minimum is 11, even though we haven't seen
+ # 10 we know that everything before 11 must be persisted.
+ id_gen.advance("first", 11)
+ id_gen.advance("second", 15)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 11)
+
+ def test_get_persisted_upto_position_get_next(self):
+ """Test that `get_persisted_upto_position` correctly tracks updates to
+ positions when `get_next` is called.
+ """
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ id_gen = self._create_id_generator("first")
+
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ with self.get_success(id_gen.get_next()) as stream_id:
+ self.assertEqual(stream_id, 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+ # We assume that so long as `get_next` does correctly advance the
+ # `persisted_upto_position` in this case, then it will be correct in the
+ # other cases that are tested above (since they'll hit the same code).
+
+
+class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
+ """
+
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db_pool = self.store.db_pool # type: DatabasePool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db_pool,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ positive=False,
+ )
+
+ return self.get_success(self.db_pool.runWithConnection(_create))
+
+ def _insert_row(self, instance_name: str, stream_id: int):
+ """Insert one row as the given instance with given stream_id.
+ """
+
+ def _insert(txn):
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+ id_gen = self._create_id_generator()
+
+ with self.get_success(id_gen.get_next()) as stream_id:
+ self._insert_row("master", stream_id)
+
+ self.assertEqual(id_gen.get_positions(), {"master": -1})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -1)
+
+ with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
+ for stream_id in stream_ids:
+ self._insert_row("master", stream_id)
+
+ self.assertEqual(id_gen.get_positions(), {"master": -4})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -4)
+
+ # Test loading from DB by creating a second ID gen
+ second_id_gen = self._create_id_generator()
+
+ self.assertEqual(second_id_gen.get_positions(), {"master": -4})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
+
+ def test_multiple_instance(self):
+ """Tests that having multiple instances that get advanced over
+ federation works corretly.
+ """
+ id_gen_1 = self._create_id_generator("first")
+ id_gen_2 = self._create_id_generator("second")
+
+ with self.get_success(id_gen_1.get_next()) as stream_id:
+ self._insert_row("first", stream_id)
+ id_gen_2.advance("first", stream_id)
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
+
+ with self.get_success(id_gen_2.get_next()) as stream_id:
+ self._insert_row("second", stream_id)
+ id_gen_1.advance("second", stream_id)
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..7e7f1286d9 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -34,12 +34,16 @@ class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_users_paginate(self):
- yield self.store.register_user(self.user.to_string(), "pass")
- yield self.store.create_profile(self.user.localpart)
- yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ yield defer.ensureDeferred(
+ self.store.register_user(self.user.to_string(), "pass")
+ )
+ yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ )
- users, total = yield self.store.get_users_paginate(
- 0, 10, name="bc", guests=False
+ users, total = yield defer.ensureDeferred(
+ self.store.get_users_paginate(0, 10, name="bc", guests=False)
)
self.assertEquals(1, total)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..3fd0a38cf5 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,23 +33,36 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(self.u_frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
- yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ )
self.assertEquals(
- "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
+ "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.u_frank.localpart)
+ )
+ ),
)
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(self.u_frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
- yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.u_frank.localpart, "http://my.site/here"
+ )
)
self.assertEquals(
"http://my.site/here",
- (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.u_frank.localpart)
+ )
+ ),
)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index a6012c973d..918387733b 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+from synapse.api.errors import NotFoundError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -46,30 +47,19 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
- event = store.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
-
- # Purge everything before this topological token
- purge = defer.ensureDeferred(
- storage.purge_events.purge_history(self.room_id, event, True)
+ event = self.get_success(
+ store.get_topological_token_for_event(last["event_id"])
)
- self.pump()
- self.assertEqual(self.successResultOf(purge), None)
- # Try and get the events
- get_first = store.get_event(first["event_id"])
- get_second = store.get_event(second["event_id"])
- get_third = store.get_event(third["event_id"])
- get_last = store.get_event(last["event_id"])
- self.pump()
+ # Purge everything before this topological token
+ self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
- self.failureResultOf(get_first)
- self.failureResultOf(get_second)
- self.failureResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_failure(store.get_event(first["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(second["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(third["event_id"]), NotFoundError)
+ self.get_success(store.get_event(last["event_id"]))
def test_purge_wont_delete_extrems(self):
"""
@@ -84,9 +74,9 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
- event = storage.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
+ event = self.get_success(
+ storage.get_topological_token_for_event(last["event_id"])
+ )
event = "t{}-{}".format(
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
)
@@ -98,14 +88,7 @@ class PurgeTests(HomeserverTestCase):
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
- get_first = storage.get_event(first["event_id"])
- get_second = storage.get_event(second["event_id"])
- get_third = storage.get_event(third["event_id"])
- get_last = storage.get_event(last["event_id"])
- self.pump()
-
- # Nothing is deleted.
- self.successResultOf(get_first)
- self.successResultOf(get_second)
- self.successResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_success(storage.get_event(first["event_id"]))
+ self.get_success(storage.get_event(second["event_id"]))
+ self.get_success(storage.get_event(third["event_id"]))
+ self.get_success(storage.get_event(last["event_id"]))
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 840db66072..6b582771fe 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,6 +17,7 @@
from twisted.internet import defer
from synapse.api.constants import UserTypes
+from synapse.api.errors import ThreepidValidationError
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -36,7 +37,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_register(self):
- yield self.store.register_user(self.user_id, self.pwhash)
+ yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals(
{
@@ -52,19 +53,21 @@ class RegistrationStoreTestCase(unittest.TestCase):
"user_type": None,
"deactivated": 0,
},
- (yield self.store.get_user_by_id(self.user_id)),
+ (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
)
@defer.inlineCallbacks
def test_add_tokens(self):
- yield self.store.register_user(self.user_id, self.pwhash)
+ yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
- result = yield self.store.get_user_by_access_token(self.tokens[1])
+ result = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[1])
+ )
self.assertDictContainsSubset(
{"name": self.user_id, "device_id": self.device_id}, result
@@ -75,7 +78,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
- yield self.store.register_user(self.user_id, self.pwhash)
+ yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred(
self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
@@ -88,22 +91,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
)
# now delete some
- yield self.store.user_delete_access_tokens(
- self.user_id, device_id=self.device_id
+ yield defer.ensureDeferred(
+ self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
)
# check they were deleted
- user = yield self.store.get_user_by_access_token(self.tokens[1])
+ user = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[1])
+ )
self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted
- user = yield self.store.get_user_by_access_token(self.tokens[0])
+ user = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[0])
+ )
self.assertEqual(self.user_id, user["name"])
# now delete the rest
- yield self.store.user_delete_access_tokens(self.user_id)
+ yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
- user = yield self.store.get_user_by_access_token(self.tokens[0])
+ user = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[0])
+ )
self.assertIsNone(user, "access token was not deleted without device_id")
@defer.inlineCallbacks
@@ -111,14 +120,48 @@ class RegistrationStoreTestCase(unittest.TestCase):
TEST_USER = "@test:test"
SUPPORT_USER = "@support:test"
- res = yield self.store.is_support_user(None)
+ res = yield defer.ensureDeferred(self.store.is_support_user(None))
self.assertFalse(res)
- yield self.store.register_user(user_id=TEST_USER, password_hash=None)
- res = yield self.store.is_support_user(TEST_USER)
+ yield defer.ensureDeferred(
+ self.store.register_user(user_id=TEST_USER, password_hash=None)
+ )
+ res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
self.assertFalse(res)
- yield self.store.register_user(
- user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+ yield defer.ensureDeferred(
+ self.store.register_user(
+ user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+ )
)
- res = yield self.store.is_support_user(SUPPORT_USER)
+ res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
+
+ @defer.inlineCallbacks
+ def test_3pid_inhibit_invalid_validation_session_error(self):
+ """Tests that enabling the configuration option to inhibit 3PID errors on
+ /requestToken also inhibits validation errors caused by an unknown session ID.
+ """
+
+ # Check that, with the config setting set to false (the default value), a
+ # validation error is caused by the unknown session ID.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Unknown session_id", e)
+
+ # Set the config setting to true.
+ self.store._ignore_unknown_session_error = True
+
+ # Check that now the validation error is caused by the token not matching.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index d07b985a8e..bc8400f240 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (yield self.store.get_room(self.room.to_string())),
+ (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_get_room_unknown_room(self):
- self.assertIsNone((yield self.store.get_room("!uknown:test")),)
+ self.assertIsNone(
+ (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
+ )
@defer.inlineCallbacks
def test_get_room_with_stats(self):
@@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
- (yield self.store.get_room_with_stats(self.room.to_string())),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats(self.room.to_string())
+ )
+ ),
)
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
- self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats("!uknown:test")
+ )
+ ),
+ )
class RoomEventsStoreTestCase(unittest.TestCase):
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 17c9da4838..12ccc1f53e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -87,7 +87,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
- self.pump(20)
+ self.pump()
self.assertTrue("_known_servers_count" not in self.store.__dict__.keys())
@@ -101,7 +101,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Initialises to 1 -- itself
self.assertEqual(self.store._known_servers_count, 1)
- self.pump(20)
+ self.pump()
# No rooms have been joined, so technically the SQL returns 0, but it
# will still say it knows about itself.
@@ -111,7 +111,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
- self.pump(20)
+ self.pump(1)
# It now knows about Charlie's server.
self.assertEqual(self.store._known_servers_count, 2)
@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index ecfafe68a9..738e912468 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -31,10 +31,18 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
- yield self.store.update_profile_in_user_dir(ALICE, "alice", None)
- yield self.store.update_profile_in_user_dir(BOB, "bob", None)
- yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
- yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+ yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(ALICE, "alice", None)
+ )
+ yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(BOB, "bob", None)
+ )
+ yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
+ )
+ yield defer.ensureDeferred(
+ self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+ )
@defer.inlineCallbacks
def test_search_user_dir(self):
|