diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 6d6f00c5c5..52eb05bfbf 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -18,14 +18,13 @@ from mock import Mock
from twisted.internet import defer
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached
from tests import unittest
class CacheTestCase(unittest.TestCase):
-
def setUp(self):
self.cache = Cache("test")
@@ -97,7 +96,6 @@ class CacheTestCase(unittest.TestCase):
class CacheDecoratorTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def test_passthrough(self):
class A(object):
@@ -180,8 +178,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
yield a.func(k)
self.assertTrue(
- callcount[0] >= 14,
- msg="Expected callcount >= 14, got %d" % (callcount[0])
+ callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
)
def test_prefill(self):
diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py
deleted file mode 100644
index f19cb1265c..0000000000
--- a/tests/storage/test__init__.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from twisted.internet import defer
-
-import tests.utils
-
-
-class InitTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(InitTestCase, self).__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
-
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver()
-
- hs.config.max_mau_value = 50
- hs.config.limit_usage_by_mau = True
- self.store = hs.get_datastore()
- self.clock = hs.get_clock()
-
- @defer.inlineCallbacks
- def test_count_monthly_users(self):
- count = yield self.store.count_monthly_users()
- self.assertEqual(0, count)
-
- yield self._insert_user_ips("@user:server1")
- yield self._insert_user_ips("@user:server2")
-
- count = yield self.store.count_monthly_users()
- self.assertEqual(2, count)
-
- @defer.inlineCallbacks
- def _insert_user_ips(self, user):
- """
- Helper function to populate user_ips without using batch insertion infra
- args:
- user (str): specify username i.e. @user:server.com
- """
- yield self.store._simple_upsert(
- table="user_ips",
- keyvalues={
- "user_id": user,
- "access_token": "access_token",
- "ip": "ip",
- "user_agent": "user_agent",
- "device_id": "device_id",
- },
- values={
- "last_seen": self.clock.time_msec(),
- }
- )
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 099861b27c..3f0083831b 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -34,35 +34,27 @@ from tests.utils import setup_test_homeserver
class ApplicationServiceStoreTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
self.as_yaml_files = []
- config = Mock(
- app_service_config_files=self.as_yaml_files,
- event_cache_size=1,
- password_providers=[],
- )
hs = yield setup_test_homeserver(
- config=config,
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+ hs.config.app_service_config_files = self.as_yaml_files
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
self.as_token = "token1"
self.as_url = "some_url"
self.as_id = "as1"
self._add_appservice(
- self.as_token,
- self.as_id,
- self.as_url,
- "some_hs_token",
- "bob"
+ self.as_token, self.as_id, self.as_url, "some_hs_token", "bob"
)
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
- self.store = ApplicationServiceStore(None, hs)
+ self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@@ -73,8 +65,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
pass
def _add_appservice(self, as_token, id, url, hs_token, sender):
- as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token,
- id=id, sender_localpart=sender, namespaces={})
+ as_yaml = dict(
+ url=url,
+ as_token=as_token,
+ hs_token=hs_token,
+ id=id,
+ sender_localpart=sender,
+ namespaces={},
+ )
# use the token as the filename
with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml))
@@ -85,24 +83,13 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self.assertEquals(service, None)
def test_retrieval_of_service(self):
- stored_service = self.store.get_app_service_by_token(
- self.as_token
- )
+ stored_service = self.store.get_app_service_by_token(self.as_token)
self.assertEquals(stored_service.token, self.as_token)
self.assertEquals(stored_service.id, self.as_id)
self.assertEquals(stored_service.url, self.as_url)
- self.assertEquals(
- stored_service.namespaces[ApplicationService.NS_ALIASES],
- []
- )
- self.assertEquals(
- stored_service.namespaces[ApplicationService.NS_ROOMS],
- []
- )
- self.assertEquals(
- stored_service.namespaces[ApplicationService.NS_USERS],
- []
- )
+ self.assertEquals(stored_service.namespaces[ApplicationService.NS_ALIASES], [])
+ self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], [])
+ self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], [])
def test_retrieval_of_all_services(self):
services = self.store.get_app_services()
@@ -110,107 +97,93 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
self.as_yaml_files = []
- config = Mock(
- app_service_config_files=self.as_yaml_files,
- event_cache_size=1,
- password_providers=[],
- )
hs = yield setup_test_homeserver(
- config=config,
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+
+ hs.config.app_service_config_files = self.as_yaml_files
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
self.db_pool = hs.get_db_pool()
+ self.engine = hs.database_engine
self.as_list = [
- {
- "token": "token1",
- "url": "https://matrix-as.org",
- "id": "id_1"
- },
- {
- "token": "alpha_tok",
- "url": "https://alpha.com",
- "id": "id_alpha"
- },
- {
- "token": "beta_tok",
- "url": "https://beta.com",
- "id": "id_beta"
- },
- {
- "token": "gamma_tok",
- "url": "https://gamma.com",
- "id": "id_gamma"
- },
+ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
+ {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
+ {"token": "beta_tok", "url": "https://beta.com", "id": "id_beta"},
+ {"token": "gamma_tok", "url": "https://gamma.com", "id": "id_gamma"},
]
for s in self.as_list:
yield self._add_service(s["url"], s["token"], s["id"])
self.as_yaml_files = []
- self.store = TestTransactionStore(None, hs)
+ self.store = TestTransactionStore(hs.get_db_conn(), hs)
def _add_service(self, url, as_token, id):
- as_yaml = dict(url=url, as_token=as_token, hs_token="something",
- id=id, sender_localpart="a_sender", namespaces={})
+ as_yaml = dict(
+ url=url,
+ as_token=as_token,
+ hs_token="something",
+ id=id,
+ sender_localpart="a_sender",
+ namespaces={},
+ )
# use the token as the filename
with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
def _set_state(self, id, state, txn=None):
- return self.db_pool.runQuery(
- "INSERT INTO application_services_state(as_id, state, last_txn) "
- "VALUES(?,?,?)",
- (id, state, txn)
+ return self.db_pool.runOperation(
+ self.engine.convert_param_style(
+ "INSERT INTO application_services_state(as_id, state, last_txn) "
+ "VALUES(?,?,?)"
+ ),
+ (id, state, txn),
)
def _insert_txn(self, as_id, txn_id, events):
- return self.db_pool.runQuery(
- "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
- "VALUES(?,?,?)",
- (as_id, txn_id, json.dumps([e.event_id for e in events]))
+ return self.db_pool.runOperation(
+ self.engine.convert_param_style(
+ "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
+ "VALUES(?,?,?)"
+ ),
+ (as_id, txn_id, json.dumps([e.event_id for e in events])),
)
def _set_last_txn(self, as_id, txn_id):
- return self.db_pool.runQuery(
- "INSERT INTO application_services_state(as_id, last_txn, state) "
- "VALUES(?,?,?)",
- (as_id, txn_id, ApplicationServiceState.UP)
+ return self.db_pool.runOperation(
+ self.engine.convert_param_style(
+ "INSERT INTO application_services_state(as_id, last_txn, state) "
+ "VALUES(?,?,?)"
+ ),
+ (as_id, txn_id, ApplicationServiceState.UP),
)
@defer.inlineCallbacks
def test_get_appservice_state_none(self):
- service = Mock(id=999)
+ service = Mock(id="999")
state = yield self.store.get_appservice_state(service)
self.assertEquals(None, state)
@defer.inlineCallbacks
def test_get_appservice_state_up(self):
- yield self._set_state(
- self.as_list[0]["id"], ApplicationServiceState.UP
- )
+ yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
service = Mock(id=self.as_list[0]["id"])
state = yield self.store.get_appservice_state(service)
self.assertEquals(ApplicationServiceState.UP, state)
@defer.inlineCallbacks
def test_get_appservice_state_down(self):
- yield self._set_state(
- self.as_list[0]["id"], ApplicationServiceState.UP
- )
- yield self._set_state(
- self.as_list[1]["id"], ApplicationServiceState.DOWN
- )
- yield self._set_state(
- self.as_list[2]["id"], ApplicationServiceState.DOWN
- )
+ yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
+ yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
+ yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
service = Mock(id=self.as_list[1]["id"])
state = yield self.store.get_appservice_state(service)
self.assertEquals(ApplicationServiceState.DOWN, state)
@@ -225,34 +198,26 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(
- service,
- ApplicationServiceState.DOWN
- )
+ yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
rows = yield self.db_pool.runQuery(
- "SELECT as_id FROM application_services_state WHERE state=?",
- (ApplicationServiceState.DOWN,)
+ self.engine.convert_param_style(
+ "SELECT as_id FROM application_services_state WHERE state=?"
+ ),
+ (ApplicationServiceState.DOWN,),
)
self.assertEquals(service.id, rows[0][0])
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(
- service,
- ApplicationServiceState.UP
- )
- yield self.store.set_appservice_state(
- service,
- ApplicationServiceState.DOWN
- )
- yield self.store.set_appservice_state(
- service,
- ApplicationServiceState.UP
- )
+ yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
rows = yield self.db_pool.runQuery(
- "SELECT as_id FROM application_services_state WHERE state=?",
- (ApplicationServiceState.UP,)
+ self.engine.convert_param_style(
+ "SELECT as_id FROM application_services_state WHERE state=?"
+ ),
+ (ApplicationServiceState.UP,),
)
self.assertEquals(service.id, rows[0][0])
@@ -318,15 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
- "SELECT last_txn FROM application_services_state WHERE as_id=?",
- (service.id,)
+ self.engine.convert_param_style(
+ "SELECT last_txn FROM application_services_state WHERE as_id=?"
+ ),
+ (service.id,),
)
self.assertEquals(1, len(res))
self.assertEquals(txn_id, res[0][0])
res = yield self.db_pool.runQuery(
- "SELECT * FROM application_services_txns WHERE txn_id=?",
- (txn_id,)
+ self.engine.convert_param_style(
+ "SELECT * FROM application_services_txns WHERE txn_id=?"
+ ),
+ (txn_id,),
)
self.assertEquals(0, len(res))
@@ -340,17 +309,20 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
- "SELECT last_txn, state FROM application_services_state WHERE "
- "as_id=?",
- (service.id,)
+ self.engine.convert_param_style(
+ "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
+ ),
+ (service.id,),
)
self.assertEquals(1, len(res))
self.assertEquals(txn_id, res[0][0])
self.assertEquals(ApplicationServiceState.UP, res[0][1])
res = yield self.db_pool.runQuery(
- "SELECT * FROM application_services_txns WHERE txn_id=?",
- (txn_id,)
+ self.engine.convert_param_style(
+ "SELECT * FROM application_services_txns WHERE txn_id=?"
+ ),
+ (txn_id,),
)
self.assertEquals(0, len(res))
@@ -382,12 +354,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_appservices_by_state_single(self):
- yield self._set_state(
- self.as_list[0]["id"], ApplicationServiceState.DOWN
- )
- yield self._set_state(
- self.as_list[1]["id"], ApplicationServiceState.UP
- )
+ yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
+ yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
services = yield self.store.get_appservices_by_state(
ApplicationServiceState.DOWN
@@ -397,18 +365,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_appservices_by_state_multiple(self):
- yield self._set_state(
- self.as_list[0]["id"], ApplicationServiceState.DOWN
- )
- yield self._set_state(
- self.as_list[1]["id"], ApplicationServiceState.UP
- )
- yield self._set_state(
- self.as_list[2]["id"], ApplicationServiceState.DOWN
- )
- yield self._set_state(
- self.as_list[3]["id"], ApplicationServiceState.UP
- )
+ yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
+ yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
+ yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
+ yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
services = yield self.store.get_appservices_by_state(
ApplicationServiceState.DOWN
@@ -416,20 +376,17 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.assertEquals(2, len(services))
self.assertEquals(
set([self.as_list[2]["id"], self.as_list[0]["id"]]),
- set([services[0].id, services[1].id])
+ set([services[0].id, services[1].id]),
)
# required for ApplicationServiceTransactionStoreTestCase tests
-class TestTransactionStore(ApplicationServiceTransactionStore,
- ApplicationServiceStore):
-
+class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
def __init__(self, db_conn, hs):
super(TestTransactionStore, self).__init__(db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
-
def _write_config(self, suffix, **kwargs):
vals = {
"id": "id" + suffix,
@@ -451,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2")
- config = Mock(
- app_service_config_files=[f1, f2], event_cache_size=1,
- password_providers=[]
- )
hs = yield setup_test_homeserver(
- config=config,
- datastore=Mock(),
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
- ApplicationServiceStore(None, hs)
+ hs.config.app_service_config_files = [f1, f2]
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
+ ApplicationServiceStore(hs.get_db_conn(), hs)
@defer.inlineCallbacks
def test_duplicate_ids(self):
f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2")
- config = Mock(
- app_service_config_files=[f1, f2], event_cache_size=1,
- password_providers=[]
- )
hs = yield setup_test_homeserver(
- config=config,
- datastore=Mock(),
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+ hs.config.app_service_config_files = [f1, f2]
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(None, hs)
+ ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
@@ -493,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2")
- config = Mock(
- app_service_config_files=[f1, f2], event_cache_size=1,
- password_providers=[]
- )
hs = yield setup_test_homeserver(
- config=config,
- datastore=Mock(),
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+ hs.config.app_service_config_files = [f1, f2]
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(None, hs)
+ ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index ab1f310572..81403727c5 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -7,10 +7,11 @@ from tests.utils import setup_test_homeserver
class BackgroundUpdateTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
- hs = yield setup_test_homeserver() # type: synapse.server.HomeServer
+ hs = yield setup_test_homeserver(
+ self.addCleanup
+ ) # type: synapse.server.HomeServer
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -51,9 +52,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
yield self.store.start_background_update("test_update", {"my_key": 1})
self.update_handler.reset_mock()
- result = yield self.store.do_next_background_update(
- duration_ms * desired_count
- )
+ result = yield self.store.do_next_background_update(duration_ms * desired_count)
self.assertIsNotNone(result)
self.update_handler.assert_called_once_with(
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
@@ -67,18 +66,12 @@ class BackgroundUpdateTestCase(unittest.TestCase):
self.update_handler.side_effect = update
self.update_handler.reset_mock()
- result = yield self.store.do_next_background_update(
- duration_ms * desired_count
- )
+ result = yield self.store.do_next_background_update(duration_ms * desired_count)
self.assertIsNotNone(result)
- self.update_handler.assert_called_once_with(
- {"my_key": 2}, desired_count
- )
+ self.update_handler.assert_called_once_with({"my_key": 2}, desired_count)
# third step: we don't expect to be called any more
self.update_handler.reset_mock()
- result = yield self.store.do_next_background_update(
- duration_ms * desired_count
- )
+ result = yield self.store.do_next_background_update(duration_ms * desired_count)
self.assertIsNone(result)
self.assertFalse(self.update_handler.called)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 1d1234ee39..829f47d2e8 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -20,11 +20,11 @@ from mock import Mock
from twisted.internet import defer
-from synapse.server import HomeServer
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import create_engine
from tests import unittest
+from tests.utils import TestHomeServer
class SQLBaseStoreTestCase(unittest.TestCase):
@@ -40,16 +40,18 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def runInteraction(func, *args, **kwargs):
return defer.succeed(func(self.mock_txn, *args, **kwargs))
+
self.db_pool.runInteraction = runInteraction
def runWithConnection(func, *args, **kwargs):
return defer.succeed(func(self.mock_conn, *args, **kwargs))
+
self.db_pool.runWithConnection = runWithConnection
config = Mock()
config.event_cache_size = 1
config.database_config = {"name": "sqlite3"}
- hs = HomeServer(
+ hs = TestHomeServer(
"test",
db_pool=self.db_pool,
config=config,
@@ -63,8 +65,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
yield self.datastore._simple_insert(
- table="tablename",
- values={"columname": "Value"}
+ table="tablename", values={"columname": "Value"}
)
self.mock_txn.execute.assert_called_with(
@@ -78,12 +79,11 @@ class SQLBaseStoreTestCase(unittest.TestCase):
yield self.datastore._simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
- values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)])
+ values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
)
self.mock_txn.execute.assert_called_with(
- "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)",
- (1, 2, 3,)
+ "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", (1, 2, 3)
)
@defer.inlineCallbacks
@@ -92,9 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore._simple_select_one_onecol(
- table="tablename",
- keyvalues={"keycol": "TheKey"},
- retcol="retcol"
+ table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
)
self.assertEquals("Value", value)
@@ -110,13 +108,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
ret = yield self.datastore._simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
- retcols=["colA", "colB", "colC"]
+ retcols=["colA", "colB", "colC"],
)
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
self.mock_txn.execute.assert_called_with(
- "SELECT colA, colB, colC FROM tablename WHERE keycol = ?",
- ["TheKey"]
+ "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
)
@defer.inlineCallbacks
@@ -128,7 +125,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
- allow_none=True
+ allow_none=True,
)
self.assertFalse(ret)
@@ -137,20 +134,15 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_select_list(self):
self.mock_txn.rowcount = 3
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
- self.mock_txn.description = (
- ("colA", None, None, None, None, None, None),
- )
+ self.mock_txn.description = (("colA", None, None, None, None, None, None),)
ret = yield self.datastore._simple_select_list(
- table="tablename",
- keyvalues={"keycol": "A set"},
- retcols=["colA"],
+ table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
)
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
self.mock_txn.execute.assert_called_with(
- "SELECT colA FROM tablename WHERE keycol = ?",
- ["A set"]
+ "SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
)
@defer.inlineCallbacks
@@ -160,12 +152,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
yield self.datastore._simple_update_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
- updatevalues={"columnname": "New Value"}
+ updatevalues={"columnname": "New Value"},
)
self.mock_txn.execute.assert_called_with(
"UPDATE tablename SET columnname = ? WHERE keycol = ?",
- ["New Value", "TheKey"]
+ ["New Value", "TheKey"],
)
@defer.inlineCallbacks
@@ -175,13 +167,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
yield self.datastore._simple_update_one(
table="tablename",
keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
- updatevalues=OrderedDict([("colC", 3), ("colD", 4)])
+ updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
)
self.mock_txn.execute.assert_called_with(
- "UPDATE tablename SET colC = ?, colD = ? WHERE"
- " colA = ? AND colB = ?",
- [3, 4, 1, 2]
+ "UPDATE tablename SET colC = ?, colD = ? WHERE" " colA = ? AND colB = ?",
+ [3, 4, 1, 2],
)
@defer.inlineCallbacks
@@ -189,8 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
yield self.datastore._simple_delete_one(
- table="tablename",
- keyvalues={"keycol": "Go away"},
+ table="tablename", keyvalues={"keycol": "Go away"}
)
self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index bd6fda6cb1..4577e9422b 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,34 +14,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
+
from twisted.internet import defer
-import tests.unittest
-import tests.utils
+from synapse.http.site import XForwardedForRequest
+from synapse.rest.client.v1 import admin, login
+
+from tests import unittest
-class ClientIpStoreTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(ClientIpStoreTestCase, self).__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
- self.clock = None # type: tests.utils.MockClock
+class ClientIpStoreTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+ return hs
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver()
- self.store = hs.get_datastore()
- self.clock = hs.get_clock()
+ def prepare(self, hs, reactor, clock):
+ self.store = self.hs.get_datastore()
- @defer.inlineCallbacks
def test_insert_new_client_ip(self):
- self.clock.now = 12345678
+ self.reactor.advance(12345678)
+
user_id = "@user:id"
- yield self.store.insert_client_ip(
- user_id,
- "access_token", "ip", "user_agent", "device_id",
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
)
- result = yield self.store.get_last_client_ip_by_device(user_id, "device_id")
+ # Trigger the storage loop
+ self.reactor.advance(10)
+
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, "device_id")
+ )
r = result[(user_id, "device_id")]
self.assertDictContainsSubset(
@@ -52,5 +59,143 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
"user_agent": "user_agent",
"last_seen": 12345678000,
},
- r
+ r,
+ )
+
+ def test_disabled_monthly_active_user(self):
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.max_mau_value = 50
+ user_id = "@user:server"
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
+ )
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
+ self.assertFalse(active)
+
+ def test_adding_monthly_active_user_when_full(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.max_mau_value = 50
+ lots_of_users = 100
+ user_id = "@user:server"
+
+ self.store.get_monthly_active_count = Mock(
+ return_value=defer.succeed(lots_of_users)
+ )
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
+ )
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
+ self.assertFalse(active)
+
+ def test_adding_monthly_active_user_when_space(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.max_mau_value = 50
+ user_id = "@user:server"
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
+ self.assertFalse(active)
+
+ # Trigger the saving loop
+ self.reactor.advance(10)
+
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
+ )
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
+ self.assertTrue(active)
+
+ def test_updating_monthly_active_user_when_space(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.max_mau_value = 50
+ user_id = "@user:server"
+ self.get_success(
+ self.store.register(user_id=user_id, token="123", password_hash=None)
+ )
+
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
+ self.assertFalse(active)
+
+ # Trigger the saving loop
+ self.reactor.advance(10)
+
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
+ )
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
+ self.assertTrue(active)
+
+
+class ClientIpAuthTestCase(unittest.HomeserverTestCase):
+
+ servlets = [admin.register_servlets, login.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+ return hs
+
+ def prepare(self, hs, reactor, clock):
+ self.store = self.hs.get_datastore()
+ self.user_id = self.register_user("bob", "abc123", True)
+
+ def test_request_with_xforwarded(self):
+ """
+ The IP in X-Forwarded-For is entered into the client IPs table.
+ """
+ self._runtest(
+ {b"X-Forwarded-For": b"127.9.0.1"},
+ "127.9.0.1",
+ {"request": XForwardedForRequest},
+ )
+
+ def test_request_from_getPeer(self):
+ """
+ The IP returned by getPeer is entered into the client IPs table, if
+ there's no X-Forwarded-For header.
+ """
+ self._runtest({}, "127.0.0.1", {})
+
+ def _runtest(self, headers, expected_ip, make_request_args):
+ device_id = "bleb"
+
+ access_token = self.login("bob", "abc123", device_id=device_id)
+
+ # Advance to a known time
+ self.reactor.advance(123456 - self.reactor.seconds())
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/admin/users/" + self.user_id,
+ access_token=access_token,
+ **make_request_args
+ )
+ request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
+
+ # Add the optional headers
+ for h, v in headers.items():
+ request.requestHeaders.addRawHeader(h, v)
+ self.render(request)
+
+ # Advance so the save loop occurs
+ self.reactor.advance(100)
+
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(self.user_id, device_id)
+ )
+ r = result[(self.user_id, device_id)]
+ self.assertDictContainsSubset(
+ {
+ "user_id": self.user_id,
+ "device_id": device_id,
+ "ip": expected_ip,
+ "user_agent": "Mozzila pizza",
+ "last_seen": 123456100,
+ },
+ r,
)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index a54cc6bc32..aef4dfaf57 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -28,68 +28,64 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield tests.utils.setup_test_homeserver()
+ hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_store_new_device(self):
- yield self.store.store_device(
- "user_id", "device_id", "display_name"
- )
+ yield self.store.store_device("user_id", "device_id", "display_name")
res = yield self.store.get_device("user_id", "device_id")
- self.assertDictContainsSubset({
- "user_id": "user_id",
- "device_id": "device_id",
- "display_name": "display_name",
- }, res)
+ self.assertDictContainsSubset(
+ {
+ "user_id": "user_id",
+ "device_id": "device_id",
+ "display_name": "display_name",
+ },
+ res,
+ )
@defer.inlineCallbacks
def test_get_devices_by_user(self):
- yield self.store.store_device(
- "user_id", "device1", "display_name 1"
- )
- yield self.store.store_device(
- "user_id", "device2", "display_name 2"
- )
- yield self.store.store_device(
- "user_id2", "device3", "display_name 3"
- )
+ yield self.store.store_device("user_id", "device1", "display_name 1")
+ yield self.store.store_device("user_id", "device2", "display_name 2")
+ yield self.store.store_device("user_id2", "device3", "display_name 3")
res = yield self.store.get_devices_by_user("user_id")
self.assertEqual(2, len(res.keys()))
- self.assertDictContainsSubset({
- "user_id": "user_id",
- "device_id": "device1",
- "display_name": "display_name 1",
- }, res["device1"])
- self.assertDictContainsSubset({
- "user_id": "user_id",
- "device_id": "device2",
- "display_name": "display_name 2",
- }, res["device2"])
+ self.assertDictContainsSubset(
+ {
+ "user_id": "user_id",
+ "device_id": "device1",
+ "display_name": "display_name 1",
+ },
+ res["device1"],
+ )
+ self.assertDictContainsSubset(
+ {
+ "user_id": "user_id",
+ "device_id": "device2",
+ "display_name": "display_name 2",
+ },
+ res["device2"],
+ )
@defer.inlineCallbacks
def test_update_device(self):
- yield self.store.store_device(
- "user_id", "device_id", "display_name 1"
- )
+ yield self.store.store_device("user_id", "device_id", "display_name 1")
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
- yield self.store.update_device(
- "user_id", "device_id",
- )
+ yield self.store.update_device("user_id", "device_id")
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do the update
yield self.store.update_device(
- "user_id", "device_id",
- new_display_name="display_name 2",
+ "user_id", "device_id", new_display_name="display_name 2"
)
# check it worked
@@ -100,7 +96,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm:
yield self.store.update_device(
- "user_id", "unknown_device_id",
- new_display_name="display_name 2",
+ "user_id", "unknown_device_id", new_display_name="display_name 2"
)
self.assertEqual(404, cm.exception.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 129ebaf343..4e128e1047 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -16,7 +16,6 @@
from twisted.internet import defer
-from synapse.storage.directory import DirectoryStore
from synapse.types import RoomAlias, RoomID
from tests import unittest
@@ -24,12 +23,11 @@ from tests.utils import setup_test_homeserver
class DirectoryStoreTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
- hs = yield setup_test_homeserver()
+ hs = yield setup_test_homeserver(self.addCleanup)
- self.store = DirectoryStore(None, hs)
+ self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
@@ -37,38 +35,29 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_room_to_alias(self):
yield self.store.create_room_alias_association(
- room_alias=self.alias,
- room_id=self.room.to_string(),
- servers=["test"],
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
self.assertEquals(
["#my-room:test"],
- (yield self.store.get_aliases_for_room(self.room.to_string()))
+ (yield self.store.get_aliases_for_room(self.room.to_string())),
)
@defer.inlineCallbacks
def test_alias_to_room(self):
yield self.store.create_room_alias_association(
- room_alias=self.alias,
- room_id=self.room.to_string(),
- servers=["test"],
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
self.assertObjectHasAttributes(
- {
- "room_id": self.room.to_string(),
- "servers": ["test"],
- },
- (yield self.store.get_association_from_room_alias(self.alias))
+ {"room_id": self.room.to_string(), "servers": ["test"]},
+ (yield self.store.get_association_from_room_alias(self.alias)),
)
@defer.inlineCallbacks
def test_delete_alias(self):
yield self.store.create_room_alias_association(
- room_alias=self.alias,
- room_id=self.room.to_string(),
- servers=["test"],
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
room_id = yield self.store.delete_room_alias(self.alias)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 84ce492a2c..b83f7336d3 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -26,8 +26,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield tests.utils.setup_test_homeserver()
-
+ hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
@defer.inlineCallbacks
@@ -35,70 +34,64 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070
json = {"key": "value"}
- yield self.store.store_device(
- "user", "device", None
- )
+ yield self.store.store_device("user", "device", None)
- yield self.store.set_e2e_device_keys(
- "user", "device", now, json)
+ yield self.store.set_e2e_device_keys("user", "device", now, json)
res = yield self.store.get_e2e_device_keys((("user", "device"),))
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
- self.assertDictContainsSubset({
- "keys": json,
- "device_display_name": None,
- }, dev)
+ self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev)
+
+ @defer.inlineCallbacks
+ def test_reupload_key(self):
+ now = 1470174257070
+ json = {"key": "value"}
+
+ yield self.store.store_device("user", "device", None)
+
+ changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+ self.assertTrue(changed)
+
+ # If we try to upload the same key then we should be told nothing
+ # changed
+ changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+ self.assertFalse(changed)
@defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = {"key": "value"}
- yield self.store.set_e2e_device_keys(
- "user", "device", now, json)
- yield self.store.store_device(
- "user", "device", "display_name"
- )
+ yield self.store.set_e2e_device_keys("user", "device", now, json)
+ yield self.store.store_device("user", "device", "display_name")
res = yield self.store.get_e2e_device_keys((("user", "device"),))
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
- self.assertDictContainsSubset({
- "keys": json,
- "device_display_name": "display_name",
- }, dev)
+ self.assertDictContainsSubset(
+ {"keys": json, "device_display_name": "display_name"}, dev
+ )
@defer.inlineCallbacks
def test_multiple_devices(self):
now = 1470174257070
- yield self.store.store_device(
- "user1", "device1", None
- )
- yield self.store.store_device(
- "user1", "device2", None
- )
- yield self.store.store_device(
- "user2", "device1", None
- )
- yield self.store.store_device(
- "user2", "device2", None
- )
+ yield self.store.store_device("user1", "device1", None)
+ yield self.store.store_device("user1", "device2", None)
+ yield self.store.store_device("user2", "device1", None)
+ yield self.store.store_device("user2", "device2", None)
- yield self.store.set_e2e_device_keys(
- "user1", "device1", now, 'json11')
- yield self.store.set_e2e_device_keys(
- "user1", "device2", now, 'json12')
- yield self.store.set_e2e_device_keys(
- "user2", "device1", now, 'json21')
- yield self.store.set_e2e_device_keys(
- "user2", "device2", now, 'json22')
-
- res = yield self.store.get_e2e_device_keys((("user1", "device1"),
- ("user2", "device2")))
+ yield self.store.set_e2e_device_keys("user1", "device1", now, 'json11')
+ yield self.store.set_e2e_device_keys("user1", "device2", now, 'json12')
+ yield self.store.set_e2e_device_keys("user2", "device1", now, 'json21')
+ yield self.store.set_e2e_device_keys("user2", "device2", now, 'json22')
+
+ res = yield self.store.get_e2e_device_keys(
+ (("user1", "device1"), ("user2", "device2"))
+ )
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
self.assertNotIn("device2", res["user1"])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 30683e7888..0d4e74d637 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -22,7 +22,7 @@ import tests.utils
class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield tests.utils.setup_test_homeserver()
+ hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
@defer.inlineCallbacks
@@ -33,23 +33,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
def insert_event(txn, i):
event_id = '$event_%i:local' % i
- txn.execute((
- "INSERT INTO events ("
- " room_id, event_id, type, depth, topological_ordering,"
- " content, processed, outlier) "
- "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
- ), (room_id, event_id, i, i, True, False))
+ txn.execute(
+ (
+ "INSERT INTO events ("
+ " room_id, event_id, type, depth, topological_ordering,"
+ " content, processed, outlier, stream_ordering) "
+ "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)"
+ ),
+ (room_id, event_id, i, i, True, False, i),
+ )
- txn.execute((
- 'INSERT INTO event_forward_extremities (room_id, event_id) '
- 'VALUES (?, ?)'
- ), (room_id, event_id))
+ txn.execute(
+ (
+ 'INSERT INTO event_forward_extremities (room_id, event_id) '
+ 'VALUES (?, ?)'
+ ),
+ (room_id, event_id),
+ )
- txn.execute((
- 'INSERT INTO event_reference_hashes '
- '(event_id, algorithm, hash) '
- "VALUES (?, 'sha256', ?)"
- ), (event_id, 'ffff'))
+ txn.execute(
+ (
+ 'INSERT INTO event_reference_hashes '
+ '(event_id, algorithm, hash) '
+ "VALUES (?, 'sha256', ?)"
+ ),
+ (event_id, b'ffff'),
+ )
for i in range(0, 11):
yield self.store.runInteraction("insert", insert_event, i)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 8430fc7ba6..b114c6fb1d 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -24,15 +24,16 @@ USER_ID = "@user:example.com"
PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}]
HIGHLIGHT = [
- "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight"},
]
class EventPushActionsStoreTestCase(tests.unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
- hs = yield tests.utils.setup_test_homeserver()
+ hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
@defer.inlineCallbacks
@@ -55,12 +56,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
counts = yield self.store.runInteraction(
- "", self.store._get_unread_counts_by_pos_txn,
- room_id, user_id, 0
+ "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
self.assertEquals(
counts,
- {"notify_count": noitf_count, "highlight_count": highlight_count}
+ {"notify_count": noitf_count, "highlight_count": highlight_count},
)
@defer.inlineCallbacks
@@ -72,11 +72,13 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.depth = stream
yield self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action},
+ event.event_id, {user_id: action}
)
yield self.store.runInteraction(
- "", self.store._set_push_actions_for_event_and_users_txn,
- [(event, None)], [(event, None)],
+ "",
+ self.store._set_push_actions_for_event_and_users_txn,
+ [(event, None)],
+ [(event, None)],
)
def _rotate(stream):
@@ -86,8 +88,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
def _mark_read(stream, depth):
return self.store.runInteraction(
- "", self.store._remove_old_push_actions_before_txn,
- room_id, user_id, stream
+ "",
+ self.store._remove_old_push_actions_before_txn,
+ room_id,
+ user_id,
+ stream,
)
yield _assert_counts(0, 0)
@@ -112,9 +117,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _rotate(7)
yield self.store._simple_delete(
- table="event_push_actions",
- keyvalues={"1": 1},
- desc="",
+ table="event_push_actions", keyvalues={"1": 1}, desc=""
)
yield _assert_counts(1, 0)
@@ -132,18 +135,21 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return self.store._simple_insert("events", {
- "stream_ordering": so,
- "received_ts": ts,
- "event_id": "event%i" % so,
- "type": "",
- "room_id": "",
- "content": "",
- "processed": True,
- "outlier": False,
- "topological_ordering": 0,
- "depth": 0,
- })
+ return self.store._simple_insert(
+ "events",
+ {
+ "stream_ordering": so,
+ "received_ts": ts,
+ "event_id": "event%i" % so,
+ "type": "",
+ "room_id": "",
+ "content": "",
+ "processed": True,
+ "outlier": False,
+ "topological_ordering": 0,
+ "depth": 0,
+ },
+ )
# start with the base case where there are no events in the table
r = yield self.store.find_first_stream_ordering_after_ts(11)
@@ -160,31 +166,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
# add a bunch of dummy events to the events table
for (stream_ordering, ts) in (
- (3, 110),
- (4, 120),
- (5, 120),
- (10, 130),
- (20, 140),
+ (3, 110),
+ (4, 120),
+ (5, 120),
+ (10, 130),
+ (20, 140),
):
yield add_event(stream_ordering, ts)
r = yield self.store.find_first_stream_ordering_after_ts(110)
- self.assertEqual(r, 3,
- "First event after 110ms should be 3, was %i" % r)
+ self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5
r = yield self.store.find_first_stream_ordering_after_ts(120)
- self.assertEqual(r, 4,
- "First event after 120ms should be 4, was %i" % r)
+ self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
r = yield self.store.find_first_stream_ordering_after_ts(129)
- self.assertEqual(r, 10,
- "First event after 129ms should be 10, was %i" % r)
+ self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event
r = yield self.store.find_first_stream_ordering_after_ts(140)
- self.assertEqual(r, 20,
- "First event after 14ms should be 20, was %i" % r)
+ self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end
r = yield self.store.find_first_stream_ordering_after_ts(160)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 3a3d002782..47f4a8ceac 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -28,7 +28,7 @@ class KeyStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield tests.utils.setup_test_homeserver()
+ hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
@defer.inlineCallbacks
@@ -39,15 +39,12 @@ class KeyStoreTestCase(tests.unittest.TestCase):
key2 = signedjson.key.decode_verify_key_base64(
"ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
)
- yield self.store.store_server_verify_key(
- "server1", "from_server", 0, key1
- )
- yield self.store.store_server_verify_key(
- "server1", "from_server", 0, key2
- )
+ yield self.store.store_server_verify_key("server1", "from_server", 0, key1)
+ yield self.store.store_server_verify_key("server1", "from_server", 0, key2)
res = yield self.store.get_server_verify_keys(
- "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"])
+ "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]
+ )
self.assertEqual(len(res.keys()), 2)
self.assertEqual(res["ed25519:key1"].version, "key1")
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
new file mode 100644
index 0000000000..8664bc3d54
--- /dev/null
+++ b/tests/storage/test_monthly_active_users.py
@@ -0,0 +1,247 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from twisted.internet import defer
+
+from tests.unittest import HomeserverTestCase
+
+FORTY_DAYS = 40 * 24 * 60 * 60
+
+
+class MonthlyActiveUsersTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+
+ hs = self.setup_test_homeserver()
+ self.store = hs.get_datastore()
+ hs.config.limit_usage_by_mau = True
+ hs.config.max_mau_value = 50
+ # Advance the clock a bit
+ reactor.advance(FORTY_DAYS)
+
+ return hs
+
+ def test_initialise_reserved_users(self):
+ self.hs.config.max_mau_value = 5
+ user1 = "@user1:server"
+ user1_email = "user1@matrix.org"
+ user2 = "@user2:server"
+ user2_email = "user2@matrix.org"
+ threepids = [
+ {'medium': 'email', 'address': user1_email},
+ {'medium': 'email', 'address': user2_email},
+ ]
+ user_num = len(threepids)
+
+ self.store.register(user_id=user1, token="123", password_hash=None)
+ self.store.register(user_id=user2, token="456", password_hash=None)
+ self.pump()
+
+ now = int(self.hs.get_clock().time_msec())
+ self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ self.store.user_add_threepid(user2, "email", user2_email, now, now)
+
+ self.store.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
+ self.pump()
+
+ active_count = self.store.get_monthly_active_count()
+
+ # Test total counts
+ self.assertEquals(self.get_success(active_count), user_num)
+
+ # Test user is marked as active
+ timestamp = self.store.user_last_seen_monthly_active(user1)
+ self.assertTrue(self.get_success(timestamp))
+ timestamp = self.store.user_last_seen_monthly_active(user2)
+ self.assertTrue(self.get_success(timestamp))
+
+ # Test that users are never removed from the db.
+ self.hs.config.max_mau_value = 0
+
+ self.reactor.advance(FORTY_DAYS)
+
+ self.store.reap_monthly_active_users()
+ self.pump()
+
+ active_count = self.store.get_monthly_active_count()
+ self.assertEquals(self.get_success(active_count), user_num)
+
+ # Test that regular users are removed from the db
+ ru_count = 2
+ self.store.upsert_monthly_active_user("@ru1:server")
+ self.store.upsert_monthly_active_user("@ru2:server")
+ self.pump()
+
+ active_count = self.store.get_monthly_active_count()
+ self.assertEqual(self.get_success(active_count), user_num + ru_count)
+ self.hs.config.max_mau_value = user_num
+ self.store.reap_monthly_active_users()
+ self.pump()
+
+ active_count = self.store.get_monthly_active_count()
+ self.assertEquals(self.get_success(active_count), user_num)
+
+ def test_can_insert_and_count_mau(self):
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(0, self.get_success(count))
+
+ self.store.upsert_monthly_active_user("@user:server")
+ self.pump()
+
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(1, self.get_success(count))
+
+ def test_user_last_seen_monthly_active(self):
+ user_id1 = "@user1:server"
+ user_id2 = "@user2:server"
+ user_id3 = "@user3:server"
+
+ result = self.store.user_last_seen_monthly_active(user_id1)
+ self.assertFalse(self.get_success(result) == 0)
+
+ self.store.upsert_monthly_active_user(user_id1)
+ self.store.upsert_monthly_active_user(user_id2)
+ self.pump()
+
+ result = self.store.user_last_seen_monthly_active(user_id1)
+ self.assertGreater(self.get_success(result), 0)
+
+ result = self.store.user_last_seen_monthly_active(user_id3)
+ self.assertNotEqual(self.get_success(result), 0)
+
+ def test_reap_monthly_active_users(self):
+ self.hs.config.max_mau_value = 5
+ initial_users = 10
+ for i in range(initial_users):
+ self.store.upsert_monthly_active_user("@user%d:server" % i)
+ self.pump()
+
+ count = self.store.get_monthly_active_count()
+ self.assertTrue(self.get_success(count), initial_users)
+
+ self.store.reap_monthly_active_users()
+ self.pump()
+ count = self.store.get_monthly_active_count()
+ self.assertEquals(
+ self.get_success(count), initial_users - self.hs.config.max_mau_value
+ )
+
+ self.reactor.advance(FORTY_DAYS)
+ self.store.reap_monthly_active_users()
+ self.pump()
+
+ count = self.store.get_monthly_active_count()
+ self.assertEquals(self.get_success(count), 0)
+
+ def test_populate_monthly_users_is_guest(self):
+ # Test that guest users are not added to mau list
+ user_id = "user_id"
+ self.store.register(
+ user_id=user_id, token="123", password_hash=None, make_guest=True
+ )
+ self.store.upsert_monthly_active_user = Mock()
+ self.store.populate_monthly_active_users(user_id)
+ self.pump()
+ self.store.upsert_monthly_active_user.assert_not_called()
+
+ def test_populate_monthly_users_should_update(self):
+ self.store.upsert_monthly_active_user = Mock()
+
+ self.store.is_trial_user = Mock(
+ return_value=defer.succeed(False)
+ )
+
+ self.store.user_last_seen_monthly_active = Mock(
+ return_value=defer.succeed(None)
+ )
+ self.store.populate_monthly_active_users('user_id')
+ self.pump()
+ self.store.upsert_monthly_active_user.assert_called_once()
+
+ def test_populate_monthly_users_should_not_update(self):
+ self.store.upsert_monthly_active_user = Mock()
+
+ self.store.is_trial_user = Mock(
+ return_value=defer.succeed(False)
+ )
+ self.store.user_last_seen_monthly_active = Mock(
+ return_value=defer.succeed(
+ self.hs.get_clock().time_msec()
+ )
+ )
+ self.store.populate_monthly_active_users('user_id')
+ self.pump()
+ self.store.upsert_monthly_active_user.assert_not_called()
+
+ def test_get_reserved_real_user_account(self):
+ # Test no reserved users, or reserved threepids
+ count = self.store.get_registered_reserved_users_count()
+ self.assertEquals(self.get_success(count), 0)
+ # Test reserved users but no registered users
+
+ user1 = '@user1:example.com'
+ user2 = '@user2:example.com'
+ user1_email = 'user1@example.com'
+ user2_email = 'user2@example.com'
+ threepids = [
+ {'medium': 'email', 'address': user1_email},
+ {'medium': 'email', 'address': user2_email},
+ ]
+ self.hs.config.mau_limits_reserved_threepids = threepids
+ self.store.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
+
+ self.pump()
+ count = self.store.get_registered_reserved_users_count()
+ self.assertEquals(self.get_success(count), 0)
+
+ # Test reserved registed users
+ self.store.register(user_id=user1, token="123", password_hash=None)
+ self.store.register(user_id=user2, token="456", password_hash=None)
+ self.pump()
+
+ now = int(self.hs.get_clock().time_msec())
+ self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ count = self.store.get_registered_reserved_users_count()
+ self.assertEquals(self.get_success(count), len(threepids))
+
+ def test_track_monthly_users_without_cap(self):
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.mau_stats_only = True
+ self.hs.config.max_mau_value = 1 # should not matter
+
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(0, self.get_success(count))
+
+ self.store.upsert_monthly_active_user("@user1:server")
+ self.store.upsert_monthly_active_user("@user2:server")
+ self.pump()
+
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(2, self.get_success(count))
+
+ def test_no_users_when_not_tracking(self):
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.mau_stats_only = False
+ self.store.upsert_monthly_active_user = Mock()
+
+ self.store.populate_monthly_active_users("@user:sever")
+ self.pump()
+
+ self.store.upsert_monthly_active_user.assert_not_called()
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
index 3276b39504..c7a63f39b9 100644
--- a/tests/storage/test_presence.py
+++ b/tests/storage/test_presence.py
@@ -16,20 +16,18 @@
from twisted.internet import defer
-from synapse.storage.presence import PresenceStore
from synapse.types import UserID
from tests import unittest
-from tests.utils import MockClock, setup_test_homeserver
+from tests.utils import setup_test_homeserver
class PresenceStoreTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
- hs = yield setup_test_homeserver(clock=MockClock())
+ hs = yield setup_test_homeserver(self.addCleanup)
- self.store = PresenceStore(None, hs)
+ self.store = hs.get_datastore()
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
@@ -38,16 +36,19 @@ class PresenceStoreTestCase(unittest.TestCase):
def test_presence_list(self):
self.assertEquals(
[],
- (yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart,
- ))
+ (
+ yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart
+ )
+ ),
)
self.assertEquals(
[],
- (yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart,
- accepted=True,
- ))
+ (
+ yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart, accepted=True
+ )
+ ),
)
yield self.store.add_presence_list_pending(
@@ -57,16 +58,19 @@ class PresenceStoreTestCase(unittest.TestCase):
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 0}],
- (yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart,
- ))
+ (
+ yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart
+ )
+ ),
)
self.assertEquals(
[],
- (yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart,
- accepted=True,
- ))
+ (
+ yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart, accepted=True
+ )
+ ),
)
yield self.store.set_presence_list_accepted(
@@ -76,16 +80,19 @@ class PresenceStoreTestCase(unittest.TestCase):
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 1}],
- (yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart,
- ))
+ (
+ yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart
+ )
+ ),
)
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 1}],
- (yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart,
- accepted=True,
- ))
+ (
+ yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart, accepted=True
+ )
+ ),
)
yield self.store.del_presence_list(
@@ -95,14 +102,17 @@ class PresenceStoreTestCase(unittest.TestCase):
self.assertEquals(
[],
- (yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart,
- ))
+ (
+ yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart
+ )
+ ),
)
self.assertEquals(
[],
- (yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart,
- accepted=True,
- ))
+ (
+ yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart, accepted=True
+ )
+ ),
)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 2c95e5e95a..45824bd3b2 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -24,35 +24,27 @@ from tests.utils import setup_test_homeserver
class ProfileStoreTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
- hs = yield setup_test_homeserver()
+ hs = yield setup_test_homeserver(self.addCleanup)
- self.store = ProfileStore(None, hs)
+ self.store = ProfileStore(hs.get_db_conn(), hs)
self.u_frank = UserID.from_string("@frank:test")
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(
- self.u_frank.localpart
- )
+ yield self.store.create_profile(self.u_frank.localpart)
- yield self.store.set_profile_displayname(
- self.u_frank.localpart, "Frank"
- )
+ yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
self.assertEquals(
- "Frank",
- (yield self.store.get_profile_displayname(self.u_frank.localpart))
+ "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
)
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(
- self.u_frank.localpart
- )
+ yield self.store.create_profile(self.u_frank.localpart)
yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
@@ -60,5 +52,5 @@ class ProfileStoreTestCase(unittest.TestCase):
self.assertEquals(
"http://my.site/here",
- (yield self.store.get_profile_avatar_url(self.u_frank.localpart))
+ (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
new file mode 100644
index 0000000000..f671599cb8
--- /dev/null
+++ b/tests/storage/test_purge.py
@@ -0,0 +1,106 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest.client.v1 import room
+
+from tests.unittest import HomeserverTestCase
+
+
+class PurgeTests(HomeserverTestCase):
+
+ user_id = "@red:server"
+ servlets = [room.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", http_client=None)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.room_id = self.helper.create_room_as(self.user_id)
+
+ def test_purge(self):
+ """
+ Purging a room will delete everything before the topological point.
+ """
+ # Send four messages to the room
+ first = self.helper.send(self.room_id, body="test1")
+ second = self.helper.send(self.room_id, body="test2")
+ third = self.helper.send(self.room_id, body="test3")
+ last = self.helper.send(self.room_id, body="test4")
+
+ storage = self.hs.get_datastore()
+
+ # Get the topological token
+ event = storage.get_topological_token_for_event(last["event_id"])
+ self.pump()
+ event = self.successResultOf(event)
+
+ # Purge everything before this topological token
+ purge = storage.purge_history(self.room_id, event, True)
+ self.pump()
+ self.assertEqual(self.successResultOf(purge), None)
+
+ # Try and get the events
+ get_first = storage.get_event(first["event_id"])
+ get_second = storage.get_event(second["event_id"])
+ get_third = storage.get_event(third["event_id"])
+ get_last = storage.get_event(last["event_id"])
+ self.pump()
+
+ # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
+ # and last is not.
+ self.failureResultOf(get_first)
+ self.failureResultOf(get_second)
+ self.failureResultOf(get_third)
+ self.successResultOf(get_last)
+
+ def test_purge_wont_delete_extrems(self):
+ """
+ Purging a room will delete everything before the topological point.
+ """
+ # Send four messages to the room
+ first = self.helper.send(self.room_id, body="test1")
+ second = self.helper.send(self.room_id, body="test2")
+ third = self.helper.send(self.room_id, body="test3")
+ last = self.helper.send(self.room_id, body="test4")
+
+ storage = self.hs.get_datastore()
+
+ # Set the topological token higher than it should be
+ event = storage.get_topological_token_for_event(last["event_id"])
+ self.pump()
+ event = self.successResultOf(event)
+ event = "t{}-{}".format(
+ *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
+ )
+
+ # Purge everything before this topological token
+ purge = storage.purge_history(self.room_id, event, True)
+ self.pump()
+ f = self.failureResultOf(purge)
+ self.assertIn("greater than forward", f.value.args[0])
+
+ # Try and get the events
+ get_first = storage.get_event(first["event_id"])
+ get_second = storage.get_event(second["event_id"])
+ get_third = storage.get_event(third["event_id"])
+ get_last = storage.get_event(last["event_id"])
+ self.pump()
+
+ # Nothing is deleted.
+ self.successResultOf(get_first)
+ self.successResultOf(get_second)
+ self.successResultOf(get_third)
+ self.successResultOf(get_last)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 475ec900c4..02bf975fbf 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -22,16 +22,14 @@ from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
class RedactionTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(
- resource_for_federation=Mock(),
- http_client=None,
+ self.addCleanup, resource_for_federation=Mock(), http_client=None
)
self.store = hs.get_datastore()
@@ -43,20 +41,25 @@ class RedactionTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test")
+ yield create_room(hs, self.room1.to_string(), self.u_alice.to_string())
+
self.depth = 1
@defer.inlineCallbacks
- def inject_room_member(self, room, user, membership, replaces_state=None,
- extra_content={}):
+ def inject_room_member(
+ self, room, user, membership, replaces_state=None, extra_content={}
+ ):
content = {"membership": membership}
content.update(extra_content)
- builder = self.event_builder_factory.new({
- "type": EventTypes.Member,
- "sender": user.to_string(),
- "state_key": user.to_string(),
- "room_id": room.to_string(),
- "content": content,
- })
+ builder = self.event_builder_factory.new(
+ {
+ "type": EventTypes.Member,
+ "sender": user.to_string(),
+ "state_key": user.to_string(),
+ "room_id": room.to_string(),
+ "content": content,
+ }
+ )
event, context = yield self.event_creation_handler.create_new_client_event(
builder
@@ -70,13 +73,15 @@ class RedactionTestCase(unittest.TestCase):
def inject_message(self, room, user, body):
self.depth += 1
- builder = self.event_builder_factory.new({
- "type": EventTypes.Message,
- "sender": user.to_string(),
- "state_key": user.to_string(),
- "room_id": room.to_string(),
- "content": {"body": body, "msgtype": u"message"},
- })
+ builder = self.event_builder_factory.new(
+ {
+ "type": EventTypes.Message,
+ "sender": user.to_string(),
+ "state_key": user.to_string(),
+ "room_id": room.to_string(),
+ "content": {"body": body, "msgtype": u"message"},
+ }
+ )
event, context = yield self.event_creation_handler.create_new_client_event(
builder
@@ -88,14 +93,16 @@ class RedactionTestCase(unittest.TestCase):
@defer.inlineCallbacks
def inject_redaction(self, room, event_id, user, reason):
- builder = self.event_builder_factory.new({
- "type": EventTypes.Redaction,
- "sender": user.to_string(),
- "state_key": user.to_string(),
- "room_id": room.to_string(),
- "content": {"reason": reason},
- "redacts": event_id,
- })
+ builder = self.event_builder_factory.new(
+ {
+ "type": EventTypes.Redaction,
+ "sender": user.to_string(),
+ "state_key": user.to_string(),
+ "room_id": room.to_string(),
+ "content": {"reason": reason},
+ "redacts": event_id,
+ }
+ )
event, context = yield self.event_creation_handler.create_new_client_event(
builder
@@ -105,9 +112,7 @@ class RedactionTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_redact(self):
- yield self.inject_room_member(
- self.room1, self.u_alice, Membership.JOIN
- )
+ yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = yield self.inject_message(self.room1, self.u_alice, u"t")
@@ -157,13 +162,10 @@ class RedactionTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_redact_join(self):
- yield self.inject_room_member(
- self.room1, self.u_alice, Membership.JOIN
- )
+ yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = yield self.inject_room_member(
- self.room1, self.u_bob, Membership.JOIN,
- extra_content={"blue": "red"},
+ self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
)
event = yield self.store.get_event(msg_event.event_id)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 7821ea3fa3..3dfb7b903a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -21,19 +21,15 @@ from tests.utils import setup_test_homeserver
class RegistrationStoreTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
- hs = yield setup_test_homeserver()
+ hs = yield setup_test_homeserver(self.addCleanup)
self.db_pool = hs.get_db_pool()
self.store = hs.get_datastore()
self.user_id = "@my-user:test"
- self.tokens = [
- "AbCdEfGhIjKlMnOpQrStUvWxYz",
- "BcDeFgHiJkLmNoPqRsTuVwXyZa"
- ]
+ self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", "BcDeFgHiJkLmNoPqRsTuVwXyZa"]
self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg"
@@ -50,35 +46,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None,
"consent_server_notice_sent": None,
"appservice_id": None,
+ "creation_ts": 1000,
},
- (yield self.store.get_user_by_id(self.user_id))
+ (yield self.store.get_user_by_id(self.user_id)),
)
result = yield self.store.get_user_by_access_token(self.tokens[0])
- self.assertDictContainsSubset(
- {
- "name": self.user_id,
- },
- result
- )
+ self.assertDictContainsSubset({"name": self.user_id}, result)
self.assertTrue("token_id" in result)
@defer.inlineCallbacks
def test_add_tokens(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
- yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
- self.device_id)
+ yield self.store.add_access_token_to_user(
+ self.user_id, self.tokens[1], self.device_id
+ )
result = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertDictContainsSubset(
- {
- "name": self.user_id,
- "device_id": self.device_id,
- },
- result
+ {"name": self.user_id, "device_id": self.device_id}, result
)
self.assertTrue("token_id" in result)
@@ -87,12 +76,13 @@ class RegistrationStoreTestCase(unittest.TestCase):
def test_user_delete_access_tokens(self):
# add some tokens
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
- yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
- self.device_id)
+ yield self.store.add_access_token_to_user(
+ self.user_id, self.tokens[1], self.device_id
+ )
# now delete some
yield self.store.user_delete_access_tokens(
- self.user_id, device_id=self.device_id,
+ self.user_id, device_id=self.device_id
)
# check they were deleted
@@ -107,8 +97,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
yield self.store.user_delete_access_tokens(self.user_id)
user = yield self.store.get_user_by_access_token(self.tokens[0])
- self.assertIsNone(user,
- "access token was not deleted without device_id")
+ self.assertIsNone(user, "access token was not deleted without device_id")
class TokenGenerator:
@@ -117,4 +106,4 @@ class TokenGenerator:
def generate(self, user_id):
self._last_issued_token += 1
- return u"%s-%d" % (user_id, self._last_issued_token,)
+ return u"%s-%d" % (user_id, self._last_issued_token)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index ae8ae94b6d..a1ea23b068 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -24,10 +24,9 @@ from tests.utils import setup_test_homeserver
class RoomStoreTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
- hs = yield setup_test_homeserver()
+ hs = yield setup_test_homeserver(self.addCleanup)
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
@@ -40,7 +39,7 @@ class RoomStoreTestCase(unittest.TestCase):
yield self.store.store_room(
self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(),
- is_public=True
+ is_public=True,
)
@defer.inlineCallbacks
@@ -49,17 +48,16 @@ class RoomStoreTestCase(unittest.TestCase):
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
- "is_public": True
+ "is_public": True,
},
- (yield self.store.get_room(self.room.to_string()))
+ (yield self.store.get_room(self.room.to_string())),
)
class RoomEventsStoreTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
- hs = setup_test_homeserver()
+ hs = setup_test_homeserver(self.addCleanup)
# Room events need the full datastore, for persist_event() and
# get_room_state()
@@ -69,18 +67,13 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test")
yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True
+ self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
)
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
yield self.store.persist_event(
- self.event_factory.create_event(
- room_id=self.room.to_string(),
- **kwargs
- )
+ self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
@defer.inlineCallbacks
@@ -88,22 +81,15 @@ class RoomEventsStoreTestCase(unittest.TestCase):
name = u"A-Room-Name"
yield self.inject_room_event(
- etype=EventTypes.Name,
- name=name,
- content={"name": name},
- depth=1,
+ etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
- state = yield self.store.get_current_state(
- room_id=self.room.to_string()
- )
+ state = yield self.store.get_current_state(room_id=self.room.to_string())
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
- {"type": "m.room.name",
- "room_id": self.room.to_string(),
- "name": name},
- state[0]
+ {"type": "m.room.name", "room_id": self.room.to_string(), "name": name},
+ state[0],
)
@defer.inlineCallbacks
@@ -111,22 +97,15 @@ class RoomEventsStoreTestCase(unittest.TestCase):
topic = u"A place for things"
yield self.inject_room_event(
- etype=EventTypes.Topic,
- topic=topic,
- content={"topic": topic},
- depth=1,
+ etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
- state = yield self.store.get_current_state(
- room_id=self.room.to_string()
- )
+ state = yield self.store.get_current_state(room_id=self.room.to_string())
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
- {"type": "m.room.topic",
- "room_id": self.room.to_string(),
- "topic": topic},
- state[0]
+ {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic},
+ state[0],
)
# Not testing the various 'level' methods for now because there's lots
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index c5fd54f67e..978c66133d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -22,16 +22,14 @@ from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
class RoomMemberStoreTestCase(unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(
- resource_for_federation=Mock(),
- http_client=None,
+ self.addCleanup, resource_for_federation=Mock(), http_client=None
)
# We can't test the RoomMemberStore on its own without the other event
# storage logic
@@ -47,15 +45,19 @@ class RoomMemberStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
+ yield create_room(hs, self.room.to_string(), self.u_alice.to_string())
+
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None):
- builder = self.event_builder_factory.new({
- "type": EventTypes.Member,
- "sender": user.to_string(),
- "state_key": user.to_string(),
- "room_id": room.to_string(),
- "content": {"membership": membership},
- })
+ builder = self.event_builder_factory.new(
+ {
+ "type": EventTypes.Member,
+ "sender": user.to_string(),
+ "state_key": user.to_string(),
+ "room_id": room.to_string(),
+ "content": {"membership": membership},
+ }
+ )
event, context = yield self.event_creation_handler.create_new_client_event(
builder
@@ -71,9 +73,12 @@ class RoomMemberStoreTestCase(unittest.TestCase):
self.assertEquals(
[self.room.to_string()],
- [m.room_id for m in (
- yield self.store.get_rooms_for_user_where_membership_is(
- self.u_alice.to_string(), [Membership.JOIN]
+ [
+ m.room_id
+ for m in (
+ yield self.store.get_rooms_for_user_where_membership_is(
+ self.u_alice.to_string(), [Membership.JOIN]
+ )
)
- )]
+ ],
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 7a76d67b8c..086a39d834 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -18,6 +18,7 @@ import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
+from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
import tests.unittest
@@ -33,7 +34,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield tests.utils.setup_test_homeserver()
+ hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
@@ -45,20 +46,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True
+ self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
)
@defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content):
- builder = self.event_builder_factory.new({
- "type": typ,
- "sender": sender.to_string(),
- "state_key": state_key,
- "room_id": room.to_string(),
- "content": content,
- })
+ builder = self.event_builder_factory.new(
+ {
+ "type": typ,
+ "sender": sender.to_string(),
+ "state_key": state_key,
+ "room_id": room.to_string(),
+ "content": content,
+ }
+ )
event, context = yield self.event_creation_handler.create_new_client_event(
builder
@@ -75,171 +76,289 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(len(s1), len(s2))
@defer.inlineCallbacks
+ def test_get_state_groups_ids(self):
+ e1 = yield self.inject_state_event(
+ self.room, self.u_alice, EventTypes.Create, '', {}
+ )
+ e2 = yield self.inject_state_event(
+ self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
+ )
+
+ state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id])
+ self.assertEqual(len(state_group_map), 1)
+ state_map = list(state_group_map.values())[0]
+ self.assertDictEqual(
+ state_map,
+ {
+ (EventTypes.Create, ''): e1.event_id,
+ (EventTypes.Name, ''): e2.event_id,
+ },
+ )
+
+ @defer.inlineCallbacks
+ def test_get_state_groups(self):
+ e1 = yield self.inject_state_event(
+ self.room, self.u_alice, EventTypes.Create, '', {}
+ )
+ e2 = yield self.inject_state_event(
+ self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
+ )
+
+ state_group_map = yield self.store.get_state_groups(
+ self.room, [e2.event_id])
+ self.assertEqual(len(state_group_map), 1)
+ state_list = list(state_group_map.values())[0]
+
+ self.assertEqual(
+ {ev.event_id for ev in state_list},
+ {e1.event_id, e2.event_id},
+ )
+
+ @defer.inlineCallbacks
def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, '', {},
+ self.room, self.u_alice, EventTypes.Create, '', {}
)
e2 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Name, '', {
- "name": "test room"
- },
+ self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
)
e3 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), {
- "membership": Membership.JOIN
- },
+ self.room,
+ self.u_alice,
+ EventTypes.Member,
+ self.u_alice.to_string(),
+ {"membership": Membership.JOIN},
)
e4 = yield self.inject_state_event(
- self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
- "membership": Membership.JOIN
- },
+ self.room,
+ self.u_bob,
+ EventTypes.Member,
+ self.u_bob.to_string(),
+ {"membership": Membership.JOIN},
)
e5 = yield self.inject_state_event(
- self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
- "membership": Membership.LEAVE
- },
+ self.room,
+ self.u_bob,
+ EventTypes.Member,
+ self.u_bob.to_string(),
+ {"membership": Membership.LEAVE},
)
# check we get the full state as of the final event
state = yield self.store.get_state_for_event(
- e5.event_id, None, filtered_types=None
+ e5.event_id,
)
self.assertIsNotNone(e4)
- self.assertStateMapEqual({
- (e1.type, e1.state_key): e1,
- (e2.type, e2.state_key): e2,
- (e3.type, e3.state_key): e3,
- # e4 is overwritten by e5
- (e5.type, e5.state_key): e5,
- }, state)
+ self.assertStateMapEqual(
+ {
+ (e1.type, e1.state_key): e1,
+ (e2.type, e2.state_key): e2,
+ (e3.type, e3.state_key): e3,
+ # e4 is overwritten by e5
+ (e5.type, e5.state_key): e5,
+ },
+ state,
+ )
# check we can filter to the m.room.name event (with a '' state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Name, '')], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, '')])
)
- self.assertStateMapEqual({
- (e2.type, e2.state_key): e2,
- }, state)
+ self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Name, None)], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
- self.assertStateMapEqual({
- (e2.type, e2.state_key): e2,
- }, state)
+ self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Member, None)], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
- self.assertStateMapEqual({
- (e3.type, e3.state_key): e3,
- (e5.type, e5.state_key): e5,
- }, state)
+ self.assertStateMapEqual(
+ {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
+ )
- # check we can use filter_types to grab a specific room member
- # without filtering out the other event types
+ # check we can grab a specific room member without filtering out the
+ # other event types
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Member, self.u_alice.to_string())],
- filtered_types=[EventTypes.Member],
+ e5.event_id,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {self.u_alice.to_string()}},
+ include_others=True,
+ )
)
- self.assertStateMapEqual({
- (e1.type, e1.state_key): e1,
- (e2.type, e2.state_key): e2,
- (e3.type, e3.state_key): e3,
- }, state)
+ self.assertStateMapEqual(
+ {
+ (e1.type, e1.state_key): e1,
+ (e2.type, e2.state_key): e2,
+ (e3.type, e3.state_key): e3,
+ },
+ state,
+ )
- # check that types=[], filtered_types=[EventTypes.Member]
- # doesn't return all members
+ # check that we can grab everything except members
state = yield self.store.get_state_for_event(
- e5.event_id, [], filtered_types=[EventTypes.Member],
+ e5.event_id, state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
- self.assertStateMapEqual({
- (e1.type, e1.state_key): e1,
- (e2.type, e2.state_key): e2,
- }, state)
+ self.assertStateMapEqual(
+ {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
+ )
#######################################################
- # _get_some_state_from_cache tests against a full cache
+ # _get_state_for_group_using_cache tests against a full cache
#######################################################
room_id = self.room.to_string()
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
- group = group_ids.keys()[0]
+ group = list(group_ids.keys())[0]
+
+ # test _get_state_for_group_using_cache correctly filters out members
+ # with types=[]
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache, group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
+ (e1.type, e1.state_key): e1.event_id,
+ (e2.type, e2.state_key): e2.event_id,
+ },
+ state_dict,
+ )
- # test _get_some_state_from_cache correctly filters out members with types=[]
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [], filtered_types=[EventTypes.Member]
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
- self.assertDictEqual({
- (e1.type, e1.state_key): e1.event_id,
- (e2.type, e2.state_key): e2.event_id,
- }, state_dict)
+ self.assertDictEqual({}, state_dict)
+
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with wildcard types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
+ )
- # test _get_some_state_from_cache correctly filters in members with wildcard types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
+ (e1.type, e1.state_key): e1.event_id,
+ (e2.type, e2.state_key): e2.event_id,
+ },
+ state_dict,
+ )
+
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
- self.assertDictEqual({
- (e1.type, e1.state_key): e1.event_id,
- (e2.type, e2.state_key): e2.event_id,
- (e3.type, e3.state_key): e3.event_id,
- # e4 is overwritten by e5
- (e5.type, e5.state_key): e5.event_id,
- }, state_dict)
+ self.assertDictEqual(
+ {
+ (e3.type, e3.state_key): e3.event_id,
+ # e4 is overwritten by e5
+ (e5.type, e5.state_key): e5.event_id,
+ },
+ state_dict,
+ )
- # test _get_some_state_from_cache correctly filters in members with specific types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
- self.assertDictEqual({
- (e1.type, e1.state_key): e1.event_id,
- (e2.type, e2.state_key): e2.event_id,
- (e5.type, e5.state_key): e5.event_id,
- }, state_dict)
+ self.assertDictEqual(
+ {
+ (e1.type, e1.state_key): e1.event_id,
+ (e2.type, e2.state_key): e2.event_id,
+ },
+ state_dict,
+ )
+
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
+ )
- # test _get_some_state_from_cache correctly filters in members with specific types
- # and no filtered_types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [(EventTypes.Member, e5.state_key)], filtered_types=None
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
)
self.assertEqual(is_all, True)
- self.assertDictEqual({
- (e5.type, e5.state_key): e5.event_id,
- }, state_dict)
+ self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
#######################################################
# deliberately remove e2 (room name) from the _state_group_cache
- (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
+ (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
+ group
+ )
self.assertEqual(is_all, True)
self.assertEqual(known_absent, set())
- self.assertDictEqual(state_dict_ids, {
- (e1.type, e1.state_key): e1.event_id,
- (e2.type, e2.state_key): e2.event_id,
- (e3.type, e3.state_key): e3.event_id,
- # e4 is overwritten by e5
- (e5.type, e5.state_key): e5.event_id,
- })
+ self.assertDictEqual(
+ state_dict_ids,
+ {
+ (e1.type, e1.state_key): e1.event_id,
+ (e2.type, e2.state_key): e2.event_id,
+ },
+ )
state_dict_ids.pop((e2.type, e2.state_key))
self.store._state_group_cache.invalidate(group)
@@ -248,72 +367,127 @@ class StateStoreTestCase(tests.unittest.TestCase):
key=group,
value=state_dict_ids,
# list fetched keys so it knows it's partial
- fetched_keys=(
- (e1.type, e1.state_key),
- (e3.type, e3.state_key),
- (e5.type, e5.state_key),
- )
+ fetched_keys=((e1.type, e1.state_key),),
)
- (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
+ (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
+ group
+ )
self.assertEqual(is_all, False)
- self.assertEqual(known_absent, set([
- (e1.type, e1.state_key),
- (e3.type, e3.state_key),
- (e5.type, e5.state_key),
- ]))
- self.assertDictEqual(state_dict_ids, {
- (e1.type, e1.state_key): e1.event_id,
- (e3.type, e3.state_key): e3.event_id,
- (e5.type, e5.state_key): e5.event_id,
- })
+ self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
+ self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
############################################
# test that things work with a partial cache
- # test _get_some_state_from_cache correctly filters out members with types=[]
+ # test _get_state_for_group_using_cache correctly filters out members
+ # with types=[]
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [], filtered_types=[EventTypes.Member]
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache, group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
- self.assertDictEqual({
- (e1.type, e1.state_key): e1.event_id,
- }, state_dict)
+ self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+
+ room_id = self.room.to_string()
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
+ )
- # test _get_some_state_from_cache correctly filters in members wildcard types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({}, state_dict)
+
+ # test _get_state_for_group_using_cache correctly filters in members
+ # wildcard types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
- self.assertDictEqual({
- (e1.type, e1.state_key): e1.event_id,
- (e3.type, e3.state_key): e3.event_id,
- # e4 is overwritten by e5
- (e5.type, e5.state_key): e5.event_id,
- }, state_dict)
+ self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
+ (e3.type, e3.state_key): e3.event_id,
+ (e5.type, e5.state_key): e5.event_id,
+ },
+ state_dict,
+ )
- # test _get_some_state_from_cache correctly filters in members with specific types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
- self.assertDictEqual({
- (e1.type, e1.state_key): e1.event_id,
- (e5.type, e5.state_key): e5.event_id,
- }, state_dict)
+ self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
+ )
- # test _get_some_state_from_cache correctly filters in members with specific types
- # and no filtered_types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [(EventTypes.Member, e5.state_key)], filtered_types=None
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
+ )
+
+ self.assertEqual(is_all, False)
+ self.assertDictEqual({}, state_dict)
+
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
)
self.assertEqual(is_all, True)
- self.assertDictEqual({
- (e5.type, e5.state_key): e5.event_id,
- }, state_dict)
+ self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
new file mode 100644
index 0000000000..14169afa96
--- /dev/null
+++ b/tests/storage/test_transactions.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tests.unittest import HomeserverTestCase
+
+
+class TransactionStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
+
+ def test_get_set_transactions(self):
+ """Tests that we can successfully get a non-existent entry for
+ destination retries, as well as testing tht we can set and get
+ correctly.
+ """
+ d = self.store.get_destination_retry_timings("example.com")
+ r = self.get_success(d)
+ self.assertIsNone(r)
+
+ d = self.store.set_destination_retry_timings("example.com", 50, 100)
+ self.get_success(d)
+
+ d = self.store.get_destination_retry_timings("example.com")
+ r = self.get_success(d)
+
+ self.assert_dict({"retry_last_ts": 50, "retry_interval": 100}, r)
+
+ def test_initial_set_transactions(self):
+ """Tests that we can successfully set the destination retries (there
+ was a bug around invalidating the cache that broke this)
+ """
+ d = self.store.set_destination_retry_timings("example.com", 50, 100)
+ self.get_success(d)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 23fad12bca..0dde1ab2fe 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -29,8 +29,8 @@ BOBBY = "@bobby:a"
class UserDirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- self.hs = yield setup_test_homeserver()
- self.store = UserDirectoryStore(None, self.hs)
+ self.hs = yield setup_test_homeserver(self.addCleanup)
+ self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
@@ -39,20 +39,12 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
{
ALICE: ProfileInfo(None, "alice"),
BOB: ProfileInfo(None, "bob"),
- BOBBY: ProfileInfo(None, "bobby")
+ BOBBY: ProfileInfo(None, "bobby"),
},
)
- yield self.store.add_users_to_public_room(
- "!room:id",
- [ALICE, BOB],
- )
+ yield self.store.add_users_to_public_room("!room:id", [ALICE, BOB])
yield self.store.add_users_who_share_room(
- "!room:id",
- False,
- (
- (ALICE, BOB),
- (BOB, ALICE),
- ),
+ "!room:id", False, ((ALICE, BOB), (BOB, ALICE))
)
@defer.inlineCallbacks
@@ -62,11 +54,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
r = yield self.store.search_user_dir(ALICE, "bob", 10)
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
- self.assertDictEqual(r["results"][0], {
- "user_id": BOB,
- "display_name": "bob",
- "avatar_url": None,
- })
+ self.assertDictEqual(
+ r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
+ )
@defer.inlineCallbacks
def test_search_user_dir_all_users(self):
@@ -75,15 +65,13 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
r = yield self.store.search_user_dir(ALICE, "bob", 10)
self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"]))
- self.assertDictEqual(r["results"][0], {
- "user_id": BOB,
- "display_name": "bob",
- "avatar_url": None,
- })
- self.assertDictEqual(r["results"][1], {
- "user_id": BOBBY,
- "display_name": "bobby",
- "avatar_url": None,
- })
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BOB, "display_name": "bob", "avatar_url": None},
+ )
+ self.assertDictEqual(
+ r["results"][1],
+ {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
+ )
finally:
self.hs.config.user_directory_search_all_users = False
|