diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 2dc5052249..63d8633582 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 51714a2b06..24fa8dbb45 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -18,17 +18,14 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import UserID
-from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
-class RoomComplexityTests(unittest.HomeserverTestCase):
+class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -41,25 +38,6 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05}
return config
- def prepare(self, reactor, clock, homeserver):
- class Authenticator(object):
- def authenticate_request(self, request, content):
- return defer.succeed("otherserver.nottld")
-
- ratelimiter = FederationRateLimiter(
- clock,
- FederationRateLimitConfig(
- window_size=1,
- sleep_limit=1,
- sleep_msec=1,
- reject_limit=1000,
- concurrent_requests=1000,
- ),
- )
- server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
-
def test_complexity_simple(self):
u1 = self.register_user("u1", "pass")
@@ -105,7 +83,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
d = handler._remote_join(
None,
- ["otherserver.example"],
+ ["other.example.com"],
"roomid",
UserID.from_string(u1),
{"membership": "join"},
@@ -146,7 +124,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
d = handler._remote_join(
None,
- ["otherserver.example"],
+ ["other.example.com"],
room_1,
UserID.from_string(u1),
{"membership": "join"},
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index cce8d8c6de..d456267b87 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.types import ReadReceipt
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
class FederationSenderTestCases(HomeserverTestCase):
@@ -29,6 +29,7 @@ class FederationSenderTestCases(HomeserverTestCase):
federation_transport_client=Mock(spec=["send_transaction"]),
)
+ @override_config({"send_federation": True})
def test_send_receipts(self):
mock_state_handler = self.hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
@@ -69,6 +70,7 @@ class FederationSenderTestCases(HomeserverTestCase):
],
)
+ @override_config({"send_federation": True})
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index b08be451aa..1ec8c40901 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,8 @@ import logging
from synapse.events import FrozenEvent
from synapse.federation.federation_server import server_matches_acl_event
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
from tests import unittest
@@ -41,6 +44,66 @@ class ServerACLsTestCase(unittest.TestCase):
self.assertTrue(server_matches_acl_event("1:2:3:4", e))
+class StateQueryTests(unittest.FederatingHomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def test_without_event_id(self):
+ """
+ Querying v1/state/<room_id> without an event ID will return the current
+ known state.
+ """
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.inject_room_member(room_1, "@user:other.example.com", "join")
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ self.assertEqual(
+ channel.json_body["room_version"],
+ self.hs.config.default_room_version.identifier,
+ )
+
+ members = set(
+ map(
+ lambda x: x["state_key"],
+ filter(
+ lambda x: x["type"] == "m.room.member", channel.json_body["pdus"]
+ ),
+ )
+ )
+
+ self.assertEqual(members, set(["@user:other.example.com", u1]))
+ self.assertEqual(len(channel.json_body["pdus"]), 6)
+
+ def test_needs_to_be_in_room(self):
+ """
+ Querying v1/state/<room_id> requires the server
+ be in the room to provide data.
+ """
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
+ )
+ self.render(request)
+ self.assertEquals(403, channel.code, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+
+
def _create_acl_event(content):
return FrozenEvent(
{
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
new file mode 100644
index 0000000000..27d83bb7d9
--- /dev/null
+++ b/tests/federation/transport/test_server.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.federation.transport import server
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ class Authenticator(object):
+ def authenticate_request(self, request, content):
+ return defer.succeed("otherserver.nottld")
+
+ ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig())
+ server.register_servlets(
+ homeserver, self.resource, Authenticator(), ratelimiter
+ )
+
+ @override_config({"allow_public_rooms_over_federation": False})
+ def test_blocked_public_room_list_over_federation(self):
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/v1/publicRooms"
+ )
+ self.render(request)
+ self.assertEquals(403, channel.code)
+
+ @override_config({"allow_public_rooms_over_federation": True})
+ def test_open_public_room_list_over_federation(self):
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/v1/publicRooms"
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 854eb6c024..fdfa2cbbc4 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -183,6 +183,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
+ test_replace_master_key.skip = (
+ "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
+ )
+
@defer.inlineCallbacks
def test_reupload_signatures(self):
"""re-uploading a signature should not fail"""
@@ -503,3 +507,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
],
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
)
+
+ test_upload_signatures.skip = (
+ "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
+ )
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 0bb96674a2..70f172eb02 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ version_etag = res["etag"]
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "1",
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
+ "count": 0,
},
)
# check we can retrieve it as a specific version
res = yield self.handler.get_version_info(self.local_user, "1")
+ self.assertEqual(res["etag"], version_etag)
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "1",
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
+ "count": 0,
},
)
@@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "2",
"algorithm": "m.megolm_backup.v1",
"auth_data": "second_version_auth_data",
+ "count": 0,
},
)
@@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"]
self.assertDictEqual(
res,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": version,
+ "count": 0,
},
)
@@ -207,12 +218,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"] # etag is opaque, so don't test its contents
self.assertDictEqual(
res,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": version,
+ "count": 0,
},
)
@@ -409,6 +422,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ # get the etag to compare to future versions
+ res = yield self.handler.get_version_info(self.local_user)
+ backup_etag = res["etag"]
+ self.assertEqual(res["count"], 1)
+
new_room_keys = copy.deepcopy(room_keys)
new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]
@@ -423,6 +441,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"SSBBTSBBIEZJU0gK",
)
+ # the etag should be the same since the session did not change
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertEqual(res["etag"], backup_etag)
+
# test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
@@ -432,6 +454,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
+ # the etag should NOT be equal now, since the key changed
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertNotEqual(res["etag"], backup_etag)
+ backup_etag = res["etag"]
+
# test that a session with a higher forwarded_count doesn't replace one
# with a lower forwarding count
new_room_key["forwarded_count"] = 2
@@ -443,6 +470,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
+ # the etag should be the same since the session did not change
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertEqual(res["etag"], backup_etag)
+
# TODO: check edge cases as well as the common variations here
@defer.inlineCallbacks
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index e0075ccd32..d9d312f0fb 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -42,16 +42,16 @@ class StatsRoomTests(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_rooms",
@@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
@@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
def get_all_room_state(self):
- return self.store._simple_select_list(
+ return self.store.db.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)
@@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
return self.get_success(
- self.store._simple_select_one(
+ self.store.db.simple_select_one(
table + "_historical",
{id_col: stat_id, end_ts: end_ts},
cols,
@@ -108,8 +108,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Do the initial population of the stats via the background update
self._add_background_updates()
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
def test_initial_room(self):
"""
@@ -141,8 +145,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
r = self.get_success(self.get_all_room_state())
@@ -178,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# the position that the deltas should begin at, once they take over.
self.hs.config.stats_enabled = True
self.handler.stats_enabled = True
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_update_one(
+ self.store.db.simple_update_one(
table="stats_incremental_position",
keyvalues={},
updatevalues={"stream_id": 0},
@@ -188,14 +196,18 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
# Now, before the table is actually ingested, add some more events.
self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
@@ -205,13 +217,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Now do the initial ingestion.
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -221,9 +233,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- self.store._all_done = False
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ self.store.db.updates._all_done = False
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
self.reactor.advance(86401)
@@ -653,15 +669,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# preparation stage of the initial background update
# Ugh, have to reset this flag
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_delete(
+ self.store.db.simple_delete(
"room_stats_current", {"1": 1}, "test_delete_stats"
)
)
self.get_success(
- self.store._simple_delete(
+ self.store.db.simple_delete(
"user_stats_current", {"1": 1}, "test_delete_stats"
)
)
@@ -673,9 +689,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# now do the background updates
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_rooms",
@@ -685,7 +701,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
@@ -695,7 +711,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -705,8 +721,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
r1stats_complete = self._get_current_stats("room", r1)
u1stats_complete = self._get_current_stats("user", u1)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 31f54bbd7d..758ee071a5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,54 +12,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
-from synapse.handlers.sync import SyncConfig, SyncHandler
+from synapse.handlers.sync import SyncConfig
from synapse.types import UserID
import tests.unittest
import tests.utils
-from tests.utils import setup_test_homeserver
-class SyncTestCase(tests.unittest.TestCase):
+class SyncTestCase(tests.unittest.HomeserverTestCase):
""" Tests Sync Handler. """
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.sync_handler = SyncHandler(self.hs)
+ def prepare(self, reactor, clock, hs):
+ self.hs = hs
+ self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
- @defer.inlineCallbacks
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
sync_config = self._generate_sync_config(user_id1)
+ self.reactor.advance(100) # So we get not 0 time
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
# Check that the happy case does not throw errors
- yield self.store.upsert_monthly_active_user(user_id1)
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
+ self.get_success(self.store.upsert_monthly_active_user(user_id1))
+ self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works
self.hs.config.hs_disabled = True
- with self.assertRaises(ResourceLimitError) as e:
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ e = self.get_failure(
+ self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ )
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.hs.config.hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
- with self.assertRaises(ResourceLimitError) as e:
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ e = self.get_failure(
+ self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ )
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5ec568f4e6..92b8726093 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -24,6 +24,7 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
from tests import unittest
+from tests.unittest import override_config
from tests.utils import register_federation_servlets
# Some local users to test with
@@ -162,7 +163,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -174,6 +177,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"send_federation": True})
def test_started_typing_remote_send(self):
self.room_members = [U_APPLE, U_ONION]
@@ -225,7 +229,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -237,6 +243,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"send_federation": True})
def test_stopped_typing(self):
self.room_members = [U_APPLE, U_BANANA, U_ONION]
@@ -276,7 +283,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -297,7 +306,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -314,7 +325,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 2)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+ )
self.assertEquals(
events[0],
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -332,7 +345,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index c5e91a8c41..26071059d2 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -158,7 +158,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_in_public_rooms(self):
r = self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
)
)
@@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_who_share_private_rooms(self):
return self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
@@ -181,10 +181,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_createtables",
@@ -193,7 +193,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_rooms",
@@ -203,7 +203,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_users",
@@ -213,7 +213,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_cleanup",
@@ -255,8 +255,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
shares_private = self.get_users_who_share_private_rooms()
public_users = self.get_users_in_public_rooms()
@@ -290,8 +294,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
shares_private = self.get_users_who_share_private_rooms()
public_users = self.get_users_in_public_rooms()
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 4f924ce451..3dae83c543 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,6 +20,7 @@ from synapse.replication.tcp.client import (
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.storage.database import Database
from tests import unittest
from tests.server import FakeTransport
@@ -42,13 +43,18 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
self.master_store = self.hs.get_datastore()
self.storage = hs.get_storage()
- self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
+ self.slaved_store = self.STORE_TYPE(
+ Database(hs), self.hs.get_db_conn(), self.hs
+ )
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = server_factory.streamer
+ handler_factory = Mock()
self.replication_handler = ReplicationClientHandler(self.slaved_store)
+ self.replication_handler.factory = handler_factory
+
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index ce3835ae6a..1d14e77255 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
+
from synapse.replication.tcp.commands import ReplicateCommand
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -30,7 +32,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server = server_factory.buildProtocol(None)
# build a replication client, with a dummy handler
+ handler_factory = Mock()
self.test_handler = TestReplicationClientHandler()
+ self.test_handler.factory = handler_factory
self.client = ClientReplicationStreamProtocol(
"client", "test", clock, self.test_handler
)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 9575058252..0ed2594381 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -632,7 +632,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
"state_groups_state",
):
count = self.get_success(
- self.store._simple_select_one_onecol(
+ self.store.db.simple_select_one_onecol(
table=table,
keyvalues={"room_id": room_id},
retcol="COUNT(*)",
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
new file mode 100644
index 0000000000..5e9c07ebf3
--- /dev/null
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.constants import EventContentFields, EventTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import room
+
+from tests import unittest
+
+
+class EphemeralMessageTestCase(unittest.HomeserverTestCase):
+
+ user_id = "@user:test"
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["enable_ephemeral_messages"] = True
+
+ self.hs = self.setup_test_homeserver(config=config)
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.room_id = self.helper.create_room_as(self.user_id)
+
+ def test_message_expiry_no_delay(self):
+ """Tests that sending a message sent with a m.self_destruct_after field set to the
+ past results in that event being deleted right away.
+ """
+ # Send a message in the room that has expired. From here, the reactor clock is
+ # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock
+ # is at 0ms the code path is the same if the event's expiry timestamp is the
+ # current timestamp.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "hello",
+ EventContentFields.SELF_DESTRUCT_AFTER: 0,
+ },
+ )
+ event_id = res["event_id"]
+
+ # Check that we can't retrieve the content of the event.
+ event_content = self.get_event(self.room_id, event_id)["content"]
+ self.assertFalse(bool(event_content), event_content)
+
+ def test_message_expiry_delay(self):
+ """Tests that sending a message with a m.self_destruct_after field set to the
+ future results in that event not being deleted right away, but advancing the
+ clock to after that expiry timestamp causes the event to be deleted.
+ """
+ # Send a message in the room that'll expire in 1s.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "hello",
+ EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000,
+ },
+ )
+ event_id = res["event_id"]
+
+ # Check that we can retrieve the content of the event before it has expired.
+ event_content = self.get_event(self.room_id, event_id)["content"]
+ self.assertTrue(bool(event_content), event_content)
+
+ # Advance the clock to after the deletion.
+ self.reactor.advance(1)
+
+ # Check that we can't retrieve the content of the event anymore.
+ event_content = self.get_event(self.room_id, event_id)["content"]
+ self.assertFalse(bool(event_content), event_content)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
new file mode 100644
index 0000000000..95475bb651
--- /dev/null
+++ b/tests/rest/client/test_retention.py
@@ -0,0 +1,293 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from synapse.api.constants import EventTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.visibility import filter_events_for_client
+
+from tests import unittest
+
+one_hour_ms = 3600000
+one_day_ms = one_hour_ms * 24
+
+
+class RetentionTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["retention"] = {
+ "enabled": True,
+ "default_policy": {
+ "min_lifetime": one_day_ms,
+ "max_lifetime": one_day_ms * 3,
+ },
+ "allowed_lifetime_min": one_day_ms,
+ "allowed_lifetime_max": one_day_ms * 3,
+ }
+
+ self.hs = self.setup_test_homeserver(config=config)
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.token = self.login("user", "password")
+
+ def test_retention_state_event(self):
+ """Tests that the server configuration can limit the values a user can set to the
+ room's retention policy.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms * 4},
+ tok=self.token,
+ expect_code=400,
+ )
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_hour_ms},
+ tok=self.token,
+ expect_code=400,
+ )
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": lifetime},
+ tok=self.token,
+ )
+
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_without_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by the server's configuration's default retention policy.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self._test_retention_event_purged(room_id, one_day_ms * 2)
+
+ def test_visibility(self):
+ """Tests that synapse.visibility.filter_events_for_client correctly filters out
+ outdated events
+ """
+ store = self.hs.get_datastore()
+ storage = self.hs.get_storage()
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ events = []
+
+ # Send a first event, which should be filtered out at the end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ # Get the event from the store so that we end up with a FrozenEvent that we can
+ # give to filter_events_for_client. We need to do this now because the event won't
+ # be in the database anymore after it has expired.
+ events.append(self.get_success(store.get_event(resp.get("event_id"))))
+
+ # Advance the time by 2 days. We're using the default retention policy, therefore
+ # after this the first event will still be valid.
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ # Send another event, which shouldn't get filtered out.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ valid_event_id = resp.get("event_id")
+
+ events.append(self.get_success(store.get_event(valid_event_id)))
+
+ # Advance the time by anothe 2 days. After this, the first event should be
+ # outdated but not the second one.
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ # Run filter_events_for_client with our list of FrozenEvents.
+ filtered_events = self.get_success(
+ filter_events_for_client(storage, self.user_id, events)
+ )
+
+ # We should only get one event back.
+ self.assertEqual(len(filtered_events), 1, filtered_events)
+ # That event should be the second, not outdated event.
+ self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
+
+ def _test_retention_event_purged(self, room_id, increment):
+ # Get the create event to, later, check that we can still access it.
+ message_handler = self.hs.get_message_handler()
+ create_event = self.get_success(
+ message_handler.get_room_data(self.user_id, room_id, EventTypes.Create)
+ )
+
+ # Send a first event to the room. This is the event we'll want to be purged at the
+ # end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ expired_event_id = resp.get("event_id")
+
+ # Check that we can retrieve the event.
+ expired_event = self.get_event(room_id, expired_event_id)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
+
+ # Advance the time.
+ self.reactor.advance(increment / 1000)
+
+ # Send another event. We need this because the purge job won't purge the most
+ # recent event in the room.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ valid_event_id = resp.get("event_id")
+
+ # Advance the time again. Now our first event should have expired but our second
+ # one should still be kept.
+ self.reactor.advance(increment / 1000)
+
+ # Check that the event has been purged from the database.
+ self.get_event(room_id, expired_event_id, expected_code=404)
+
+ # Check that the event that hasn't been purged can still be retrieved.
+ valid_event = self.get_event(room_id, valid_event_id)
+ self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
+
+ # Check that we can still access state events that were sent before the event that
+ # has been purged.
+ self.get_event(room_id, create_event.event_id)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url, access_token=self.token)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
+
+
+class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["retention"] = {
+ "enabled": True,
+ }
+
+ mock_federation_client = Mock(spec=["backfill"])
+
+ self.hs = self.setup_test_homeserver(
+ config=config, federation_client=mock_federation_client,
+ )
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.token = self.login("user", "password")
+
+ def test_no_default_policy(self):
+ """Tests that an event doesn't get expired if there is neither a default retention
+ policy nor a policy specific to the room.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self._test_retention(room_id)
+
+ def test_state_policy(self):
+ """Tests that an event gets correctly expired if there is no default retention
+ policy but there's a policy specific to the room.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set the maximum lifetime to 35 days so that the first event gets expired but not
+ # the second one.
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms * 35},
+ tok=self.token,
+ )
+
+ self._test_retention(room_id, expected_code_for_first_event=404)
+
+ def _test_retention(self, room_id, expected_code_for_first_event=200):
+ # Send a first event to the room. This is the event we'll want to be purged at the
+ # end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ first_event_id = resp.get("event_id")
+
+ # Check that we can retrieve the event.
+ expired_event = self.get_event(room_id, first_event_id)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
+
+ # Advance the time by a month.
+ self.reactor.advance(one_day_ms * 30 / 1000)
+
+ # Send another event. We need this because the purge job won't purge the most
+ # recent event in the room.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ second_event_id = resp.get("event_id")
+
+ # Advance the time by another month.
+ self.reactor.advance(one_day_ms * 30 / 1000)
+
+ # Check if the event has been purged from the database.
+ first_event = self.get_event(
+ room_id, first_event_id, expected_code=expected_code_for_first_event
+ )
+
+ if expected_code_for_first_event == 200:
+ self.assertEqual(
+ first_event.get("content", {}).get("body"), "1", first_event
+ )
+
+ # Check that the event that hasn't been purged can still be retrieved.
+ second_event = self.get_event(room_id, second_event_id)
+ self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url, access_token=self.token)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 66c2b68707..0fdff79aa7 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -15,6 +15,8 @@
from mock import Mock
+from twisted.internet import defer
+
from synapse.rest.client.v1 import presence
from synapse.types import UserID
@@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
hs.presence_handler = Mock()
+ hs.presence_handler.set_state.return_value = defer.succeed(None)
return hs
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 140d8b3772..12c5e95cb5 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase):
]
)
+ self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
+ self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
+ self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
+ self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
+ self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
+ Mock()
+ )
+
hs = yield setup_test_homeserver(
self.addCleanup,
"test",
@@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
)
def _get_user_by_req(request=None, allow_guest=False):
- return synapse.types.create_requester(myid)
+ return defer.succeed(synapse.types.create_requester(myid))
hs.get_auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 5e38fd6ced..1ca7fa742f 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -25,7 +27,9 @@ from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import login, profile, room
+from synapse.util.stringutils import random_string
from tests import unittest
@@ -811,104 +815,77 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
- def test_filter_labels(self):
- """Test that we can filter by a label."""
- message_filter = json.dumps(
- {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]}
- )
-
- events = self._test_filter_labels(message_filter)
-
- self.assertEqual(len(events), 2, [event["content"] for event in events])
- self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
- self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+ def test_room_messages_purge(self):
+ store = self.hs.get_datastore()
+ pagination_handler = self.hs.get_pagination_handler()
- def test_filter_not_labels(self):
- """Test that we can filter by the absence of a label."""
- message_filter = json.dumps(
- {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]}
+ # Send a first message in the room, which will be removed by the purge.
+ first_event_id = self.helper.send(self.room_id, "message 1")["event_id"]
+ first_token = self.get_success(
+ store.get_topological_token_for_event(first_event_id)
)
- events = self._test_filter_labels(message_filter)
-
- self.assertEqual(len(events), 3, [event["content"] for event in events])
- self.assertEqual(events[0]["content"]["body"], "without label", events[0])
- self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
- self.assertEqual(
- events[2]["content"]["body"], "with two wrong labels", events[2]
+ # Send a second message in the room, which won't be removed, and which we'll
+ # use as the marker to purge events before.
+ second_event_id = self.helper.send(self.room_id, "message 2")["event_id"]
+ second_token = self.get_success(
+ store.get_topological_token_for_event(second_event_id)
)
- def test_filter_labels_not_labels(self):
- """Test that we can filter by both a label and the absence of another label."""
- sync_filter = json.dumps(
- {
- "types": [EventTypes.Message],
- "org.matrix.labels": ["#work"],
- "org.matrix.not_labels": ["#notfun"],
- }
- )
+ # Send a third event in the room to ensure we don't fall under any edge case
+ # due to our marker being the latest forward extremity in the room.
+ self.helper.send(self.room_id, "message 3")
- events = self._test_filter_labels(sync_filter)
-
- self.assertEqual(len(events), 1, [event["content"] for event in events])
- self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
-
- def _test_filter_labels(self, message_filter):
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={
- "msgtype": "m.text",
- "body": "with right label",
- EventContentFields.LABELS: ["#fun"],
- },
+ # Check that we get the first and second message when querying /messages.
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+ % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
)
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"msgtype": "m.text", "body": "without label"},
- )
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={
- "msgtype": "m.text",
- "body": "with wrong label",
- EventContentFields.LABELS: ["#work"],
- },
+ # Purge every event before the second event.
+ purge_id = random_string(16)
+ pagination_handler._purges_by_id[purge_id] = PurgeStatus()
+ self.get_success(
+ pagination_handler._purge_history(
+ purge_id=purge_id,
+ room_id=self.room_id,
+ token=second_token,
+ delete_local_events=True,
+ )
)
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={
- "msgtype": "m.text",
- "body": "with two wrong labels",
- EventContentFields.LABELS: ["#work", "#notfun"],
- },
+ # Check that we only get the second message through /message now that the first
+ # has been purged.
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+ % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
)
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={
- "msgtype": "m.text",
- "body": "with right label",
- EventContentFields.LABELS: ["#fun"],
- },
- )
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 1, [event["content"] for event in chunk])
- token = "s0_0_0_0_0_0_0_0_0"
+ # Check that we get no event, but also no error, when querying /messages with
+ # the token that was pointing at the first event, because we don't have it
+ # anymore.
request, channel = self.make_request(
"GET",
- "/rooms/%s/messages?access_token=x&from=%s&filter=%s"
- % (self.room_id, token, message_filter),
+ "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+ % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
)
self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
- return channel.json_body["chunk"]
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
class RoomSearchTestCase(unittest.HomeserverTestCase):
@@ -1106,3 +1083,517 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
res_displayname = channel.json_body["content"]["displayname"]
self.assertEqual(res_displayname, self.displayname, channel.result)
+
+
+class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
+ """Tests that clients can add a "reason" field to membership events and
+ that they get correctly added to the generated events and propagated.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok)
+
+ def test_join_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/join".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_leave_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_kick_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_ban_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_unban_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_invite_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_reject_invite_reason(self):
+ self.helper.invite(
+ self.room_id,
+ src=self.creator,
+ targ=self.second_user_id,
+ tok=self.creator_tok,
+ )
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def _check_for_reason(self, reason):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
+ self.room_id, self.second_user_id
+ ),
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ event_content = channel.json_body
+
+ self.assertEqual(event_content.get("reason"), reason, channel.result)
+
+
+class LabelsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ profile.register_servlets,
+ ]
+
+ # Filter that should only catch messages with the label "#fun".
+ FILTER_LABELS = {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#fun"],
+ }
+ # Filter that should only catch messages without the label "#fun".
+ FILTER_NOT_LABELS = {
+ "types": [EventTypes.Message],
+ "org.matrix.not_labels": ["#fun"],
+ }
+ # Filter that should only catch messages with the label "#work" but without the label
+ # "#notfun".
+ FILTER_LABELS_NOT_LABELS = {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#work"],
+ "org.matrix.not_labels": ["#notfun"],
+ }
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("test", "test")
+ self.tok = self.login("test", "test")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_context_filter_labels(self):
+ """Test that we can filter by a label on a /context request."""
+ event_id = self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/context/%s?filter=%s"
+ % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_before), 1, [event["content"] for event in events_before]
+ )
+ self.assertEqual(
+ events_before[0]["content"]["body"], "with right label", events_before[0]
+ )
+
+ events_after = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_after), 1, [event["content"] for event in events_after]
+ )
+ self.assertEqual(
+ events_after[0]["content"]["body"], "with right label", events_after[0]
+ )
+
+ def test_context_filter_not_labels(self):
+ """Test that we can filter by the absence of a label on a /context request."""
+ event_id = self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/context/%s?filter=%s"
+ % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_before), 1, [event["content"] for event in events_before]
+ )
+ self.assertEqual(
+ events_before[0]["content"]["body"], "without label", events_before[0]
+ )
+
+ events_after = channel.json_body["events_after"]
+
+ self.assertEqual(
+ len(events_after), 2, [event["content"] for event in events_after]
+ )
+ self.assertEqual(
+ events_after[0]["content"]["body"], "with wrong label", events_after[0]
+ )
+ self.assertEqual(
+ events_after[1]["content"]["body"], "with two wrong labels", events_after[1]
+ )
+
+ def test_context_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label on a
+ /context request.
+ """
+ event_id = self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/context/%s?filter=%s"
+ % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_before), 0, [event["content"] for event in events_before]
+ )
+
+ events_after = channel.json_body["events_after"]
+
+ self.assertEqual(
+ len(events_after), 1, [event["content"] for event in events_after]
+ )
+ self.assertEqual(
+ events_after[0]["content"]["body"], "with wrong label", events_after[0]
+ )
+
+ def test_messages_filter_labels(self):
+ """Test that we can filter by a label on a /messages request."""
+ self._send_labelled_messages_in_room()
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
+ % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)),
+ )
+ self.render(request)
+
+ events = channel.json_body["chunk"]
+
+ self.assertEqual(len(events), 2, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+ def test_messages_filter_not_labels(self):
+ """Test that we can filter by the absence of a label on a /messages request."""
+ self._send_labelled_messages_in_room()
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
+ % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)),
+ )
+ self.render(request)
+
+ events = channel.json_body["chunk"]
+
+ self.assertEqual(len(events), 4, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "without label", events[1])
+ self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2])
+ self.assertEqual(
+ events[3]["content"]["body"], "with two wrong labels", events[3]
+ )
+
+ def test_messages_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label on a
+ /messages request.
+ """
+ self._send_labelled_messages_in_room()
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
+ % (
+ self.room_id,
+ self.tok,
+ token,
+ json.dumps(self.FILTER_LABELS_NOT_LABELS),
+ ),
+ )
+ self.render(request)
+
+ events = channel.json_body["chunk"]
+
+ self.assertEqual(len(events), 1, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+ def test_search_filter_labels(self):
+ """Test that we can filter by a label on a /search request."""
+ request_data = json.dumps(
+ {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS,
+ }
+ }
+ }
+ )
+
+ self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "POST", "/search?access_token=%s" % self.tok, request_data
+ )
+ self.render(request)
+
+ results = channel.json_body["search_categories"]["room_events"]["results"]
+
+ self.assertEqual(
+ len(results), 2, [result["result"]["content"] for result in results],
+ )
+ self.assertEqual(
+ results[0]["result"]["content"]["body"],
+ "with right label",
+ results[0]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[1]["result"]["content"]["body"],
+ "with right label",
+ results[1]["result"]["content"]["body"],
+ )
+
+ def test_search_filter_not_labels(self):
+ """Test that we can filter by the absence of a label on a /search request."""
+ request_data = json.dumps(
+ {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_NOT_LABELS,
+ }
+ }
+ }
+ )
+
+ self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "POST", "/search?access_token=%s" % self.tok, request_data
+ )
+ self.render(request)
+
+ results = channel.json_body["search_categories"]["room_events"]["results"]
+
+ self.assertEqual(
+ len(results), 4, [result["result"]["content"] for result in results],
+ )
+ self.assertEqual(
+ results[0]["result"]["content"]["body"],
+ "without label",
+ results[0]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[1]["result"]["content"]["body"],
+ "without label",
+ results[1]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[2]["result"]["content"]["body"],
+ "with wrong label",
+ results[2]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[3]["result"]["content"]["body"],
+ "with two wrong labels",
+ results[3]["result"]["content"]["body"],
+ )
+
+ def test_search_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label on a
+ /search request.
+ """
+ request_data = json.dumps(
+ {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS_NOT_LABELS,
+ }
+ }
+ }
+ )
+
+ self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "POST", "/search?access_token=%s" % self.tok, request_data
+ )
+ self.render(request)
+
+ results = channel.json_body["search_categories"]["room_events"]["results"]
+
+ self.assertEqual(
+ len(results), 1, [result["result"]["content"] for result in results],
+ )
+ self.assertEqual(
+ results[0]["result"]["content"]["body"],
+ "with wrong label",
+ results[0]["result"]["content"]["body"],
+ )
+
+ def _send_labelled_messages_in_room(self):
+ """Sends several messages to a room with different labels (or without any) to test
+ filtering by label.
+ Returns:
+ The ID of the event to use if we're testing filtering on /context.
+ """
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=self.tok,
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ tok=self.tok,
+ )
+
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ tok=self.tok,
+ )
+ # Return this event's ID when we test filtering in /context requests.
+ event_id = res["event_id"]
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with wrong label",
+ EventContentFields.LABELS: ["#work"],
+ },
+ tok=self.tok,
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with two wrong labels",
+ EventContentFields.LABELS: ["#work", "#notfun"],
+ },
+ tok=self.tok,
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=self.tok,
+ )
+
+ return event_id
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 30fb77bac8..4bc3aaf02d 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -109,7 +109,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code)
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+ events = self.get_success(
+ self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+ )
self.assertEquals(
events[0],
[
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 8ea0cb05ea..e7417b3d14 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index dab87e5edf..c0d0d2b44e 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -203,6 +203,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{
+ "public_baseurl": "https://test_server",
"enable_registration_captcha": True,
"user_consent": {
"version": "1",
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 3283c0e47b..661c1f88b9 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/server.py b/tests/server.py
index f878aeaada..2b7cf4242e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -379,6 +379,7 @@ class FakeTransport(object):
disconnecting = False
disconnected = False
+ connected = True
buffer = attr.ib(default=b"")
producer = attr.ib(default=None)
autoflush = attr.ib(default=True)
@@ -402,6 +403,7 @@ class FakeTransport(object):
"FakeTransport: Delaying disconnect until buffer is flushed"
)
else:
+ self.connected = False
self.disconnected = True
def abortConnection(self):
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 9b81b536f5..d491ea2924 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
self.table_name = "table_" + hs.get_secrets().token_hex(6)
self.get_success(
- self.storage.runInteraction(
+ self.storage.db.runInteraction(
"create",
lambda x, *a: x.execute(*a),
"CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.storage.runInteraction(
+ self.storage.db.runInteraction(
"index",
lambda x, *a: x.execute(*a),
"CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["hello"], ["there"]]
self.get_success(
- self.storage.runInteraction(
+ self.storage.db.runInteraction(
"test",
- self.storage._simple_upsert_many_txn,
+ self.storage.db.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage._simple_select_list(
+ self.storage.db.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["bleb"]]
self.get_success(
- self.storage.runInteraction(
+ self.storage.db.runInteraction(
"test",
- self.storage._simple_upsert_many_txn,
+ self.storage.db.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage._simple_select_list(
+ self.storage.db.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index dfeea24599..2e521e9ab7 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
+from synapse.storage.database import Database
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -54,7 +55,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
- self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
+ database = Database(hs)
+ self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@@ -123,7 +125,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
- self.store = TestTransactionStore(hs.get_db_conn(), hs)
+ database = Database(hs)
+ self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
def _add_service(self, url, as_token, id):
as_yaml = dict(
@@ -382,8 +385,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, db_conn, hs):
- super(TestTransactionStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(TestTransactionStore, self).__init__(database, db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -416,7 +419,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.event_cache_size = 1
hs.config.password_providers = []
- ApplicationServiceStore(hs.get_db_conn(), hs)
+ ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
@defer.inlineCallbacks
def test_duplicate_ids(self):
@@ -432,7 +435,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(hs.get_db_conn(), hs)
+ ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
@@ -453,7 +456,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(hs.get_db_conn(), hs)
+ ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 9fabe3fbc0..aec76f4ab1 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -15,7 +15,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
self.update_handler = Mock()
- yield self.store.register_background_update_handler(
+ yield self.store.db.updates.register_background_update_handler(
"test_update", self.update_handler
)
@@ -23,7 +23,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
# (perhaps we should run them as part of the test HS setup, since we
# run all of the other schema setup stuff there?)
while True:
- res = yield self.store.do_next_background_update(1000)
+ res = yield self.store.db.updates.do_next_background_update(1000)
if res is None:
break
@@ -37,9 +37,9 @@ class BackgroundUpdateTestCase(unittest.TestCase):
def update(progress, count):
self.clock.advance_time_msec(count * duration_ms)
progress = {"my_key": progress["my_key"] + 1}
- yield self.store.runInteraction(
+ yield self.store.db.runInteraction(
"update_progress",
- self.store._background_update_progress_txn,
+ self.store.db.updates._background_update_progress_txn,
"test_update",
progress,
)
@@ -47,29 +47,37 @@ class BackgroundUpdateTestCase(unittest.TestCase):
self.update_handler.side_effect = update
- yield self.store.start_background_update("test_update", {"my_key": 1})
+ yield self.store.db.updates.start_background_update(
+ "test_update", {"my_key": 1}
+ )
self.update_handler.reset_mock()
- result = yield self.store.do_next_background_update(duration_ms * desired_count)
+ result = yield self.store.db.updates.do_next_background_update(
+ duration_ms * desired_count
+ )
self.assertIsNotNone(result)
self.update_handler.assert_called_once_with(
- {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
+ {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE
)
# second step: complete the update
@defer.inlineCallbacks
def update(progress, count):
- yield self.store._end_background_update("test_update")
+ yield self.store.db.updates._end_background_update("test_update")
return count
self.update_handler.side_effect = update
self.update_handler.reset_mock()
- result = yield self.store.do_next_background_update(duration_ms * desired_count)
+ result = yield self.store.db.updates.do_next_background_update(
+ duration_ms * desired_count
+ )
self.assertIsNotNone(result)
self.update_handler.assert_called_once_with({"my_key": 2}, desired_count)
# third step: we don't expect to be called any more
self.update_handler.reset_mock()
- result = yield self.store.do_next_background_update(duration_ms * desired_count)
+ result = yield self.store.db.updates.do_next_background_update(
+ duration_ms * desired_count
+ )
self.assertIsNone(result)
self.assertFalse(self.update_handler.called)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index c778de1f0c..537cfe9f64 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,6 +21,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.storage.engines import create_engine
from tests import unittest
@@ -59,13 +60,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
"test", db_pool=self.db_pool, config=config, database_engine=fake_engine
)
- self.datastore = SQLBaseStore(None, hs)
+ self.datastore = SQLBaseStore(Database(hs), None, hs)
@defer.inlineCallbacks
def test_insert_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore._simple_insert(
+ yield self.datastore.db.simple_insert(
table="tablename", values={"columname": "Value"}
)
@@ -77,7 +78,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore._simple_insert(
+ yield self.datastore.db.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
@@ -92,7 +93,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
- value = yield self.datastore._simple_select_one_onecol(
+ value = yield self.datastore.db.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
)
@@ -106,7 +107,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
- ret = yield self.datastore._simple_select_one(
+ ret = yield self.datastore.db.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
@@ -122,7 +123,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
- ret = yield self.datastore._simple_select_one(
+ ret = yield self.datastore.db.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
@@ -137,7 +138,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
- ret = yield self.datastore._simple_select_list(
+ ret = yield self.datastore.db.simple_select_list(
table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
)
@@ -150,7 +151,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore._simple_update_one(
+ yield self.datastore.db.simple_update_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
updatevalues={"columnname": "New Value"},
@@ -165,7 +166,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_4cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore._simple_update_one(
+ yield self.datastore.db.simple_update_one(
table="tablename",
keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
@@ -180,7 +181,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_delete_one(self):
self.mock_txn.rowcount = 1
- yield self.datastore._simple_delete_one(
+ yield self.datastore.db.simple_delete_one(
table="tablename", keyvalues={"keycol": "Go away"}
)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 69dcaa63d5..029ac26454 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -46,7 +46,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
"""Re run the background update to clean up the extremities.
"""
# Make sure we don't clash with in progress updates.
- self.assertTrue(self.store._all_done, "Background updates are still ongoing")
+ self.assertTrue(
+ self.store.db.updates._all_done, "Background updates are still ongoing"
+ )
schema_path = os.path.join(
prepare_database.dir_path,
@@ -62,14 +64,20 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
prepare_database.executescript(txn, schema_path)
self.get_success(
- self.store.runInteraction("test_delete_forward_extremities", run_delta_file)
+ self.store.db.runInteraction(
+ "test_delete_forward_extremities", run_delta_file
+ )
)
# Ugh, have to reset this flag
- self.store._all_done = False
+ self.store.db.updates._all_done = False
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
def test_soft_failed_extremities_handled_correctly(self):
"""Test that extremities are correctly calculated in the presence of
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index afac5dec7f..fc279340d4 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -81,7 +81,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -112,7 +112,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -202,8 +202,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_devices_last_seen_bg_update(self):
# First make sure we have completed all updates.
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
# Insert a user IP
user_id = "@user:id"
@@ -218,7 +222,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# But clear the associated entry in devices table
self.get_success(
- self.store._simple_update(
+ self.store.db.simple_update(
table="devices",
keyvalues={"user_id": user_id, "device_id": "device_id"},
updatevalues={"last_seen": None, "ip": None, "user_agent": None},
@@ -245,7 +249,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# Register the background update to run again.
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
table="background_updates",
values={
"update_name": "devices_last_seen",
@@ -256,11 +260,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
- self.store._all_done = False
+ self.store.db.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
# We should now get the correct result again
result = self.get_success(
@@ -281,8 +289,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_old_user_ips_pruned(self):
# First make sure we have completed all updates.
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
# Insert a user IP
user_id = "@user:id"
@@ -297,7 +309,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should see that in the DB
result = self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -323,7 +335,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should get no results.
result = self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index d128fde441..35dafbb904 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -39,8 +39,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.store.set_e2e_room_key(
- "user_id", version1, "room", "session", room_key
+ self.store.add_e2e_room_keys(
+ "user_id", version1, [("room", "session", room_key)]
)
)
@@ -51,8 +51,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.store.set_e2e_room_key(
- "user_id", version2, "room", "session", room_key
+ self.store.add_e2e_room_keys(
+ "user_id", version2, [("room", "session", room_key)]
)
)
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 2fe50377f8..eadfb90a22 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -61,7 +61,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
)
for i in range(0, 11):
- yield self.store.runInteraction("insert", insert_event, i)
+ yield self.store.db.runInteraction("insert", insert_event, i)
# this should get the last five and five others
r = yield self.store.get_prev_events_for_room(room_id)
@@ -93,9 +93,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
)
for i in range(0, 20):
- yield self.store.runInteraction("insert", insert_event, i, room1)
- yield self.store.runInteraction("insert", insert_event, i, room2)
- yield self.store.runInteraction("insert", insert_event, i, room3)
+ yield self.store.db.runInteraction("insert", insert_event, i, room1)
+ yield self.store.db.runInteraction("insert", insert_event, i, room2)
+ yield self.store.db.runInteraction("insert", insert_event, i, room3)
# Test simple case
r = yield self.store.get_rooms_with_many_extremities(5, 5, [])
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b114c6fb1d..d4bcf1821e 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -55,7 +55,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield self.store.runInteraction(
+ counts = yield self.store.db.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
self.assertEquals(
@@ -74,7 +74,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield self.store.add_push_actions_to_staging(
event.event_id, {user_id: action}
)
- yield self.store.runInteraction(
+ yield self.store.db.runInteraction(
"",
self.store._set_push_actions_for_event_and_users_txn,
[(event, None)],
@@ -82,12 +82,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
def _rotate(stream):
- return self.store.runInteraction(
+ return self.store.db.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
def _mark_read(stream, depth):
- return self.store.runInteraction(
+ return self.store.db.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
room_id,
@@ -116,7 +116,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
- yield self.store._simple_delete(
+ yield self.store.db.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
@@ -135,7 +135,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return self.store._simple_insert(
+ return self.store.db.simple_insert(
"events",
{
"stream_ordering": so,
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 90a63dc477..3c78faab45 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -65,7 +65,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
- self.store.runInteraction(
+ self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.pump()
@@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.hs.config.mau_limits_reserved_threepids = threepids
- self.store.runInteraction(
+ self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
count = self.store.get_monthly_active_count()
@@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": user2_email},
]
self.hs.config.mau_limits_reserved_threepids = threepids
- self.store.runInteraction(
+ self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24c7fe16c3..9b6f7211ae 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -16,7 +16,6 @@
from twisted.internet import defer
-from synapse.storage.data_stores.main.profile import ProfileStore
from synapse.types import UserID
from tests import unittest
@@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
- self.store = ProfileStore(hs.get_db_conn(), hs)
+ self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 4561c3e383..dc45173355 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -338,7 +338,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
event_json = self.get_success(
- self.store._simple_select_one_onecol(
+ self.store.db.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
@@ -356,7 +356,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.reactor.advance(60 * 60 * 2)
event_json = self.get_success(
- self.store._simple_select_one_onecol(
+ self.store.db.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 9ddd17f73d..7840f63fe3 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -16,8 +16,7 @@
from unittest.mock import Mock
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.room_versions import RoomVersions
+from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
from synapse.types import Requester, UserID
@@ -44,9 +43,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastore()
- self.storage = hs.get_storage()
- self.event_builder_factory = hs.get_event_builder_factory()
- self.event_creation_handler = hs.get_event_creation_handler()
self.u_alice = self.register_user("alice", "pass")
self.t_alice = self.login("alice", "pass")
@@ -55,26 +51,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# User elsewhere on another host
self.u_charlie = UserID.from_string("@charlie:elsewhere")
- def inject_room_member(self, room, user, membership, replaces_state=None):
- builder = self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": EventTypes.Member,
- "sender": user,
- "state_key": user,
- "room_id": room,
- "content": {"membership": membership},
- },
- )
-
- event, context = self.get_success(
- self.event_creation_handler.create_new_client_event(builder)
- )
-
- self.get_success(self.storage.persistence.persist_event(event, context))
-
- return event
-
def test_one_member(self):
# Alice creates the room, and is automatically joined
@@ -146,8 +122,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def test_can_rerun_update(self):
# First make sure we have completed all updates.
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
@@ -156,7 +136,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Register the background update to run again.
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
table="background_updates",
values={
"update_name": "current_state_events_membership",
@@ -167,8 +147,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
- self.store._all_done = False
+ self.store.db.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7eea57c0e2..6a545d2eb0 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -15,8 +15,6 @@
from twisted.internet import defer
-from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
-
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
+ self.store = self.hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 7d82b58466..ad165d7295 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -33,6 +33,8 @@ class MessageAcceptTests(unittest.TestCase):
self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"]
+ self.store = self.homeserver.get_datastore()
+
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
@@ -77,10 +79,7 @@ class MessageAcceptTests(unittest.TestCase):
# Make sure we actually joined the room
self.assertEqual(
self.successResultOf(
- maybeDeferred(
- self.homeserver.get_datastore().get_latest_event_ids_in_room,
- self.room_id,
- )
+ maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
)[0],
"$join:test.serv",
)
@@ -100,10 +99,7 @@ class MessageAcceptTests(unittest.TestCase):
# Figure out what the most recent event is
most_recent = self.successResultOf(
- maybeDeferred(
- self.homeserver.get_datastore().get_latest_event_ids_in_room,
- self.room_id,
- )
+ maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
)[0]
# Now lie about an event
@@ -141,7 +137,5 @@ class MessageAcceptTests(unittest.TestCase):
)
# Make sure the invalid event isn't there
- extrem = maybeDeferred(
- self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id
- )
+ extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
diff --git a/tests/unittest.py b/tests/unittest.py
index 561cebc223..b30b7d1718 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,9 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import gc
import hashlib
import hmac
+import inspect
import logging
import time
@@ -23,17 +26,21 @@ from mock import Mock
from canonicaljson import json
-from twisted.internet.defer import Deferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.logging.context import LoggingContext
from synapse.server import HomeServer
from synapse.types import Requester, UserID, create_requester
+from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import get_clock, make_request, render, setup_test_homeserver
from tests.test_utils.logging_setup import setup_logging
@@ -395,10 +402,12 @@ class HomeserverTestCase(TestCase):
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
- # Run the database background updates.
- if hasattr(stor, "do_next_background_update"):
- while not self.get_success(stor.has_completed_background_updates()):
- self.get_success(stor.do_next_background_update(1))
+ # Run the database background updates, when running against "master".
+ if hs.__class__.__name__ == "TestHomeServer":
+ while not self.get_success(
+ stor.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(stor.db.updates.do_next_background_update(1))
return hs
@@ -409,6 +418,8 @@ class HomeserverTestCase(TestCase):
self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0):
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump(by=by)
@@ -418,6 +429,8 @@ class HomeserverTestCase(TestCase):
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump()
@@ -538,7 +551,7 @@ class HomeserverTestCase(TestCase):
Add the given event as an extremity to the room.
"""
self.get_success(
- self.hs.get_datastore()._simple_insert(
+ self.hs.get_datastore().db.simple_insert(
table="event_forward_extremities",
values={"room_id": room_id, "event_id": event_id},
desc="test_add_extremity",
@@ -559,6 +572,66 @@ class HomeserverTestCase(TestCase):
self.render(request)
self.assertEqual(channel.code, 403, channel.result)
+ def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
+ """
+ Inject a membership event into a room.
+
+ Args:
+ room: Room ID to inject the event into.
+ user: MXID of the user to inject the membership for.
+ membership: The membership type.
+ """
+ event_builder_factory = self.hs.get_event_builder_factory()
+ event_creation_handler = self.hs.get_event_creation_handler()
+
+ room_version = self.get_success(self.hs.get_datastore().get_room_version(room))
+
+ builder = event_builder_factory.for_room_version(
+ KNOWN_ROOM_VERSIONS[room_version],
+ {
+ "type": EventTypes.Member,
+ "sender": user,
+ "state_key": user,
+ "room_id": room,
+ "content": {"membership": membership},
+ },
+ )
+
+ event, context = self.get_success(
+ event_creation_handler.create_new_client_event(builder)
+ )
+
+ self.get_success(
+ self.hs.get_storage().persistence.persist_event(event, context)
+ )
+
+
+class FederatingHomeserverTestCase(HomeserverTestCase):
+ """
+ A federating homeserver that authenticates incoming requests as `other.example.com`.
+ """
+
+ def prepare(self, reactor, clock, homeserver):
+ class Authenticator(object):
+ def authenticate_request(self, request, content):
+ return succeed("other.example.com")
+
+ ratelimiter = FederationRateLimiter(
+ clock,
+ FederationRateLimitConfig(
+ window_size=1,
+ sleep_limit=1,
+ sleep_msec=1,
+ reject_limit=1000,
+ concurrent_requests=1000,
+ ),
+ )
+ federation_server.register_servlets(
+ homeserver, self.resource, Authenticator(), ratelimiter
+ )
+
+ return super().prepare(reactor, clock, homeserver)
+
def override_config(extra_config):
"""A decorator which can be applied to test functions to give additional HS config
diff --git a/tests/utils.py b/tests/utils.py
index 7dc9bdc505..c57da59191 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -109,6 +109,7 @@ def default_config(name, parse=False):
"""
config_dict = {
"server_name": name,
+ "send_federation": False,
"media_store_path": "media",
"uploads_path": "uploads",
# the test signing key is just an arbitrary ed25519 key to keep the config
@@ -460,7 +461,9 @@ class MockHttpResource(HttpServer):
try:
args = [urlparse.unquote(u) for u in matcher.groups()]
- (code, response) = yield func(mock_request, *args)
+ (code, response) = yield defer.ensureDeferred(
+ func(mock_request, *args)
+ )
return code, response
except CodeMessageException as e:
return (e.code, cs_error(e.msg, code=e.errcode))
|