diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 5d45689c8c..8ab56ec94c 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -36,7 +36,7 @@ from tests import unittest
from tests.utils import mock_getRawHeaders, setup_test_homeserver
-class TestHandlers(object):
+class TestHandlers:
def __init__(self, hs):
self.auth_handler = synapse.handlers.auth.AuthHandler(hs)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 1fab1d6b69..d2d535d23c 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -369,8 +369,10 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
- filter_id = yield self.datastore.add_user_filter(
- user_localpart=user_localpart, user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.datastore.add_user_filter(
+ user_localpart=user_localpart, user_filter=user_filter_json
+ )
)
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event]
@@ -388,8 +390,10 @@ class FilteringTestCase(unittest.TestCase):
def test_filter_presence_no_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
- filter_id = yield self.datastore.add_user_filter(
- user_localpart=user_localpart + "2", user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.datastore.add_user_filter(
+ user_localpart=user_localpart + "2", user_filter=user_filter_json
+ )
)
event = MockEvent(
event_id="$asdasd:localhost",
@@ -410,8 +414,10 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_filter_room_state_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield self.datastore.add_user_filter(
- user_localpart=user_localpart, user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.datastore.add_user_filter(
+ user_localpart=user_localpart, user_filter=user_filter_json
+ )
)
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event]
@@ -428,8 +434,10 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_filter_room_state_no_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield self.datastore.add_user_filter(
- user_localpart=user_localpart, user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.datastore.add_user_filter(
+ user_localpart=user_localpart, user_filter=user_filter_json
+ )
)
event = MockEvent(
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
@@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
def test_add_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield self.filtering.add_user_filter(
- user_localpart=user_localpart, user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.filtering.add_user_filter(
+ user_localpart=user_localpart, user_filter=user_filter_json
+ )
)
self.assertEquals(filter_id, 0)
@@ -485,8 +495,10 @@ class FilteringTestCase(unittest.TestCase):
def test_get_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield self.datastore.add_user_filter(
- user_localpart=user_localpart, user_filter=user_filter_json
+ filter_id = yield defer.ensureDeferred(
+ self.datastore.add_user_filter(
+ user_localpart=user_localpart, user_filter=user_filter_json
+ )
)
filter = yield defer.ensureDeferred(
diff --git a/tests/config/test_base.py b/tests/config/test_base.py
new file mode 100644
index 0000000000..42ee5f56d9
--- /dev/null
+++ b/tests/config/test_base.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path
+import tempfile
+
+from synapse.config import ConfigError
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+
+
+class BaseConfigTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.hs = hs
+
+ def test_loading_missing_templates(self):
+ # Use a temporary directory that exists on the system, but that isn't likely to
+ # contain template files
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Attempt to load an HTML template from our custom template directory
+ template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0]
+
+ # If no errors, we should've gotten the default template instead
+
+ # Render the template
+ a_random_string = random_string(5)
+ html_content = template.render({"error_description": a_random_string})
+
+ # Check that our string exists in the template
+ self.assertIn(
+ a_random_string,
+ html_content,
+ "Template file did not contain our test string",
+ )
+
+ def test_loading_custom_templates(self):
+ # Use a temporary directory that exists on the system
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Create a temporary bogus template file
+ with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template:
+ # Get temporary file's filename
+ template_filename = os.path.basename(tmp_template.name)
+
+ # Write a custom HTML template
+ contents = b"{{ test_variable }}"
+ tmp_template.write(contents)
+ tmp_template.flush()
+
+ # Attempt to load the template from our custom template directory
+ template = (
+ self.hs.config.read_templates([template_filename], tmp_dir)
+ )[0]
+
+ # Render the template
+ a_random_string = random_string(5)
+ html_content = template.render({"test_variable": a_random_string})
+
+ # Check that our string exists in the template
+ self.assertIn(
+ a_random_string,
+ html_content,
+ "Template file did not contain our test string",
+ )
+
+ def test_loading_template_from_nonexistent_custom_directory(self):
+ with self.assertRaises(ConfigError):
+ self.hs.config.read_templates(
+ ["some_filename.html"], "a_nonexistent_directory"
+ )
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 0d4b05304b..2e6e7abf1f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -43,7 +43,7 @@ from tests import unittest
from tests.test_utils import make_awaitable
-class MockPerspectiveServer(object):
+class MockPerspectiveServer:
def __init__(self):
self.server_name = "mock_server"
self.key = signedjson.key.generate_signing_key(0)
@@ -190,7 +190,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
- self.failureResultOf(d, SynapseError)
+ self.get_failure(d, SynapseError)
# should succeed on a signed object
d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
@@ -221,7 +221,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
- self.failureResultOf(d, SynapseError)
+ self.get_failure(d, SynapseError)
# should fail on a signed object with a non-zero minimum_valid_until_ms,
# as it tries to refetch the keys and fails.
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index b8ca118716..1471cc1a28 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -15,8 +15,6 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
@@ -60,7 +58,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Artificially raise the complexity
store = self.hs.get_datastore()
- store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23)
+ store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
# Get the room complexity again -- make sure it's our artificial value
request, channel = self.make_request(
@@ -154,7 +152,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
)
# Artificially raise the complexity
- self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(
+ self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable(
600
)
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
new file mode 100644
index 0000000000..cc52c3dfac
--- /dev/null
+++ b/tests/federation/test_federation_catch_up.py
@@ -0,0 +1,323 @@
+from typing import List, Tuple
+
+from mock import Mock
+
+from synapse.events import EventBase
+from synapse.federation.sender import PerDestinationQueue, TransactionManager
+from synapse.federation.units import Edu
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.test_utils import event_injection, make_awaitable
+from tests.unittest import FederatingHomeserverTestCase, override_config
+
+
+class FederationCatchUpTestCases(FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
+
+ def prepare(self, reactor, clock, hs):
+ # stub out get_current_hosts_in_room
+ state_handler = hs.get_state_handler()
+
+ # This mock is crucial for destination_rooms to be populated.
+ state_handler.get_current_hosts_in_room = Mock(
+ return_value=make_awaitable(["test", "host2"])
+ )
+
+ # whenever send_transaction is called, record the pdu data
+ self.pdus = []
+ self.failed_pdus = []
+ self.is_online = True
+ self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.record_transaction
+ )
+
+ async def record_transaction(self, txn, json_cb):
+ if self.is_online:
+ data = json_cb()
+ self.pdus.extend(data["pdus"])
+ return {}
+ else:
+ data = json_cb()
+ self.failed_pdus.extend(data["pdus"])
+ raise IOError("Failed to connect because this is a test!")
+
+ def get_destination_room(self, room: str, destination: str = "host2") -> dict:
+ """
+ Gets the destination_rooms entry for a (destination, room_id) pair.
+
+ Args:
+ room: room ID
+ destination: what destination, default is "host2"
+
+ Returns:
+ Dictionary of { event_id: str, stream_ordering: int }
+ """
+ event_id, stream_ordering = self.get_success(
+ self.hs.get_datastore().db_pool.execute(
+ "test:get_destination_rooms",
+ None,
+ """
+ SELECT event_id, stream_ordering
+ FROM destination_rooms dr
+ JOIN events USING (stream_ordering)
+ WHERE dr.destination = ? AND dr.room_id = ?
+ """,
+ destination,
+ room,
+ )
+ )[0]
+ return {"event_id": event_id, "stream_ordering": stream_ordering}
+
+ @override_config({"send_federation": True})
+ def test_catch_up_destination_rooms_tracking(self):
+ """
+ Tests that we populate the `destination_rooms` table as needed.
+ """
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room = self.helper.create_room_as("u1", tok=u1_token)
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, "@user:host2", "join")
+ )
+
+ event_id_1 = self.helper.send(room, "wombats!", tok=u1_token)["event_id"]
+
+ row_1 = self.get_destination_room(room)
+
+ event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"]
+
+ row_2 = self.get_destination_room(room)
+
+ # check: events correctly registered in order
+ self.assertEqual(row_1["event_id"], event_id_1)
+ self.assertEqual(row_2["event_id"], event_id_2)
+ self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
+
+ @override_config({"send_federation": True})
+ def test_catch_up_last_successful_stream_ordering_tracking(self):
+ """
+ Tests that we populate the `destination_rooms` table as needed.
+ """
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room = self.helper.create_room_as("u1", tok=u1_token)
+
+ # take the remote offline
+ self.is_online = False
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, "@user:host2", "join")
+ )
+
+ self.helper.send(room, "wombats!", tok=u1_token)
+ self.pump()
+
+ lsso_1 = self.get_success(
+ self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+ "host2"
+ )
+ )
+
+ self.assertIsNone(
+ lsso_1,
+ "There should be no last successful stream ordering for an always-offline destination",
+ )
+
+ # bring the remote online
+ self.is_online = True
+
+ event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"]
+
+ lsso_2 = self.get_success(
+ self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+ "host2"
+ )
+ )
+ row_2 = self.get_destination_room(room)
+
+ self.assertEqual(
+ self.pdus[0]["content"]["body"],
+ "rabbits!",
+ "Test fault: didn't receive the right PDU",
+ )
+ self.assertEqual(
+ row_2["event_id"],
+ event_id_2,
+ "Test fault: destination_rooms not updated correctly",
+ )
+ self.assertEqual(
+ lsso_2,
+ row_2["stream_ordering"],
+ "Send succeeded but not marked as last_successful_stream_ordering",
+ )
+
+ @override_config({"send_federation": True}) # critical to federate
+ def test_catch_up_from_blank_state(self):
+ """
+ Runs an overall test of federation catch-up from scratch.
+ Further tests will focus on more narrow aspects and edge-cases, but I
+ hope to provide an overall view with this test.
+ """
+ # bring the other server online
+ self.is_online = True
+
+ # let's make some events for the other server to receive
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_1 = self.helper.create_room_as("u1", tok=u1_token)
+ room_2 = self.helper.create_room_as("u1", tok=u1_token)
+
+ # also critical to federate
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_2, "@user:host2", "join")
+ )
+
+ self.helper.send_state(
+ room_1, event_type="m.room.topic", body={"topic": "wombat"}, tok=u1_token
+ )
+
+ # check: PDU received for topic event
+ self.assertEqual(len(self.pdus), 1)
+ self.assertEqual(self.pdus[0]["type"], "m.room.topic")
+
+ # take the remote offline
+ self.is_online = False
+
+ # send another event
+ self.helper.send(room_1, "hi user!", tok=u1_token)
+
+ # check: things didn't go well since the remote is down
+ self.assertEqual(len(self.failed_pdus), 1)
+ self.assertEqual(self.failed_pdus[0]["content"]["body"], "hi user!")
+
+ # let's delete the federation transmission queue
+ # (this pretends we are starting up fresh.)
+ self.assertFalse(
+ self.hs.get_federation_sender()
+ ._per_destination_queues["host2"]
+ .transmission_loop_running
+ )
+ del self.hs.get_federation_sender()._per_destination_queues["host2"]
+
+ # let's also clear any backoffs
+ self.get_success(
+ self.hs.get_datastore().set_destination_retry_timings("host2", None, 0, 0)
+ )
+
+ # bring the remote online and clear the received pdu list
+ self.is_online = True
+ self.pdus = []
+
+ # now we need to initiate a federation transaction somehow…
+ # to do that, let's send another event (because it's simple to do)
+ # (do it to another room otherwise the catch-up logic decides it doesn't
+ # need to catch up room_1 — something I overlooked when first writing
+ # this test)
+ self.helper.send(room_2, "wombats!", tok=u1_token)
+
+ # we should now have received both PDUs
+ self.assertEqual(len(self.pdus), 2)
+ self.assertEqual(self.pdus[0]["content"]["body"], "hi user!")
+ self.assertEqual(self.pdus[1]["content"]["body"], "wombats!")
+
+ def make_fake_destination_queue(
+ self, destination: str = "host2"
+ ) -> Tuple[PerDestinationQueue, List[EventBase]]:
+ """
+ Makes a fake per-destination queue.
+ """
+ transaction_manager = TransactionManager(self.hs)
+ per_dest_queue = PerDestinationQueue(self.hs, transaction_manager, destination)
+ results_list = []
+
+ async def fake_send(
+ destination_tm: str,
+ pending_pdus: List[EventBase],
+ _pending_edus: List[Edu],
+ ) -> bool:
+ assert destination == destination_tm
+ results_list.extend(pending_pdus)
+ return True # success!
+
+ transaction_manager.send_new_transaction = fake_send
+
+ return per_dest_queue, results_list
+
+ @override_config({"send_federation": True})
+ def test_catch_up_loop(self):
+ """
+ Tests the behaviour of _catch_up_transmission_loop.
+ """
+
+ # ARRANGE:
+ # - a local user (u1)
+ # - 3 rooms which u1 is joined to (and remote user @user:host2 is
+ # joined to)
+ # - some events (1 to 5) in those rooms
+ # we have 'already sent' events 1 and 2 to host2
+ per_dest_queue, sent_pdus = self.make_fake_destination_queue()
+
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_1 = self.helper.create_room_as("u1", tok=u1_token)
+ room_2 = self.helper.create_room_as("u1", tok=u1_token)
+ room_3 = self.helper.create_room_as("u1", tok=u1_token)
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_2, "@user:host2", "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_3, "@user:host2", "join")
+ )
+
+ # create some events
+ self.helper.send(room_1, "you hear me!!", tok=u1_token)
+ event_id_2 = self.helper.send(room_2, "wombats!", tok=u1_token)["event_id"]
+ self.helper.send(room_3, "Matrix!", tok=u1_token)
+ event_id_4 = self.helper.send(room_2, "rabbits!", tok=u1_token)["event_id"]
+ event_id_5 = self.helper.send(room_3, "Synapse!", tok=u1_token)["event_id"]
+
+ # destination_rooms should already be populated, but let us pretend that we already
+ # sent (successfully) up to and including event id 2
+ event_2 = self.get_success(self.hs.get_datastore().get_event(event_id_2))
+
+ # also fetch event 5 so we know its last_successful_stream_ordering later
+ event_5 = self.get_success(self.hs.get_datastore().get_event(event_id_5))
+
+ self.get_success(
+ self.hs.get_datastore().set_destination_last_successful_stream_ordering(
+ "host2", event_2.internal_metadata.stream_ordering
+ )
+ )
+
+ # ACT
+ self.get_success(per_dest_queue._catch_up_transmission_loop())
+
+ # ASSERT, noticing in particular:
+ # - event 3 not sent out, because event 5 replaces it
+ # - order is least recent first, so event 5 comes after event 4
+ # - catch-up is completed
+ self.assertEqual(len(sent_pdus), 2)
+ self.assertEqual(sent_pdus[0].event_id, event_id_4)
+ self.assertEqual(sent_pdus[1].event_id, event_id_5)
+ self.assertFalse(per_dest_queue._catching_up)
+ self.assertEqual(
+ per_dest_queue._last_successful_stream_ordering,
+ event_5.internal_metadata.stream_ordering,
+ )
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 5f512ff8bf..917762e6b6 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -34,7 +34,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call.
- mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+ mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver(
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 296dc887be..da933ecd75 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -15,6 +15,8 @@
# limitations under the License.
import logging
+from parameterized import parameterized
+
from synapse.events import make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin
@@ -23,6 +25,37 @@ from synapse.rest.client.v1 import login, room
from tests import unittest
+class FederationServerTests(unittest.FederatingHomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ @parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)])
+ def test_bad_request(self, query_content):
+ """
+ Querying with bad data returns a reasonable error code.
+ """
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.inject_room_member(room_1, "@user:other.example.com", "join")
+
+ "/get_missing_events/(?P<room_id>[^/]*)/?"
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
+ query_content,
+ )
+ self.render(request)
+ self.assertEquals(400, channel.code, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
+
+
class ServerACLsTestCase(unittest.TestCase):
def test_blacklisted_server(self):
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 27d83bb7d9..72e22d655f 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -26,7 +26,7 @@ from tests.unittest import override_config
class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
- class Authenticator(object):
+ class Authenticator:
def authenticate_request(self, request, content):
return defer.succeed("otherserver.nottld")
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c01b04e1dc..97877c2e42 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -24,10 +24,11 @@ from synapse.api.errors import ResourceLimitError
from synapse.handlers.auth import AuthHandler
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class AuthHandlers(object):
+class AuthHandlers:
def __init__(self, hs):
self.auth_handler = AuthHandler(hs)
@@ -142,7 +143,7 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.large_number_of_users)
+ return_value=make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
@@ -153,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.large_number_of_users)
+ return_value=make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -168,7 +169,7 @@ class AuthTestCase(unittest.TestCase):
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -178,7 +179,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -188,10 +189,10 @@ class AuthTestCase(unittest.TestCase):
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
@@ -199,10 +200,10 @@ class AuthTestCase(unittest.TestCase):
)
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -215,7 +216,7 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.small_number_of_users)
+ return_value=make_awaitable(self.small_number_of_users)
)
# Ensure does not raise exception
yield defer.ensureDeferred(
@@ -225,7 +226,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.small_number_of_users)
+ return_value=make_awaitable(self.small_number_of_users)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..89ec5fcb31 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -75,7 +75,17 @@ COMMON_CONFIG = {
COOKIE_NAME = b"oidc_session"
COOKIE_PATH = "/_synapse/oidc"
-MockedMappingProvider = Mock(OidcMappingProvider)
+
+class TestMappingProvider(OidcMappingProvider):
+ @staticmethod
+ def parse_config(config):
+ return
+
+ def get_remote_user_id(self, userinfo):
+ return userinfo["sub"]
+
+ async def map_user_attributes(self, userinfo, token):
+ return {"localpart": userinfo["username"], "display_name": None}
def simple_async_mock(return_value=None, raises=None):
@@ -123,7 +133,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config["issuer"] = ISSUER
oidc_config["scopes"] = SCOPES
oidc_config["user_mapping_provider"] = {
- "module": __name__ + ".MockedMappingProvider"
+ "module": __name__ + ".TestMappingProvider",
}
config["oidc_config"] = oidc_config
@@ -374,12 +384,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(spec=["args", "getCookie", "addCookie"])
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
code = "code"
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
+ user_agent = "Browser"
+ ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
@@ -392,6 +406,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+ request.getClientIP.return_value = ip_address
+
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +417,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()
@@ -431,7 +451,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()
@@ -568,3 +590,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
with self.assertRaises(OidcError) as exc:
yield defer.ensureDeferred(self.handler._exchange_code(code))
self.assertEqual(exc.exception.error, "some_error")
+
+ def test_map_userinfo_to_user(self):
+ """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+ userinfo = {
+ "sub": "test_user",
+ "username": "test_user",
+ }
+ # The token doesn't matter with the default user mapping provider.
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Some providers return an integer ID.
+ userinfo = {
+ "sub": 1234,
+ "username": "test_user_2",
+ }
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user_2:test")
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 05ea40a7de..306dcfe944 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -19,6 +19,7 @@ from mock import Mock, call
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
@@ -32,7 +33,6 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest.client.v1 import room
-from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from tests import unittest
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d70e1fc608..8e95e53d9e 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -28,7 +28,7 @@ from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class ProfileHandlers(object):
+class ProfileHandlers:
def __init__(self, hs):
self.profile_handler = MasterProfileHandler(hs)
@@ -64,14 +64,16 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- yield self.store.create_profile(self.frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
self.handler = hs.get_profile_handler()
self.hs = hs
@defer.inlineCallbacks
def test_get_my_name(self):
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ )
displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.frank)
@@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
@defer.inlineCallbacks
@@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ )
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
# Setting displayname a second time is forbidden
@@ -157,8 +171,10 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_profile("caroline")
- yield self.store.set_profile_displayname("caroline", "Caroline")
+ yield defer.ensureDeferred(self.store.create_profile("caroline"))
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname("caroline", "Caroline")
+ )
response = yield defer.ensureDeferred(
self.query_handlers["profile"](
@@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
@@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/pic.gif",
)
@@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
@@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e364b1bd62..cb7c0ed51a 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -15,20 +15,21 @@
from mock import Mock
-from twisted.internet import defer
-
+from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
from tests.test_utils import make_awaitable
from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
from .. import unittest
-class RegistrationHandlers(object):
+class RegistrationHandlers:
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
@@ -99,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value - 1)
+ return_value=make_awaitable(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@@ -107,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.lots_of_users)
+ return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -115,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -125,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.lots_of_users)
+ return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
@@ -475,6 +476,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_spam_checker_deny(self):
+ """A spam checker can deny registration, which results in an error."""
+
+ class DenyAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.DENY
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [DenyAll()]
+
+ self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+ def test_spam_checker_shadow_ban(self):
+ """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+ class BanAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanAll()]
+
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+ # Get an access token.
+ token = self.macaroon_generator.generate_access_token(user_id)
+ self.get_success(
+ self.store.add_access_token_to_user(
+ user_id=user_id, token=token, device_id=None, valid_until_ms=None
+ )
+ )
+
+ # Ensure the user was marked as shadow-banned.
+ request = Mock(args={})
+ request.args[b"access_token"] = [token.encode("ascii")]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ auth = Auth(self.hs)
+ requester = self.get_success(auth.get_user_by_req(request))
+
+ self.assertTrue(requester.shadow_banned)
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 0e666492f6..312c0a0d41 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -54,7 +54,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms_2",
+ "update_name": "populate_stats_process_rooms",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
@@ -66,7 +66,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms_2",
+ "depends_on": "populate_stats_process_rooms",
},
)
)
@@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- def get_all_room_state(self):
- return self.store.db_pool.simple_select_list(
+ async def get_all_room_state(self):
+ return await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)
@@ -219,10 +219,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
- {
- "update_name": "populate_stats_process_rooms_2",
- "progress_json": "{}",
- },
+ {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
)
)
self.get_success(
@@ -231,7 +228,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
{
"update_name": "populate_stats_cleanup",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms_2",
+ "depends_on": "populate_stats_process_rooms",
},
)
)
@@ -256,7 +253,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# self.handler.notify_new_event()
# We need to let the delta processor advance…
- self.pump(10 * 60)
+ self.reactor.advance(10 * 60)
# Get the slices! There should be two -- day 1, and day 2.
r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0))
@@ -728,7 +725,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms_2",
+ "update_name": "populate_stats_process_rooms",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
@@ -740,7 +737,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms_2",
+ "depends_on": "populate_stats_process_rooms",
},
)
)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 64afd581bc..3fec09ea8a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -21,7 +21,7 @@ from mock import ANY, Mock, call
from twisted.internet import defer
from synapse.api.errors import AuthError
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
@@ -73,6 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"delivered_txn",
"get_received_txn_response",
"set_received_txn_response",
+ "get_destination_last_successful_stream_ordering",
"get_destination_retry_timings",
"get_devices_by_remote",
"maybe_store_room_on_invite",
@@ -80,6 +81,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"get_user_directory_stream_pos",
"get_current_state_deltas",
"get_device_updates_by_remote",
+ "get_room_max_stream_ordering",
]
)
@@ -116,10 +118,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
+ self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
(0, [])
)
+ self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable(
+ None
+ )
+
def get_received_txn_response(*args):
return defer.succeed(None)
@@ -144,9 +150,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = get_users_in_room
- self.datastore.get_user_directory_stream_pos.return_value = (
+ self.datastore.get_user_directory_stream_pos.side_effect = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
- defer.succeed(1)
+ lambda: make_awaitable(1)
)
self.datastore.get_current_state_deltas.return_value = (0, None)
@@ -155,8 +161,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
([], 0)
)
- self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
- self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+ self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
+ None
+ )
+ self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
None
)
@@ -167,7 +175,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=20000,
)
)
@@ -194,7 +205,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=20000,
)
)
@@ -269,7 +283,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.stopped_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
)
)
@@ -309,7 +325,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=10000,
)
)
@@ -348,7 +367,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=10000,
)
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 31ed89a5cd..87be94111f 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_spam_checker(self):
"""
- A user which fails to the spam checks will not appear in search results.
+ A user which fails the spam checks will not appear in search results.
"""
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
@@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
- class AllowAll(object):
+ class AllowAll:
def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll(object):
+ class BlockAll:
def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2096ba3c91..3e5a856584 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -133,7 +133,7 @@ def create_test_cert_file(sanlist):
@implementer(IOpenSSLServerConnectionCreator)
-class TestServerTLSConnectionFactory(object):
+class TestServerTLSConnectionFactory:
"""An SSL connection creator which returns connections which present a certificate
signed by our test CA."""
@@ -145,7 +145,7 @@ class TestServerTLSConnectionFactory(object):
self._cert_file = create_test_cert_file(sanlist)
def serverConnectionForTLS(self, tlsProtocol):
- ctx = SSL.Context(SSL.TLSv1_METHOD)
+ ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_certificate_file(self._cert_file)
ctx.use_privatekey_file(get_test_key_file())
return Connection(ctx, None)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 69945a8f98..8b5ad4574f 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -972,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_well_known_cache(self):
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -995,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
well_known_server.loseConnection()
# repeat the request: it should hit the cache
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"target-server")
@@ -1003,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((1000.0,))
# now it should connect again
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -1026,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -1052,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# another lookup.
self.reactor.pump((900.0,))
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# The resolver may retry a few times, so fonx all requests that come along
attempts = 0
@@ -1082,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((10000.0,))
# Repated the request, this time it should fail if the lookup fails.
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
clients = self.reactor.tcpClients
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -1252,7 +1264,7 @@ def _log_request(request):
@implementer(IPolicyForHTTPS)
-class TrustingTLSPolicyForHTTPS(object):
+class TrustingTLSPolicyForHTTPS:
"""An IPolicyForHTTPS which checks that the certificate belongs to the
right server, but doesn't check the certificate chain."""
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index ac598249e4..5604af3795 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -16,6 +16,7 @@
from mock import Mock
from netaddr import IPSet
+from parameterized import parameterized
from twisted.internet import defer
from twisted.internet.defer import TimeoutError
@@ -508,6 +509,53 @@ class FederationClientTests(HomeserverTestCase):
self.assertFalse(conn.disconnecting)
# wait for a while
- self.pump(120)
+ self.reactor.advance(120)
self.assertTrue(conn.disconnecting)
+
+ @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
+ def test_json_error(self, return_value):
+ """
+ Test what happens if invalid JSON is returned from the remote endpoint.
+ """
+
+ test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8008)
+
+ # complete the connection and wire it up to a fake transport
+ protocol = factory.buildProtocol(None)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ # that should have made it send the request to the transport
+ self.assertRegex(transport.value(), b"^GET /foo/bar")
+ self.assertRegex(transport.value(), b"Host: testserv:8008")
+
+ # Deferred is still without a result
+ self.assertNoResult(test_d)
+
+ # Send it the HTTP response
+ protocol.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Server: Fake\r\n"
+ b"Content-Type: application/json\r\n"
+ b"Content-Length: %i\r\n"
+ b"\r\n"
+ b"%s" % (len(return_value), return_value)
+ )
+
+ self.pump()
+
+ f = self.failureResultOf(test_d)
+ self.assertIsInstance(f.value, ValueError)
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
new file mode 100644
index 0000000000..45089158ce
--- /dev/null
+++ b/tests/http/test_servlet.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+from io import BytesIO
+
+from mock import Mock
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+ parse_json_object_from_request,
+ parse_json_value_from_request,
+)
+
+from tests import unittest
+
+
+def make_request(content):
+ """Make an object that acts enough like a request."""
+ request = Mock(spec=["content"])
+
+ if isinstance(content, dict):
+ content = json.dumps(content).encode("utf8")
+
+ request.content = BytesIO(content)
+ return request
+
+
+class TestServletUtils(unittest.TestCase):
+ def test_parse_json_value(self):
+ """Basic tests for parse_json_value_from_request."""
+ # Test round-tripping.
+ obj = {"foo": 1}
+ result = parse_json_value_from_request(make_request(obj))
+ self.assertEqual(result, obj)
+
+ # Results don't have to be objects.
+ result = parse_json_value_from_request(make_request(b'["foo"]'))
+ self.assertEqual(result, ["foo"])
+
+ # Test empty.
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b""))
+
+ result = parse_json_value_from_request(make_request(b""), allow_empty_body=True)
+ self.assertIsNone(result)
+
+ # Invalid UTF-8.
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b"\xFF\x00"))
+
+ # Invalid JSON.
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b"foo"))
+
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
+
+ def test_parse_json_object(self):
+ """Basic tests for parse_json_object_from_request."""
+ # Test empty.
+ result = parse_json_object_from_request(
+ make_request(b""), allow_empty_body=True
+ )
+ self.assertEqual(result, {})
+
+ # Test not an object
+ with self.assertRaises(SynapseError):
+ parse_json_object_from_request(make_request(b'["foo"]'))
diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py
index 451d05c0f0..d36f5f426c 100644
--- a/tests/logging/test_structured.py
+++ b/tests/logging/test_structured.py
@@ -29,12 +29,12 @@ from synapse.logging.context import LoggingContext
from tests.unittest import DEBUG, HomeserverTestCase
-class FakeBeginner(object):
+class FakeBeginner:
def beginLoggingTo(self, observers, **kwargs):
self.observers = observers
-class StructuredLoggingTestBase(object):
+class StructuredLoggingTestBase:
"""
Test base that registers a cleanup handler to reset the stdlib log handler
to 'unset'.
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 807cd65dd6..04de0b9dbe 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the new user exists with all provided attributes
self.assertEqual(user_id, "@bob:test")
self.assertTrue(access_token)
- self.assertTrue(self.store.get_user_by_id(user_id))
+ self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
# Check that the email was assigned
emails = self.get_success(self.store.user_get_threepids(user_id))
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 83032cc9ea..3224568640 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase
@attr.s
-class _User(object):
+class _User:
"Helper wrapper for user ID and access token"
id = attr.ib()
token = attr.ib()
@@ -170,7 +170,7 @@ class EmailPusherTests(HomeserverTestCase):
last_stream_ordering = pushers[0]["last_stream_ordering"]
# Advance time a bit, so the pusher will register something has happened
- self.pump(100)
+ self.pump(10)
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 0b5204654c..561258a356 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 0},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
)
self.persist(
@@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 1},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
)
self.persist(
@@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 1, "notify_count": 2},
+ {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
)
def test_get_rooms_for_user_with_stream_ordering(self):
@@ -368,7 +368,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.get_success(
self.master_store.add_push_actions_to_staging(
- event.event_id, {user_id: actions for user_id, actions in push_actions}
+ event.event_id,
+ {user_id: actions for user_id, actions in push_actions},
+ False,
)
)
return event, context
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 83f9aa291c..1d7edee5ba 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import make_awaitable
@@ -45,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new event.
"""
mock_client = Mock(spec=["put_json"])
- mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
@@ -73,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new events.
"""
mock_client1 = Mock(spec=["put_json"])
- mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -85,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
- mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -136,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new typing EDUs.
"""
mock_client1 = Mock(spec=["put_json"])
- mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -148,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
- mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -175,7 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.get_success(
typing_handler.started_typing(
target_user=UserID.from_string(user),
- auth_user=UserID.from_string(user),
+ requester=create_requester(user),
room_id=room,
timeout=20000,
)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 408c568a27..6dfc709dc5 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1174,6 +1174,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("room_id", channel.json_body)
self.assertIn("name", channel.json_body)
+ self.assertIn("topic", channel.json_body)
+ self.assertIn("avatar", channel.json_body)
self.assertIn("canonical_alias", channel.json_body)
self.assertIn("joined_members", channel.json_body)
self.assertIn("joined_local_members", channel.json_body)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 17d0aae2e9..b8b7758d24 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -20,8 +20,6 @@ import urllib.parse
from mock import Mock
-from twisted.internet import defer
-
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import HttpResponseException, ResourceLimitError
@@ -29,6 +27,7 @@ from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import sync
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -338,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -592,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -632,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 0b191d13c6..7d3773ff78 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -45,50 +45,63 @@ class RetentionTestCase(unittest.HomeserverTestCase):
}
self.hs = self.setup_test_homeserver(config=config)
+
return self.hs
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
- def test_retention_state_event(self):
- """Tests that the server configuration can limit the values a user can set to the
- room's retention policy.
+ self.store = self.hs.get_datastore()
+ self.serializer = self.hs.get_event_client_serializer()
+ self.clock = self.hs.get_clock()
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
"""
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_day_ms * 4},
+ body={"max_lifetime": lifetime},
tok=self.token,
- expect_code=400,
)
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_with_state_event_outside_allowed(self):
+ """Tests that the server configuration can override the policy for a room when
+ running the purge jobs.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set a max_lifetime higher than the maximum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_hour_ms},
+ body={"max_lifetime": one_day_ms * 4},
tok=self.token,
- expect_code=400,
)
- def test_retention_event_purged_with_state_event(self):
- """Tests that expired events are correctly purged when the room's retention policy
- is defined by a state event.
- """
- room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Check that the event is purged after waiting for the maximum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
- # Set the room's retention period to 2 days.
- lifetime = one_day_ms * 2
+ # Set a max_lifetime lower than the minimum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": lifetime},
+ body={"max_lifetime": one_hour_ms},
tok=self.token,
)
- self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+ # Check that the event is purged after waiting for the minimum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 0.5)
def test_retention_event_purged_without_state_event(self):
"""Tests that expired events are correctly purged when the room's retention policy
@@ -140,12 +153,32 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# That event should be the second, not outdated event.
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
- def _test_retention_event_purged(self, room_id, increment):
+ def _test_retention_event_purged(self, room_id: str, increment: float):
+ """Run the following test scenario to test the message retention policy support:
+
+ 1. Send event 1
+ 2. Increment time by `increment`
+ 3. Send event 2
+ 4. Increment time by `increment`
+ 5. Check that event 1 has been purged
+ 6. Check that event 2 has not been purged
+ 7. Check that state events that were sent before event 1 aren't purged.
+ The main reason for sending a second event is because currently Synapse won't
+ purge the latest message in a room because it would otherwise result in a lack of
+ forward extremities for this room. It's also a good thing to ensure the purge jobs
+ aren't too greedy and purge messages they shouldn't.
+
+ Args:
+ room_id: The ID of the room to test retention in.
+ increment: The number of milliseconds to advance the clock each time. Must be
+ defined so that events in the room aren't purged if they are `increment`
+ old but are purged if they are `increment * 2` old.
+ """
# Get the create event to, later, check that we can still access it.
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
message_handler.get_room_data(
- self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False
+ self.user_id, room_id, EventTypes.Create, state_key=""
)
)
@@ -156,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
expired_event_id = resp.get("event_id")
# Check that we can retrieve the event.
- expired_event = self.get_event(room_id, expired_event_id)
+ expired_event = self.get_event(expired_event_id)
self.assertEqual(
expired_event.get("content", {}).get("body"), "1", expired_event
)
@@ -174,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# one should still be kept.
self.reactor.advance(increment / 1000)
- # Check that the event has been purged from the database.
- self.get_event(room_id, expired_event_id, expected_code=404)
+ # Check that the first event has been purged from the database, i.e. that we
+ # can't retrieve it anymore, because it has expired.
+ self.get_event(expired_event_id, expect_none=True)
- # Check that the event that hasn't been purged can still be retrieved.
- valid_event = self.get_event(room_id, valid_event_id)
+ # Check that the event that hasn't expired can still be retrieved.
+ valid_event = self.get_event(valid_event_id)
self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
# Check that we can still access state events that were sent before the event that
# has been purged.
self.get_event(room_id, create_event.event_id)
- def get_event(self, room_id, event_id, expected_code=200):
- url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+ def get_event(self, event_id, expect_none=False):
+ event = self.get_success(self.store.get_event(event_id, allow_none=True))
- request, channel = self.make_request("GET", url, access_token=self.token)
- self.render(request)
+ if expect_none:
+ self.assertIsNone(event)
+ return {}
- self.assertEqual(channel.code, expected_code, channel.result)
+ self.assertIsNotNone(event)
- return channel.json_body
+ time_now = self.clock.time_msec()
+ serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+
+ return serialized
class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
new file mode 100644
index 0000000000..dfe4bf7762
--- /dev/null
+++ b/tests/rest/client/test_shadow_banned.py
@@ -0,0 +1,312 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock, patch
+
+import synapse.rest.admin
+from synapse.api.constants import EventTypes
+from synapse.rest.client.v1 import directory, login, profile, room
+from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+
+from tests import unittest
+
+
+class _ShadowBannedBase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ # Create two users, one of which is shadow-banned.
+ self.banned_user_id = self.register_user("banned", "test")
+ self.banned_access_token = self.login("banned", "test")
+
+ self.store = self.hs.get_datastore()
+
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="users",
+ keyvalues={"name": self.banned_user_id},
+ updatevalues={"shadow_banned": True},
+ desc="shadow_ban",
+ )
+ )
+
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class RoomTestCase(_ShadowBannedBase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def test_invite(self):
+ """Invites from shadow-banned users don't actually get sent."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ self.helper.invite(
+ room=room_id,
+ src=self.banned_user_id,
+ tok=self.banned_access_token,
+ targ=self.other_user_id,
+ )
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ def test_invite_3pid(self):
+ """Ensure that a 3PID invite does not attempt to contact the identity server."""
+ identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler.lookup_3pid = Mock(
+ side_effect=AssertionError("This should not get called")
+ )
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/invite" % (room_id,),
+ {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # This should have raised an error earlier, but double check this wasn't called.
+ identity_handler.lookup_3pid.assert_not_called()
+
+ def test_create_room(self):
+ """Invitations during a room creation should be discarded, but the room still gets created."""
+ # The room creation is successful.
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ {"visibility": "public", "invite": [self.other_user_id]},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ room_id = channel.json_body["room_id"]
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ # Since a real room was created, the other user should be able to join it.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ # Both users should be in the room.
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
+
+ def test_message(self):
+ """Messages from shadow-banned users don't actually get sent."""
+
+ room_id = self.helper.create_room_as(
+ self.other_user_id, tok=self.other_access_token
+ )
+
+ # The user should be in the room.
+ self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
+
+ # Sending a message should complete successfully.
+ result = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "with right label"},
+ tok=self.banned_access_token,
+ )
+ self.assertIn("event_id", result)
+ event_id = result["event_id"]
+
+ latest_events = self.get_success(
+ self.store.get_latest_event_ids_in_room(room_id)
+ )
+ self.assertNotIn(event_id, latest_events)
+
+ def test_upgrade(self):
+ """A room upgrade should fail, but look like it succeeded."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
+ {"new_version": "6"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ # A new room_id should be returned.
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(new_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
+
+ def test_typing(self):
+ """Typing notifications should not be propagated into the room."""
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/typing/%s" % (room_id, self.banned_user_id),
+ {"typing": True, "timeout": 30000},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # There should be no typing events.
+ event_source = self.hs.get_event_sources().sources["typing"]
+ self.assertEquals(event_source.get_current_key(), 0)
+
+ # The other user can join and send typing events.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/typing/%s" % (room_id, self.other_user_id),
+ {"typing": True, "timeout": 30000},
+ access_token=self.other_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # These appear in the room.
+ self.assertEquals(event_source.get_current_key(), 1)
+ events = self.get_success(
+ event_source.get_new_events(from_key=0, room_ids=[room_id])
+ )
+ self.assertEquals(
+ events[0],
+ [
+ {
+ "type": "m.typing",
+ "room_id": room_id,
+ "content": {"user_ids": [self.other_user_id]},
+ }
+ ],
+ )
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class ProfileTestCase(_ShadowBannedBase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ def test_displayname(self):
+ """Profile changes should succeed, but don't end up in a room."""
+ original_display_name = "banned"
+ new_display_name = "new name"
+
+ # Join a room.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # The update should succeed.
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
+ {"displayname": new_display_name},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEqual(channel.json_body, {})
+
+ # The user's display name should be updated.
+ request, channel = self.make_request(
+ "GET", "/profile/%s/displayname" % (self.banned_user_id,)
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["displayname"], new_display_name)
+
+ # But the display name in the room should not be.
+ message_handler = self.hs.get_message_handler()
+ event = self.get_success(
+ message_handler.get_room_data(
+ self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ )
+ )
+ self.assertEqual(
+ event.content, {"membership": "join", "displayname": original_display_name}
+ )
+
+ def test_room_displayname(self):
+ """Changes to state events for a room should be processed, but not end up in the room."""
+ original_display_name = "banned"
+ new_display_name = "new name"
+
+ # Join a room.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # The update should succeed.
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
+ % (room_id, self.banned_user_id),
+ {"membership": "join", "displayname": new_display_name},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertIn("event_id", channel.json_body)
+
+ # The display name in the room should not be changed.
+ message_handler = self.hs.get_message_handler()
+ event = self.get_success(
+ message_handler.get_room_data(
+ self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ )
+ )
+ self.assertEqual(
+ event.content, {"membership": "join", "displayname": original_display_name}
+ )
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
index 7167fc56b6..8c24add530 100644
--- a/tests/rest/client/third_party_rules.py
+++ b/tests/rest/client/third_party_rules.py
@@ -19,7 +19,7 @@ from synapse.rest.client.v1 import login, room
from tests import unittest
-class ThirdPartyRulesTestModule(object):
+class ThirdPartyRulesTestModule:
def __init__(self, config):
pass
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index db52725cfe..2668662c9e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
new file mode 100644
index 0000000000..081052f6a6
--- /dev/null
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -0,0 +1,448 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, push_rule, room
+
+from tests.unittest import HomeserverTestCase
+
+
+class PushRuleAttributesTestCase(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ push_rule.register_servlets,
+ ]
+ hijack_auth = False
+
+ def test_enabled_on_creation(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is enabled upon creation, even though a rule with that
+ ruleId existed previously and was disabled.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_on_recreation(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is enabled upon creation, even if a rule with that
+ ruleId existed previously and was disabled.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # disable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": False},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule disabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], False)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_disable(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is disabled and enabled when we ask for it.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # disable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": False},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule disabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], False)
+
+ # re-enable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule enabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_404_when_get_non_existent(self):
+ """
+ Tests that `enabled` gives 404 when the rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check 404 for deleted rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_get_non_existent_server_rule(self):
+ """
+ Tests that `enabled` gives 404 when the server-default rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_put_non_existent_rule(self):
+ """
+ Tests that `enabled` gives 404 when we put to a rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_put_non_existent_server_rule(self):
+ """
+ Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/.m.muahahah/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_get(self):
+ """
+ Tests that `actions` gives you what you expect on a fresh rule.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET actions for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}]
+ )
+
+ def test_actions_put(self):
+ """
+ Tests that PUT on actions updates the value you'd get from GET.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # change the rule actions
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET actions for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["actions"], ["dont_notify"])
+
+ def test_actions_404_when_get_non_existent(self):
+ """
+ Tests that `actions` gives 404 when the rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check 404 for deleted rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_get_non_existent_server_rule(self):
+ """
+ Tests that `actions` gives 404 when the server-default rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_put_non_existent_rule(self):
+ """
+ Tests that `actions` gives 404 when putting to a rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_put_non_existent_server_rule(self):
+ """
+ Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/.m.muahahah/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index e74bddc1e5..0a567b032f 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -684,38 +684,39 @@ class RoomJoinRatelimitTestCase(RoomBase):
]
@unittest.override_config(
- {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit(self):
"""Tests that local joins are actually rate-limited."""
- for i in range(5):
+ for i in range(3):
self.helper.create_room_as(self.user_id)
self.helper.create_room_as(self.user_id, expect_code=429)
@unittest.override_config(
- {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_profile_change(self):
"""Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins."""
- # Create and join more rooms than the rate-limiting config allows in a second.
+ # Create and join as many rooms as the rate-limiting config allows in a second.
room_ids = [
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
]
- self.reactor.advance(1)
- room_ids = room_ids + [
- self.helper.create_room_as(self.user_id),
- self.helper.create_room_as(self.user_id),
- self.helper.create_room_as(self.user_id),
- ]
+ # Let some time for the rate-limiter to forget about our multi-join.
+ self.reactor.advance(2)
+ # Add one to make sure we're joined to more rooms than the config allows us to
+ # join in a second.
+ room_ids.append(self.helper.create_room_as(self.user_id))
# Create a profile for the user, since it hasn't been done on registration.
store = self.hs.get_datastore()
- store.create_profile(UserID.from_string(self.user_id).localpart)
+ self.get_success(
+ store.create_profile(UserID.from_string(self.user_id).localpart)
+ )
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
@@ -738,7 +739,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
self.assertEquals(channel.json_body["displayname"], "John Doe")
@unittest.override_config(
- {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_idempotent(self):
"""Tests that the room join endpoints remain idempotent despite rate-limiting
@@ -754,7 +755,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
for path in paths_to_test:
# Make sure we send more requests than the rate-limiting config would allow
# if all of these requests ended up joining the user to a room.
- for i in range(6):
+ for i in range(4):
request, channel = self.make_request("POST", path % room_id, {})
self.render(request)
self.assertEquals(channel.code, 200)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index e66c9a4c4c..afaf9f7b85 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -30,7 +30,7 @@ from tests.server import make_request, render
@attr.s
-class RestHelper(object):
+class RestHelper:
"""Contains extra helper functions to quickly and clearly perform a given
REST action, which isn't the focus of the test.
"""
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 152a5182fa..93f899d861 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -14,11 +14,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
import os
import re
from email.parser import Parser
+from typing import Optional
+from urllib.parse import urlencode
import pkg_resources
@@ -27,8 +28,10 @@ from synapse.api.constants import LoginType, Membership
from synapse.api.errors import Codes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
+from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest
+from tests.unittest import override_config
class PasswordResetTestCase(unittest.HomeserverTestCase):
@@ -69,6 +72,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
+ self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
def test_basic_password_reset(self):
"""Test basic password reset flow
@@ -250,8 +254,32 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Remove the host
path = link.replace("https://example.com", "")
+ # Load the password reset confirmation page
request, channel = self.make_request("GET", path, shorthand=False)
- self.render(request)
+ request.render(self.submit_token_resource)
+ self.pump()
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Now POST to the same endpoint, mimicking the same behaviour as clicking the
+ # password reset confirm button
+
+ # Send arguments as url-encoded form data, matching the template's behaviour
+ form_args = []
+ for key, value_list in request.args.items():
+ for value in value_list:
+ arg = (key, value)
+ form_args.append(arg)
+
+ # Confirm the password reset
+ request, channel = self.make_request(
+ "POST",
+ path,
+ content=urlencode(form_args).encode("utf8"),
+ shorthand=False,
+ content_is_form=True,
+ )
+ request.render(self.submit_token_resource)
+ self.pump()
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
@@ -668,16 +696,104 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
- def _request_token(self, email, client_secret):
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link(self):
+ """Tests a valid next_link parameter value with no whitelist (good case)"""
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/a/good/site",
+ expect_code=200,
+ )
+
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link_exotic_protocol(self):
+ """Tests using a esoteric protocol as a next_link parameter value.
+ Someone may be hosting a client on IPFS etc.
+ """
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
+ expect_code=200,
+ )
+
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link_file_uri(self):
+ """Tests next_link parameters cannot be file URI"""
+ # Attempt to use a next_link value that points to the local disk
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="file:///host/path",
+ expect_code=400,
+ )
+
+ @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
+ def test_next_link_domain_whitelist(self):
+ """Tests next_link parameters must fit the whitelist if provided"""
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/some/good/page",
+ expect_code=200,
+ )
+
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.org/some/also/good/page",
+ expect_code=200,
+ )
+
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://bad.example.org/some/bad/page",
+ expect_code=400,
+ )
+
+ @override_config({"next_link_domain_whitelist": []})
+ def test_empty_next_link_domain_whitelist(self):
+ """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
+ disallowed
+ """
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/a/page",
+ expect_code=400,
+ )
+
+ def _request_token(
+ self,
+ email: str,
+ client_secret: str,
+ next_link: Optional[str] = None,
+ expect_code: int = 200,
+ ) -> str:
+ """Request a validation token to add an email address to a user's account
+
+ Args:
+ email: The email address to validate
+ client_secret: A secret string
+ next_link: A link to redirect the user to after validation
+ expect_code: Expected return code of the call
+
+ Returns:
+ The ID of the new threepid validation session
+ """
+ body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+ if next_link:
+ body["next_link"] = next_link
+
request, channel = self.make_request(
- "POST",
- b"account/3pid/email/requestToken",
- {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ "POST", b"account/3pid/email/requestToken", body,
)
self.render(request)
- self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(expect_code, channel.code, channel.result)
- return channel.json_body["sid"]
+ return channel.json_body.get("sid")
def _request_token_invalid_email(
self, email, expected_errcode, expected_error, client_secret="foobar",
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index e0e9e94fbf..de00350580 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
from synapse.api.errors import Codes
from synapse.rest.client.v2_alpha import filter
@@ -73,8 +75,10 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_get_filter(self):
- filter_id = self.filtering.add_user_filter(
- user_localpart="apple", user_filter=self.EXAMPLE_FILTER
+ filter_id = defer.ensureDeferred(
+ self.filtering.add_user_filter(
+ user_localpart="apple", user_filter=self.EXAMPLE_FILTER
+ )
)
self.reactor.advance(1)
filter_id = filter_id.result
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 53a43038f0..2fc3a60fc5 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -160,7 +160,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -186,7 +186,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
new file mode 100644
index 0000000000..5ae72fd008
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import shared_rooms
+
+from tests import unittest
+
+
+class UserSharedRoomsTest(unittest.HomeserverTestCase):
+ """
+ Tests the UserSharedRoomsServlet.
+ """
+
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ shared_rooms.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["update_user_directory"] = True
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def _get_shared_rooms(self, token, other_user):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+ % other_user,
+ access_token=token,
+ )
+ self.render(request)
+ return request, channel
+
+ def test_shared_room_list_public(self):
+ """
+ A room should show up in the shared list of rooms between two users
+ if it is public.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ def test_shared_room_list_private(self):
+ """
+ A room should show up in the shared list of rooms between two users
+ if it is private.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ def test_shared_room_list_mixed(self):
+ """
+ The shared room list between two users should contain both public and private
+ rooms.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token)
+ self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token)
+ self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token)
+ self.helper.join(room_public, user=u2, tok=u2_token)
+ self.helper.join(room_private, user=u1, tok=u1_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 2)
+ self.assertTrue(room_public in channel.json_body["joined"])
+ self.assertTrue(room_private in channel.json_body["joined"])
+
+ def test_shared_room_list_after_leave(self):
+ """
+ A room should no longer be considered shared if the other
+ user has left it.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ # Assert user directory is not empty
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ self.helper.leave(room, user=u1, tok=u1_token)
+
+ request, channel = self._get_shared_rooms(u2_token, u1)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index fa3a3ec1bd..a31e44c97e 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
import json
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v2_alpha import read_marker, sync
from tests import unittest
from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch)
)
self.assertRaises(TimedOutException, self.render, request)
+
+
+class UnreadMessagesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ read_marker.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.url = "/sync?since=%s"
+ self.next_batch = "s0"
+
+ # Register the first user (used to check the unread counts).
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # Create the room we'll check unread counts for.
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ # Register the second user (used to send events to the room).
+ self.user2 = self.register_user("kermit2", "monkey")
+ self.tok2 = self.login("kermit2", "monkey")
+
+ # Change the power levels of the room so that the second user can send state
+ # events.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.PowerLevels,
+ {
+ "users": {self.user_id: 100, self.user2: 100},
+ "users_default": 0,
+ "events": {
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ "m.room.history_visibility": 100,
+ "m.room.canonical_alias": 50,
+ "m.room.avatar": 50,
+ "m.room.tombstone": 100,
+ "m.room.server_acl": 100,
+ "m.room.encryption": 100,
+ },
+ "events_default": 0,
+ "state_default": 50,
+ "ban": 50,
+ "kick": 50,
+ "redact": 50,
+ "invite": 0,
+ },
+ tok=self.tok,
+ )
+
+ def test_unread_counts(self):
+ """Tests that /sync returns the right value for the unread count (MSC2654)."""
+
+ # Check that our own messages don't increase the unread count.
+ self.helper.send(self.room_id, "hello", tok=self.tok)
+ self._check_unread_count(0)
+
+ # Join the new user and check that this doesn't increase the unread count.
+ self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
+ self._check_unread_count(0)
+
+ # Check that the new user sending a message increases our unread count.
+ res = self.helper.send(self.room_id, "hello", tok=self.tok2)
+ self._check_unread_count(1)
+
+ # Send a read receipt to tell the server we've read the latest event.
+ body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/read_markers" % self.room_id,
+ body,
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that the unread counter is back to 0.
+ self._check_unread_count(0)
+
+ # Check that room name changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+ )
+ self._check_unread_count(1)
+
+ # Check that room topic changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+ )
+ self._check_unread_count(2)
+
+ # Check that encrypted messages increase the unread counter.
+ self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
+ self._check_unread_count(3)
+
+ # Check that custom events with a body increase the unread counter.
+ self.helper.send_event(
+ self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that edits don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "body": "hello",
+ "msgtype": "m.text",
+ "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+ },
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that notices don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"body": "hello", "msgtype": "m.notice"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that tombstone events changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.Tombstone,
+ {"replacement_room": "!someroom:test"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(5)
+
+ def _check_unread_count(self, expected_count: True):
+ """Syncs and compares the unread count with the expected value."""
+
+ request, channel = self.make_request(
+ "GET", self.url % self.next_batch, access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ room_entry = channel.json_body["rooms"]["join"][self.room_id]
+ self.assertEqual(
+ room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+ )
+
+ # Store the next batch for the next request.
+ self.next_batch = channel.json_body["next_batch"]
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index f4f3e56777..5f897d49cf 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -120,12 +120,13 @@ class _TestImage:
extension = attr.ib(type=bytes)
expected_cropped = attr.ib(type=Optional[bytes])
expected_scaled = attr.ib(type=Optional[bytes])
+ expected_found = attr.ib(default=True, type=bool)
@parameterized_class(
("test_image",),
[
- # smol png
+ # smoll png
(
_TestImage(
unhexlify(
@@ -161,6 +162,8 @@ class _TestImage:
None,
),
),
+ # an empty file
+ (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
@@ -303,12 +306,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
- self._test_thumbnail("crop", self.test_image.expected_cropped)
+ self._test_thumbnail(
+ "crop", self.test_image.expected_cropped, self.test_image.expected_found
+ )
def test_thumbnail_scale(self):
- self._test_thumbnail("scale", self.test_image.expected_scaled)
+ self._test_thumbnail(
+ "scale", self.test_image.expected_scaled, self.test_image.expected_found
+ )
- def _test_thumbnail(self, method, expected_body):
+ def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
request, channel = self.make_request(
"GET", self.media_id + params, shorthand=False
@@ -325,11 +332,23 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
self.pump()
- self.assertEqual(channel.code, 200)
- if expected_body is not None:
+ if expected_found:
+ self.assertEqual(channel.code, 200)
+ if expected_body is not None:
+ self.assertEqual(
+ channel.result["body"], expected_body, channel.result["body"]
+ )
+ else:
+ # ensure that the result is at least some valid image
+ Image.open(BytesIO(channel.result["body"]))
+ else:
+ # A 404 with a JSON body.
+ self.assertEqual(channel.code, 404)
self.assertEqual(
- channel.result["body"], expected_body, channel.result["body"]
+ channel.json_body,
+ {
+ "errcode": "M_NOT_FOUND",
+ "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
+ % method,
+ },
)
- else:
- # ensure that the result is at least some valid image
- Image.open(BytesIO(channel.result["body"]))
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 74765a582b..c00a7b9114 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -32,7 +32,7 @@ from tests.server import FakeTransport
@attr.s
-class FakeResponse(object):
+class FakeResponse:
version = attr.ib()
code = attr.ib()
phrase = attr.ib()
@@ -43,7 +43,7 @@ class FakeResponse(object):
@property
def request(self):
@attr.s
- class FakeTransport(object):
+ class FakeTransport:
absoluteURI = self.absoluteURI
return FakeTransport()
@@ -111,7 +111,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups = {}
- class Resolver(object):
+ class Resolver:
def resolveHostName(
_self,
resolutionReceiver,
diff --git a/tests/server.py b/tests/server.py
index b6e0b14e78..61ec670155 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,6 +1,6 @@
import json
import logging
-from io import BytesIO
+from io import SEEK_END, BytesIO
import attr
from zope.interface import implementer
@@ -35,7 +35,7 @@ class TimedOutException(Exception):
@attr.s
-class FakeChannel(object):
+class FakeChannel:
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
@@ -135,6 +135,7 @@ def make_request(
request=SynapseRequest,
shorthand=True,
federation_auth_origin=None,
+ content_is_form=False,
):
"""
Make a web request using the given method and path, feed it the
@@ -150,6 +151,8 @@ def make_request(
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
+ content_is_form: Whether the content is URL encoded form data. Adds the
+ 'Content-Type': 'application/x-www-form-urlencoded' header.
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
@@ -181,6 +184,8 @@ def make_request(
req = request(channel)
req.process = lambda: b""
req.content = BytesIO(content)
+ # Twisted expects to be at the end of the content when parsing the request.
+ req.content.seek(SEEK_END)
req.postpath = list(map(unquote, path[1:].split(b"/")))
if access_token:
@@ -195,7 +200,13 @@ def make_request(
)
if content:
- req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
+ if content_is_form:
+ req.requestHeaders.addRawHeader(
+ b"Content-Type", b"application/x-www-form-urlencoded"
+ )
+ else:
+ # Assume the body is JSON
+ req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
req.requestReceived(method, path, b"1.1")
@@ -242,7 +253,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
lookups = self.lookups = {}
@implementer(IResolverSimple)
- class FakeResolver(object):
+ class FakeResolver:
def getHostByName(self, name, timeout=None):
if name not in lookups:
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
@@ -371,7 +382,7 @@ def get_clock():
@attr.s(cmp=False)
-class FakeTransport(object):
+class FakeTransport:
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
straight into an IProtocol object: it exists to connect two IProtocols together.
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 2858d13558..6382b19dc3 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
self._rlsn._store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(1000)
+ return_value=make_awaitable(1000)
)
self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock())
@@ -80,9 +80,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=defer.succeed("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
- self._rlsn._store.get_tags_for_room = Mock(
- side_effect=lambda user_id, room_id: make_awaitable({})
- )
+ self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self):
@@ -104,7 +102,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
@@ -122,7 +120,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -158,7 +156,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(None)
+ return_value=make_awaitable(None)
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -217,7 +215,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -261,10 +259,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self):
- self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000))
+ self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
self.store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(1000)
+ return_value=make_awaitable(1000)
)
# Call the function multiple times to ensure we only send the notice once
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index f2955a9c69..ad9bbef9d2 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -49,7 +49,7 @@ class FakeClock:
return defer.succeed(None)
-class FakeEvent(object):
+class FakeEvent:
"""A fake event we use as a convenience.
NOTE: Again as a convenience we use "node_ids" rather than event_ids to
@@ -595,7 +595,7 @@ def pairwise(iterable):
@attr.s
-class TestStateResolutionStore(object):
+class TestStateResolutionStore:
event_map = attr.ib()
def get_events(self, event_ids, allow_rejected=False):
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 319e2c2325..f5afed017c 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -99,7 +99,7 @@ class CacheTestCase(unittest.HomeserverTestCase):
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def test_passthrough(self):
- class A(object):
+ class A:
@cached()
def func(self, key):
return key
@@ -113,7 +113,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_hit(self):
callcount = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -131,7 +131,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_invalidate(self):
callcount = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -149,7 +149,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
- class A(object):
+ class A:
@cached()
def func(self, key):
return key
@@ -160,7 +160,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_max_entries(self):
callcount = [0]
- class A(object):
+ class A:
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
@@ -187,7 +187,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
d = defer.succeed(123)
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -205,7 +205,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0]
callcount2 = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
@@ -238,7 +238,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0]
callcount2 = [0]
- class A(object):
+ class A:
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
@@ -275,7 +275,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0]
callcount2 = [0]
- class A(object):
+ class A:
@cached()
def func(self, key):
callcount[0] += 1
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 98b74890d5..cb808d4de4 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
)
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -207,7 +208,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +222,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -234,7 +243,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def test_create_appservice_txn_first(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -246,7 +257,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_last_txn(service.id, 9643) # AS is falling behind
yield self._insert_txn(service.id, 9644, events)
yield self._insert_txn(service.id, 9645, events)
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -256,7 +269,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643)
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -277,7 +292,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(self.as_list[2]["id"], 10, events)
yield self._insert_txn(self.as_list[3]["id"], 9643, events)
- txn = yield self.store.create_appservice_txn(service, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events)
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -289,7 +306,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
txn_id = 1
yield self._insert_txn(service.id, txn_id, events)
- yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ yield defer.ensureDeferred(
+ self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ )
res = yield self.db_pool.runQuery(
self.engine.convert_param_style(
@@ -315,7 +334,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
txn_id = 5
yield self._set_last_txn(service.id, 4)
yield self._insert_txn(service.id, txn_id, events)
- yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ yield defer.ensureDeferred(
+ self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+ )
res = yield self.db_pool.runQuery(
self.engine.convert_param_style(
@@ -349,7 +370,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
+ self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 2efbc97c2e..02aae1c13d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,7 +1,5 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest
@@ -38,11 +36,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
# first step: make a bit of progress
- @defer.inlineCallbacks
- def update(progress, count):
- yield self.clock.sleep((count * duration_ms) / 1000)
+ async def update(progress, count):
+ await self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
- yield store.db_pool.runInteraction(
+ await store.db_pool.runInteraction(
"update_progress",
self.updates._background_update_progress_txn,
"test_update",
@@ -67,13 +64,12 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# second step: complete the update
# we should now get run with a much bigger number of items to update
- @defer.inlineCallbacks
- def update(progress, count):
+ async def update(progress, count):
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
count, target_background_update_duration_ms / duration_ms, places=0,
)
- yield self.updates._end_background_update("test_update")
+ await self.updates._end_background_update("test_update")
return count
self.update_handler.side_effect = update
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index efcaeef1e7..40ba652248 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_insert(
- table="tablename", values={"columname": "Value"}
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename", values={"columname": "Value"}
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_insert(
- table="tablename",
- # Use OrderedDict() so we can assert on the SQL generated
- values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename",
+ # Use OrderedDict() so we can assert on the SQL generated
+ values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -93,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
- value = yield self.datastore.db_pool.simple_select_one_onecol(
- table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ value = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one_onecol(
+ table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ )
)
self.assertEquals("Value", value)
@@ -107,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
- ret = yield self.datastore.db_pool.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "TheKey"},
- retcols=["colA", "colB", "colC"],
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "TheKey"},
+ retcols=["colA", "colB", "colC"],
+ )
)
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@@ -123,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
- ret = yield self.datastore.db_pool.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "Not here"},
- retcols=["colA"],
- allow_none=True,
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "Not here"},
+ retcols=["colA"],
+ allow_none=True,
+ )
)
self.assertFalse(ret)
@@ -138,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
- ret = yield self.datastore.db_pool.simple_select_list(
- table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_list(
+ table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+ )
)
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
@@ -151,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_update_one(
- table="tablename",
- keyvalues={"keycol": "TheKey"},
- updatevalues={"columnname": "New Value"},
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_update_one(
+ table="tablename",
+ keyvalues={"keycol": "TheKey"},
+ updatevalues={"columnname": "New Value"},
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -166,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_4cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_update_one(
- table="tablename",
- keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
- updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_update_one(
+ table="tablename",
+ keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
+ updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -181,8 +197,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_delete_one(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_delete_one(
- table="tablename", keyvalues={"keycol": "Go away"}
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_delete_one(
+ table="tablename", keyvalues={"keycol": "Go away"}
+ )
)
self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 3fab5a5248..080761d1d2 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
@@ -271,7 +271,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Pump the reactor repeatedly so that the background updates have a
# chance to run.
- self.pump(10 * 60)
+ self.pump(20)
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
"3"
] = 300000
+
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
# All entries within time frame
self.assertEqual(
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
3,
)
# Oldest room to expire
- self.pump(1)
+ self.pump(1.01)
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
self.assertEqual(
len(
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 224ea6fd79..755c70db31 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -16,13 +16,12 @@
from mock import Mock
-from twisted.internet import defer
-
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -155,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
user_id = "@user:server"
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(lots_of_users)
+ return_value=make_awaitable(lots_of_users)
)
self.get_success(
self.store.insert_client_ip(
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 87ed8f8cd1..34ae8c9da7 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name")
)
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name 1")
)
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
@@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# check it worked
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index daac947cb2..da93ca3980 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -42,7 +42,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals(
["#my-room:test"],
- (yield self.store.get_aliases_for_room(self.room.to_string())),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_aliases_for_room(self.room.to_string())
+ )
+ ),
)
@defer.inlineCallbacks
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index d57cdffd8b..3fc4bb13b6 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -32,10 +32,12 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(self.store.store_device("user", "device", None))
- yield self.store.set_e2e_device_keys("user", "device", now, json)
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user", "device"),))
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -49,12 +51,16 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(self.store.store_device("user", "device", None))
- changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+ changed = yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing
# changed
- changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+ changed = yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
self.assertFalse(changed)
@defer.inlineCallbacks
@@ -62,13 +68,15 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070
json = {"key": "value"}
- yield self.store.set_e2e_device_keys("user", "device", now, json)
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user", "device", now, json)
+ )
yield defer.ensureDeferred(
self.store.store_device("user", "device", "display_name")
)
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user", "device"),))
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -86,13 +94,23 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
- yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
- yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
- yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
- yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
+ )
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
+ )
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
+ )
+ yield defer.ensureDeferred(
+ self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+ )
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
+ self.store.get_e2e_device_keys_for_cs_api(
+ (("user1", "device1"), ("user2", "device2"))
+ )
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a7b85004e5..949846fe33 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
# Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)]
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 857db071d4..c0595963dd 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -60,12 +60,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield self.store.db_pool.runInteraction(
- "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+ counts = yield defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+ )
)
self.assertEquals(
counts,
- {"notify_count": noitf_count, "highlight_count": highlight_count},
+ {
+ "notify_count": noitf_count,
+ "unread_count": 0, # Unread counts are tested in the sync tests.
+ "highlight_count": highlight_count,
+ },
)
@defer.inlineCallbacks
@@ -78,28 +84,34 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(
self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}
+ event.event_id, {user_id: action}, False,
)
)
- yield self.store.db_pool.runInteraction(
- "",
- self.persist_events_store._set_push_actions_for_event_and_users_txn,
- [(event, None)],
- [(event, None)],
+ yield defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "",
+ self.persist_events_store._set_push_actions_for_event_and_users_txn,
+ [(event, None)],
+ [(event, None)],
+ )
)
def _rotate(stream):
- return self.store.db_pool.runInteraction(
- "", self.store._rotate_notifs_before_txn, stream
+ return defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "", self.store._rotate_notifs_before_txn, stream
+ )
)
def _mark_read(stream, depth):
- return self.store.db_pool.runInteraction(
- "",
- self.store._remove_old_push_actions_before_txn,
- room_id,
- user_id,
- stream,
+ return defer.ensureDeferred(
+ self.store.db_pool.runInteraction(
+ "",
+ self.store._remove_old_push_actions_before_txn,
+ room_id,
+ user_id,
+ stream,
+ )
)
yield _assert_counts(0, 0)
@@ -123,8 +135,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
- yield self.store.db_pool.simple_delete(
- table="event_push_actions", keyvalues={"1": 1}, desc=""
+ yield defer.ensureDeferred(
+ self.store.db_pool.simple_delete(
+ table="event_push_actions", keyvalues={"1": 1}, desc=""
+ )
)
yield _assert_counts(1, 0)
@@ -142,33 +156,43 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return self.store.db_pool.simple_insert(
- "events",
- {
- "stream_ordering": so,
- "received_ts": ts,
- "event_id": "event%i" % so,
- "type": "",
- "room_id": "",
- "content": "",
- "processed": True,
- "outlier": False,
- "topological_ordering": 0,
- "depth": 0,
- },
+ return defer.ensureDeferred(
+ self.store.db_pool.simple_insert(
+ "events",
+ {
+ "stream_ordering": so,
+ "received_ts": ts,
+ "event_id": "event%i" % so,
+ "type": "",
+ "room_id": "",
+ "content": "",
+ "processed": True,
+ "outlier": False,
+ "topological_ordering": 0,
+ "depth": 0,
+ },
+ )
)
# start with the base case where there are no events in the table
- r = yield self.store.find_first_stream_ordering_after_ts(11)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(11)
+ )
self.assertEqual(r, 0)
# now with one event
yield add_event(2, 10)
- r = yield self.store.find_first_stream_ordering_after_ts(9)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(9)
+ )
self.assertEqual(r, 2)
- r = yield self.store.find_first_stream_ordering_after_ts(10)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(10)
+ )
self.assertEqual(r, 2)
- r = yield self.store.find_first_stream_ordering_after_ts(11)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(11)
+ )
self.assertEqual(r, 3)
# add a bunch of dummy events to the events table
@@ -181,25 +205,37 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
):
yield add_event(stream_ordering, ts)
- r = yield self.store.find_first_stream_ordering_after_ts(110)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(110)
+ )
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5
- r = yield self.store.find_first_stream_ordering_after_ts(120)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(120)
+ )
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
- r = yield self.store.find_first_stream_ordering_after_ts(129)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(129)
+ )
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event
- r = yield self.store.find_first_stream_ordering_after_ts(140)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(140)
+ )
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end
- r = yield self.store.find_first_stream_ordering_after_ts(160)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(160)
+ )
self.assertEqual(r, 21)
# check we can find an event at ordering zero
yield add_event(0, 5)
- r = yield self.store.find_first_stream_ordering_after_ts(1)
+ r = yield defer.ensureDeferred(
+ self.store.find_first_stream_ordering_after_ts(1)
+ )
self.assertEqual(r, 0)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index e845410dae..20636fc400 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -58,6 +58,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
+ """Insert N rows as the given instance, inserting with stream IDs pulled
+ from the postgres sequence.
+ """
+
def _insert(txn):
for _ in range(number):
txn.execute(
@@ -65,7 +69,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
(instance_name,),
)
- self.get_success(self.db_pool.runInteraction("test_single_instance", _insert))
+ self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+ def _insert_row_with_id(self, instance_name: str, stream_id: int):
+ """Insert one row as the given instance with given stream_id, updating
+ the postgres sequence position to match.
+ """
+
+ def _insert(txn):
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ )
+ txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+
+ self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
def test_empty(self):
"""Test an ID generator against an empty database gives sensible
@@ -88,7 +105,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -98,12 +115,62 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+
+ def test_out_of_order_finish(self):
+ """Test that IDs persisted out of order are correctly handled
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ ctx1 = self.get_success(id_gen.get_next())
+ ctx2 = self.get_success(id_gen.get_next())
+ ctx3 = self.get_success(id_gen.get_next())
+ ctx4 = self.get_success(id_gen.get_next())
+
+ s1 = ctx1.__enter__()
+ s2 = ctx2.__enter__()
+ s3 = ctx3.__enter__()
+ s4 = ctx4.__enter__()
+
+ self.assertEqual(s1, 8)
+ self.assertEqual(s2, 9)
+ self.assertEqual(s3, 10)
+ self.assertEqual(s4, 11)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ ctx2.__exit__(None, None, None)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ ctx1.__exit__(None, None, None)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
+
+ ctx4.__exit__(None, None, None)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
+
+ ctx3.__exit__(None, None, None)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 11})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
@@ -116,8 +183,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second")
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token("first"), 3)
- self.assertEqual(first_id_gen.get_current_token("second"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -166,7 +233,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -176,9 +243,179 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+
+ def test_get_persisted_upto_position(self):
+ """Test that `get_persisted_upto_position` correctly tracks updates to
+ positions.
+ """
+
+ # The following tests are a bit cheeky in that we notify about new
+ # positions via `advance` without *actually* advancing the postgres
+ # sequence.
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ id_gen = self._create_id_generator("first")
+
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+ # Min is 3 and there is a gap between 5, so we expect it to be 3.
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ # We advance "first" straight to 6. Min is now 5 but there is no gap so
+ # we expect it to be 6
+ id_gen.advance("first", 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+ # No gap, so we expect 7.
+ id_gen.advance("second", 7)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # We haven't seen 8 yet, so we expect 7 still.
+ id_gen.advance("second", 9)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # Now that we've seen 7, 8 and 9 we can got straight to 9.
+ id_gen.advance("first", 8)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 9)
+
+ # Jump forward with gaps. The minimum is 11, even though we haven't seen
+ # 10 we know that everything before 11 must be persisted.
+ id_gen.advance("first", 11)
+ id_gen.advance("second", 15)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 11)
+
+ def test_get_persisted_upto_position_get_next(self):
+ """Test that `get_persisted_upto_position` correctly tracks updates to
+ positions when `get_next` is called.
+ """
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ id_gen = self._create_id_generator("first")
+
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ with self.get_success(id_gen.get_next()) as stream_id:
+ self.assertEqual(stream_id, 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+ # We assume that so long as `get_next` does correctly advance the
+ # `persisted_upto_position` in this case, then it will be correct in the
+ # other cases that are tested above (since they'll hit the same code).
+
+
+class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
+ """
+
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db_pool = self.store.db_pool # type: DatabasePool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db_pool,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ positive=False,
+ )
+
+ return self.get_success(self.db_pool.runWithConnection(_create))
+
+ def _insert_row(self, instance_name: str, stream_id: int):
+ """Insert one row as the given instance with given stream_id.
+ """
+
+ def _insert(txn):
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+ id_gen = self._create_id_generator()
+
+ with self.get_success(id_gen.get_next()) as stream_id:
+ self._insert_row("master", stream_id)
+
+ self.assertEqual(id_gen.get_positions(), {"master": -1})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -1)
+
+ with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
+ for stream_id in stream_ids:
+ self._insert_row("master", stream_id)
+
+ self.assertEqual(id_gen.get_positions(), {"master": -4})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -4)
+
+ # Test loading from DB by creating a second ID gen
+ second_id_gen = self._create_id_generator()
+
+ self.assertEqual(second_id_gen.get_positions(), {"master": -4})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
+
+ def test_multiple_instance(self):
+ """Tests that having multiple instances that get advanced over
+ federation works corretly.
+ """
+ id_gen_1 = self._create_id_generator("first")
+ id_gen_2 = self._create_id_generator("second")
+
+ with self.get_success(id_gen_1.get_next()) as stream_id:
+ self._insert_row("first", stream_id)
+ id_gen_2.advance("first", stream_id)
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
+
+ with self.get_success(id_gen_2.get_next()) as stream_id:
+ self._insert_row("second", stream_id)
+ id_gen_1.advance("second", stream_id)
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..7e7f1286d9 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -34,12 +34,16 @@ class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_users_paginate(self):
- yield self.store.register_user(self.user.to_string(), "pass")
- yield self.store.create_profile(self.user.localpart)
- yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ yield defer.ensureDeferred(
+ self.store.register_user(self.user.to_string(), "pass")
+ )
+ yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ )
- users, total = yield self.store.get_users_paginate(
- 0, 10, name="bc", guests=False
+ users, total = yield defer.ensureDeferred(
+ self.store.get_users_paginate(0, 10, name="bc", guests=False)
)
self.assertEquals(1, total)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9870c74883..643072bbaf 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -231,9 +231,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
d = self.store.populate_monthly_active_users(user_id)
self.get_success(d)
@@ -241,9 +239,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
@@ -256,9 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock(
@@ -344,9 +338,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..3fd0a38cf5 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,23 +33,36 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(self.u_frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
- yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ )
self.assertEquals(
- "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
+ "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.u_frank.localpart)
+ )
+ ),
)
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(self.u_frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
- yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.u_frank.localpart, "http://my.site/here"
+ )
)
self.assertEquals(
"http://my.site/here",
- (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.u_frank.localpart)
+ )
+ ),
)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index a6012c973d..918387733b 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+from synapse.api.errors import NotFoundError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -46,30 +47,19 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
- event = store.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
-
- # Purge everything before this topological token
- purge = defer.ensureDeferred(
- storage.purge_events.purge_history(self.room_id, event, True)
+ event = self.get_success(
+ store.get_topological_token_for_event(last["event_id"])
)
- self.pump()
- self.assertEqual(self.successResultOf(purge), None)
- # Try and get the events
- get_first = store.get_event(first["event_id"])
- get_second = store.get_event(second["event_id"])
- get_third = store.get_event(third["event_id"])
- get_last = store.get_event(last["event_id"])
- self.pump()
+ # Purge everything before this topological token
+ self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
- self.failureResultOf(get_first)
- self.failureResultOf(get_second)
- self.failureResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_failure(store.get_event(first["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(second["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(third["event_id"]), NotFoundError)
+ self.get_success(store.get_event(last["event_id"]))
def test_purge_wont_delete_extrems(self):
"""
@@ -84,9 +74,9 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
- event = storage.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
+ event = self.get_success(
+ storage.get_topological_token_for_event(last["event_id"])
+ )
event = "t{}-{}".format(
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
)
@@ -98,14 +88,7 @@ class PurgeTests(HomeserverTestCase):
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
- get_first = storage.get_event(first["event_id"])
- get_second = storage.get_event(second["event_id"])
- get_third = storage.get_event(third["event_id"])
- get_last = storage.get_event(last["event_id"])
- self.pump()
-
- # Nothing is deleted.
- self.successResultOf(get_first)
- self.successResultOf(get_second)
- self.successResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_success(storage.get_event(first["event_id"]))
+ self.get_success(storage.get_event(second["event_id"]))
+ self.get_success(storage.get_event(third["event_id"]))
+ self.get_success(storage.get_event(last["event_id"]))
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 840db66072..6b582771fe 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,6 +17,7 @@
from twisted.internet import defer
from synapse.api.constants import UserTypes
+from synapse.api.errors import ThreepidValidationError
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -36,7 +37,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_register(self):
- yield self.store.register_user(self.user_id, self.pwhash)
+ yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals(
{
@@ -52,19 +53,21 @@ class RegistrationStoreTestCase(unittest.TestCase):
"user_type": None,
"deactivated": 0,
},
- (yield self.store.get_user_by_id(self.user_id)),
+ (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
)
@defer.inlineCallbacks
def test_add_tokens(self):
- yield self.store.register_user(self.user_id, self.pwhash)
+ yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
- result = yield self.store.get_user_by_access_token(self.tokens[1])
+ result = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[1])
+ )
self.assertDictContainsSubset(
{"name": self.user_id, "device_id": self.device_id}, result
@@ -75,7 +78,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
- yield self.store.register_user(self.user_id, self.pwhash)
+ yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred(
self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
@@ -88,22 +91,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
)
# now delete some
- yield self.store.user_delete_access_tokens(
- self.user_id, device_id=self.device_id
+ yield defer.ensureDeferred(
+ self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
)
# check they were deleted
- user = yield self.store.get_user_by_access_token(self.tokens[1])
+ user = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[1])
+ )
self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted
- user = yield self.store.get_user_by_access_token(self.tokens[0])
+ user = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[0])
+ )
self.assertEqual(self.user_id, user["name"])
# now delete the rest
- yield self.store.user_delete_access_tokens(self.user_id)
+ yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
- user = yield self.store.get_user_by_access_token(self.tokens[0])
+ user = yield defer.ensureDeferred(
+ self.store.get_user_by_access_token(self.tokens[0])
+ )
self.assertIsNone(user, "access token was not deleted without device_id")
@defer.inlineCallbacks
@@ -111,14 +120,48 @@ class RegistrationStoreTestCase(unittest.TestCase):
TEST_USER = "@test:test"
SUPPORT_USER = "@support:test"
- res = yield self.store.is_support_user(None)
+ res = yield defer.ensureDeferred(self.store.is_support_user(None))
self.assertFalse(res)
- yield self.store.register_user(user_id=TEST_USER, password_hash=None)
- res = yield self.store.is_support_user(TEST_USER)
+ yield defer.ensureDeferred(
+ self.store.register_user(user_id=TEST_USER, password_hash=None)
+ )
+ res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
self.assertFalse(res)
- yield self.store.register_user(
- user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+ yield defer.ensureDeferred(
+ self.store.register_user(
+ user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+ )
)
- res = yield self.store.is_support_user(SUPPORT_USER)
+ res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
+
+ @defer.inlineCallbacks
+ def test_3pid_inhibit_invalid_validation_session_error(self):
+ """Tests that enabling the configuration option to inhibit 3PID errors on
+ /requestToken also inhibits validation errors caused by an unknown session ID.
+ """
+
+ # Check that, with the config setting set to false (the default value), a
+ # validation error is caused by the unknown session ID.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Unknown session_id", e)
+
+ # Set the config setting to true.
+ self.store._ignore_unknown_session_error = True
+
+ # Check that now the validation error is caused by the token not matching.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index d07b985a8e..bc8400f240 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (yield self.store.get_room(self.room.to_string())),
+ (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_get_room_unknown_room(self):
- self.assertIsNone((yield self.store.get_room("!uknown:test")),)
+ self.assertIsNone(
+ (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
+ )
@defer.inlineCallbacks
def test_get_room_with_stats(self):
@@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
- (yield self.store.get_room_with_stats(self.room.to_string())),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats(self.room.to_string())
+ )
+ ),
)
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
- self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats("!uknown:test")
+ )
+ ),
+ )
class RoomEventsStoreTestCase(unittest.TestCase):
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 17c9da4838..12ccc1f53e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -87,7 +87,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
- self.pump(20)
+ self.pump()
self.assertTrue("_known_servers_count" not in self.store.__dict__.keys())
@@ -101,7 +101,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Initialises to 1 -- itself
self.assertEqual(self.store._known_servers_count, 1)
- self.pump(20)
+ self.pump()
# No rooms have been joined, so technically the SQL returns 0, but it
# will still say it knows about itself.
@@ -111,7 +111,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
- self.pump(20)
+ self.pump(1)
# It now knows about Charlie's server.
self.assertEqual(self.store._known_servers_count, 2)
@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index ecfafe68a9..738e912468 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -31,10 +31,18 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
- yield self.store.update_profile_in_user_dir(ALICE, "alice", None)
- yield self.store.update_profile_in_user_dir(BOB, "bob", None)
- yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
- yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+ yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(ALICE, "alice", None)
+ )
+ yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(BOB, "bob", None)
+ )
+ yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
+ )
+ yield defer.ensureDeferred(
+ self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+ )
@defer.inlineCallbacks
def test_search_user_dir(self):
diff --git a/tests/test_federation.py b/tests/test_federation.py
index f2fa42bfb9..27a7fc9ed7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -15,8 +15,9 @@
from mock import Mock
-from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
+from twisted.internet.defer import succeed
+from synapse.api.errors import FederationError
from synapse.events import make_event_from_dict
from synapse.logging.context import LoggingContext
from synapse.types import Requester, UserID
@@ -42,24 +43,19 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
user_id = UserID("us", "test")
- our_user = Requester(user_id, None, False, None, None)
+ our_user = Requester(user_id, None, False, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
- room_deferred = ensureDeferred(
+ self.room_id = self.get_success(
room_creator.create_room(
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
)
- )
- self.reactor.advance(0.1)
- self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
+ )[0]["room_id"]
self.store = self.homeserver.get_datastore()
# Figure out what the most recent event is
- most_recent = self.successResultOf(
- maybeDeferred(
- self.homeserver.get_datastore().get_latest_event_ids_in_room,
- self.room_id,
- )
+ most_recent = self.get_success(
+ self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id)
)[0]
join_event = make_event_from_dict(
@@ -89,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
# Send the join, it should return None (which is not an error)
- d = ensureDeferred(
- self.handler.on_receive_pdu(
- "test.serv", join_event, sent_to_us_directly=True
- )
+ self.assertEqual(
+ self.get_success(
+ self.handler.on_receive_pdu(
+ "test.serv", join_event, sent_to_us_directly=True
+ )
+ ),
+ None,
)
- self.reactor.advance(1)
- self.assertEqual(self.successResultOf(d), None)
# Make sure we actually joined the room
self.assertEqual(
- self.successResultOf(
- maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
- )[0],
+ self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
"$join:test.serv",
)
@@ -119,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.http_client.post_json = post_json
# Figure out what the most recent event is
- most_recent = self.successResultOf(
- maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
+ most_recent = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
)[0]
# Now lie about an event
@@ -140,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
with LoggingContext(request="lying_event"):
- d = ensureDeferred(
+ failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
- )
+ ),
+ FederationError,
)
- # Step the reactor, so the database fetches come back
- self.reactor.advance(1)
-
# on_receive_pdu should throw an error
- failure = self.failureResultOf(d)
self.assertEqual(
failure.value.args[0],
(
@@ -160,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
# Make sure the invalid event isn't there
- extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
- self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
+ extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
+ self.assertEqual(extrem[0], "$join:test.serv")
def test_retry_device_list_resync(self):
"""Tests that device lists are marked as stale if they couldn't be synced, and
diff --git a/tests/test_server.py b/tests/test_server.py
index d628070e48..655c918a15 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -178,7 +178,6 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
- self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"])
class OptionsResourceTests(unittest.TestCase):
diff --git a/tests/test_state.py b/tests/test_state.py
index b5c3667d2a..2d58467932 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -71,7 +71,7 @@ def create_event(
return event
-class StateGroupStore(object):
+class StateGroupStore:
def __init__(self):
self._event_to_state_group = {}
self._group_to_state = {}
@@ -80,16 +80,16 @@ class StateGroupStore(object):
self._next_group = 1
- def get_state_groups_ids(self, room_id, event_ids):
+ async def get_state_groups_ids(self, room_id, event_ids):
groups = {}
for event_id in event_ids:
group = self._event_to_state_group.get(event_id)
if group:
groups[group] = self._group_to_state[group]
- return defer.succeed(groups)
+ return groups
- def store_state_group(
+ async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
state_group = self._next_group
@@ -97,19 +97,17 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids)
- return defer.succeed(state_group)
+ return state_group
- def get_events(self, event_ids, **kwargs):
- return defer.succeed(
- {
- e_id: self._event_id_to_event[e_id]
- for e_id in event_ids
- if e_id in self._event_id_to_event
- }
- )
+ async def get_events(self, event_ids, **kwargs):
+ return {
+ e_id: self._event_id_to_event[e_id]
+ for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
- def get_state_group_delta(self, name):
- return defer.succeed((None, None))
+ async def get_state_group_delta(self, name):
+ return (None, None)
def register_events(self, events):
for e in events:
@@ -121,8 +119,8 @@ class StateGroupStore(object):
def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group
- def get_room_version_id(self, room_id):
- return defer.succeed(RoomVersions.V1.identifier)
+ async def get_room_version_id(self, room_id):
+ return RoomVersions.V1.identifier
class DictObj(dict):
@@ -131,7 +129,7 @@ class DictObj(dict):
self.__dict__ = self
-class Graph(object):
+class Graph:
def __init__(self, nodes, edges):
events = {}
clobbered = set(events.keys())
@@ -476,12 +474,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = yield self.store.store_state_group(
- prev_event_id,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state},
+ group_name = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
)
self.store.register_event_id_state_group(prev_event_id, group_name)
@@ -508,12 +508,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = yield self.store.store_state_group(
- prev_event_id,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state},
+ group_name = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
)
self.store.register_event_id_state_group(prev_event_id, group_name)
@@ -691,21 +693,25 @@ class StateTestCase(unittest.TestCase):
def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
- sg1 = yield self.store.store_state_group(
- prev_event_id_1,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state_1},
+ sg1 = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id_1,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state_1},
+ )
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
- sg2 = yield self.store.store_state_group(
- prev_event_id_2,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state_2},
+ sg2 = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id_2,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state_2},
+ )
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 508aeba078..a298cc0fd3 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,6 +17,7 @@
"""
Utilities for running the unit tests
"""
+from asyncio import Future
from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@@ -38,6 +39,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed")
-async def make_awaitable(result: Any):
- """Create an awaitable that just returns a result."""
- return result
+def make_awaitable(result: Any) -> Awaitable[Any]:
+ """
+ Makes an awaitable, suitable for mocking an `async` function.
+ This uses Futures as they can be awaited multiple times so can be returned
+ to multiple callers.
+ """
+ future = Future() # type: ignore
+ future.set_result(result)
+ return future
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8522c6fc09..e93aa84405 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -13,14 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import Collection
"""
Utility functions for poking events into the storage of the server under test.
@@ -58,7 +57,7 @@ async def inject_member_event(
async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> EventBase:
"""Inject a generic event into a room
@@ -72,7 +71,10 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- await hs.get_storage().persistence.persist_event(event, context)
+ persistence = hs.get_storage().persistence
+ assert persistence is not None
+
+ await persistence.persist_event(event, context)
return event
@@ -80,7 +82,7 @@ async def inject_event(
async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
if room_version is None:
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 531a9b9118..510b630114 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -37,7 +37,6 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.hs = yield setup_test_homeserver(self.addCleanup)
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
- self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage()
yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@@ -99,7 +98,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt)
# the erasey user gets erased
- yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+ yield defer.ensureDeferred(
+ self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+ )
# ... and the filtering happens.
filtered = yield defer.ensureDeferred(
@@ -293,7 +294,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
test_large_room.skip = "Disabled by default because it's slow"
-class _TestStore(object):
+class _TestStore:
"""Implements a few methods of the DataStore, so that we can test
filter_events_for_server
diff --git a/tests/unittest.py b/tests/unittest.py
index d0bba3ddef..128dd4e19c 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -250,7 +250,11 @@ class HomeserverTestCase(TestCase):
async def get_user_by_req(request, allow_guest=False, rights="access"):
return create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
+ UserID.from_string(self.helper.auth_user_id),
+ 1,
+ False,
+ False,
+ None,
)
self.hs.get_auth().get_user_by_req = get_user_by_req
@@ -349,6 +353,7 @@ class HomeserverTestCase(TestCase):
request: Type[T] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: str = None,
+ content_is_form: bool = False,
) -> Tuple[T, FakeChannel]:
"""
Create a SynapseRequest at the path using the method and containing the
@@ -364,6 +369,8 @@ class HomeserverTestCase(TestCase):
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
+ content_is_form: Whether the content is URL encoded form data. Adds the
+ 'Content-Type': 'application/x-www-form-urlencoded' header.
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
@@ -380,6 +387,7 @@ class HomeserverTestCase(TestCase):
request,
shorthand,
federation_auth_origin,
+ content_is_form,
)
def render(self, request):
@@ -540,7 +548,7 @@ class HomeserverTestCase(TestCase):
"""
event_creator = self.hs.get_event_creation_handler()
secrets = self.hs.get_secrets()
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
event, context = self.get_success(
event_creator.create_event(
@@ -610,7 +618,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
"""
def prepare(self, reactor, clock, homeserver):
- class Authenticator(object):
+ class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 4d2b9e0d64..677e925477 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -88,7 +88,7 @@ class CacheTestCase(unittest.TestCase):
class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -122,7 +122,7 @@ class DescriptorTestCase(unittest.TestCase):
def test_cache_num_args(self):
"""Only the first num_args arguments should matter to the cache"""
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -156,7 +156,7 @@ class DescriptorTestCase(unittest.TestCase):
"""If the wrapped function throws synchronously, things should continue to work
"""
- class Cls(object):
+ class Cls:
@cached()
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
@@ -180,7 +180,7 @@ class DescriptorTestCase(unittest.TestCase):
complete_lookup = defer.Deferred()
- class Cls(object):
+ class Cls:
@descriptors.cached()
def fn(self, arg1):
@defer.inlineCallbacks
@@ -223,7 +223,7 @@ class DescriptorTestCase(unittest.TestCase):
"""Check that the cache sets and restores logcontexts correctly when
the lookup function throws an exception"""
- class Cls(object):
+ class Cls:
@descriptors.cached()
def fn(self, arg1):
@defer.inlineCallbacks
@@ -263,7 +263,7 @@ class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache_default_args(self):
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -300,7 +300,7 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called()
def test_cache_iterable(self):
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -336,7 +336,7 @@ class DescriptorTestCase(unittest.TestCase):
"""If the wrapped function throws synchronously, things should continue to work
"""
- class Cls(object):
+ class Cls:
@descriptors.cached(iterable=True)
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
@@ -358,7 +358,7 @@ class DescriptorTestCase(unittest.TestCase):
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
assert current_context().request == "c1"
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
assert current_context().request == "c1"
return self.mock(args1, arg2)
@@ -408,7 +408,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def test_invalidate(self):
"""Make sure that invalidation callbacks are called."""
- class Cls(object):
+ class Cls:
def __init__(self):
self.mock = mock.Mock()
@@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
return self.mock(args1, arg2)
obj = Cls()
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index 8d6627ec33..2012263184 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -112,7 +112,7 @@ class FileConsumerTests(unittest.TestCase):
self.assertTrue(string_file.closed)
-class DummyPullProducer(object):
+class DummyPullProducer:
def __init__(self):
self.consumer = None
self.deferred = defer.Deferred()
@@ -134,7 +134,7 @@ class DummyPullProducer(object):
return d
-class BlockingStringWrite(object):
+class BlockingStringWrite:
def __init__(self):
self.buffer = ""
self.closed = False
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index bc42ffce88..5f46ed0cef 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -91,7 +91,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
#
# one more go, with success
#
- self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
+ self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index bd32e2cee7..d3dea3b52a 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
from synapse.util.async_helpers import ReadWriteLock
@@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
rwlock.read(key), # 5
rwlock.write(key), # 6
]
+ ds = [defer.ensureDeferred(d) for d in ds]
self._assert_called_before_not_after(ds, 2)
@@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase):
with ds[6].result:
pass
- d = rwlock.write(key)
+ d = defer.ensureDeferred(rwlock.write(key))
self.assertTrue(d.called)
with d.result:
pass
- d = rwlock.read(key)
+ d = defer.ensureDeferred(rwlock.read(key))
self.assertTrue(d.called)
with d.result:
pass
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index 4f4da29a98..8491f7cc83 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -28,9 +28,6 @@ class StringUtilsTestCase(unittest.TestCase):
"_--something==_",
"...--==-18913",
"8Dj2odd-e9asd.cd==_--ddas-secret-",
- # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766
- # To be removed in a future release
- "SECRET:1234567890",
]
bad = [
diff --git a/tests/utils.py b/tests/utils.py
index a61cbdef44..4673872f88 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -472,7 +472,7 @@ class MockHttpResource(HttpServer):
self.callbacks.append((method, path_pattern, callback))
-class MockKey(object):
+class MockKey:
alg = "mock_alg"
version = "mock_version"
signature = b"\x9a\x87$"
@@ -491,7 +491,7 @@ class MockKey(object):
return b"<fake_encoded_key>"
-class MockClock(object):
+class MockClock:
now = 1000
def __init__(self):
@@ -568,7 +568,7 @@ def _format_call(args, kwargs):
)
-class DeferredMockCallable(object):
+class DeferredMockCallable:
"""A callable instance that stores a set of pending call expectations and
return values for them. It allows a unit test to assert that the given set
of function calls are eventually made, by awaiting on them to be called.
|