summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_appservice.py122
-rw-r--r--tests/handlers/test_space_summary.py185
-rw-r--r--tests/module_api/test_api.py10
-rw-r--r--tests/push/test_email.py20
-rw-r--r--tests/rest/client/v1/test_rooms.py92
-rw-r--r--tests/rest/client/v2_alpha/test_account.py33
-rw-r--r--tests/rest/client/v2_alpha/test_register.py12
-rw-r--r--tests/storage/databases/main/test_events_worker.py50
-rw-r--r--tests/test_federation.py6
9 files changed, 419 insertions, 111 deletions
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 024c5e963c..43998020b2 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -133,11 +133,131 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.assertEquals(result.room_id, room_id)
         self.assertEquals(result.servers, servers)
 
-    def _mkservice(self, is_interested):
+    def test_get_3pe_protocols_no_appservices(self):
+        self.mock_store.get_app_services.return_value = []
+        response = self.successResultOf(
+            defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
+        )
+        self.mock_as_api.get_3pe_protocol.assert_not_called()
+        self.assertEquals(response, {})
+
+    def test_get_3pe_protocols_no_protocols(self):
+        service = self._mkservice(False, [])
+        self.mock_store.get_app_services.return_value = [service]
+        response = self.successResultOf(
+            defer.ensureDeferred(self.handler.get_3pe_protocols())
+        )
+        self.mock_as_api.get_3pe_protocol.assert_not_called()
+        self.assertEquals(response, {})
+
+    def test_get_3pe_protocols_protocol_no_response(self):
+        service = self._mkservice(False, ["my-protocol"])
+        self.mock_store.get_app_services.return_value = [service]
+        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
+        response = self.successResultOf(
+            defer.ensureDeferred(self.handler.get_3pe_protocols())
+        )
+        self.mock_as_api.get_3pe_protocol.assert_called_once_with(
+            service, "my-protocol"
+        )
+        self.assertEquals(response, {})
+
+    def test_get_3pe_protocols_select_one_protocol(self):
+        service = self._mkservice(False, ["my-protocol"])
+        self.mock_store.get_app_services.return_value = [service]
+        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
+            {"x-protocol-data": 42, "instances": []}
+        )
+        response = self.successResultOf(
+            defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
+        )
+        self.mock_as_api.get_3pe_protocol.assert_called_once_with(
+            service, "my-protocol"
+        )
+        self.assertEquals(
+            response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
+        )
+
+    def test_get_3pe_protocols_one_protocol(self):
+        service = self._mkservice(False, ["my-protocol"])
+        self.mock_store.get_app_services.return_value = [service]
+        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
+            {"x-protocol-data": 42, "instances": []}
+        )
+        response = self.successResultOf(
+            defer.ensureDeferred(self.handler.get_3pe_protocols())
+        )
+        self.mock_as_api.get_3pe_protocol.assert_called_once_with(
+            service, "my-protocol"
+        )
+        self.assertEquals(
+            response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
+        )
+
+    def test_get_3pe_protocols_multiple_protocol(self):
+        service_one = self._mkservice(False, ["my-protocol"])
+        service_two = self._mkservice(False, ["other-protocol"])
+        self.mock_store.get_app_services.return_value = [service_one, service_two]
+        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
+            {"x-protocol-data": 42, "instances": []}
+        )
+        response = self.successResultOf(
+            defer.ensureDeferred(self.handler.get_3pe_protocols())
+        )
+        self.mock_as_api.get_3pe_protocol.assert_called()
+        self.assertEquals(
+            response,
+            {
+                "my-protocol": {"x-protocol-data": 42, "instances": []},
+                "other-protocol": {"x-protocol-data": 42, "instances": []},
+            },
+        )
+
+    def test_get_3pe_protocols_multiple_info(self):
+        service_one = self._mkservice(False, ["my-protocol"])
+        service_two = self._mkservice(False, ["my-protocol"])
+
+        async def get_3pe_protocol(service, unusedProtocol):
+            if service == service_one:
+                return {
+                    "x-protocol-data": 42,
+                    "instances": [{"desc": "Alice's service"}],
+                }
+            if service == service_two:
+                return {
+                    "x-protocol-data": 36,
+                    "x-not-used": 45,
+                    "instances": [{"desc": "Bob's service"}],
+                }
+            raise Exception("Unexpected service")
+
+        self.mock_store.get_app_services.return_value = [service_one, service_two]
+        self.mock_as_api.get_3pe_protocol = get_3pe_protocol
+        response = self.successResultOf(
+            defer.ensureDeferred(self.handler.get_3pe_protocols())
+        )
+        # It's expected that the second service's data doesn't appear in the response
+        self.assertEquals(
+            response,
+            {
+                "my-protocol": {
+                    "x-protocol-data": 42,
+                    "instances": [
+                        {
+                            "desc": "Alice's service",
+                        },
+                        {"desc": "Bob's service"},
+                    ],
+                },
+            },
+        )
+
+    def _mkservice(self, is_interested, protocols=None):
         service = Mock()
         service.is_interested.return_value = make_awaitable(is_interested)
         service.token = "mock_service_token"
         service.url = "mock_service_url"
+        service.protocols = protocols
         return service
 
     def _mkservice_alias(self, is_interested_in_alias):
diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py
index 01975c13d4..6cc1a02e12 100644
--- a/tests/handlers/test_space_summary.py
+++ b/tests/handlers/test_space_summary.py
@@ -26,7 +26,7 @@ from synapse.api.constants import (
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import make_event_from_dict
-from synapse.handlers.space_summary import _child_events_comparison_key
+from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEntry
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
 from synapse.server import HomeServer
@@ -351,26 +351,30 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             #   events before child events).
 
             # Note that these entries are brief, but should contain enough info.
-            rooms = [
-                {
-                    "room_id": subspace,
-                    "world_readable": True,
-                    "room_type": RoomTypes.SPACE,
-                },
-                {
-                    "room_id": subroom,
-                    "world_readable": True,
-                },
-            ]
-            event_content = {"via": [fed_hostname]}
-            events = [
-                {
-                    "room_id": subspace,
-                    "state_key": subroom,
-                    "content": event_content,
-                },
+            return [
+                _RoomEntry(
+                    subspace,
+                    {
+                        "room_id": subspace,
+                        "world_readable": True,
+                        "room_type": RoomTypes.SPACE,
+                    },
+                    [
+                        {
+                            "room_id": subspace,
+                            "state_key": subroom,
+                            "content": {"via": [fed_hostname]},
+                        }
+                    ],
+                ),
+                _RoomEntry(
+                    subroom,
+                    {
+                        "room_id": subroom,
+                        "world_readable": True,
+                    },
+                ),
             ]
-            return rooms, events
 
         # Add a room to the space which is on another server.
         self._add_child(self.space, subspace, self.token)
@@ -436,70 +440,95 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
         ):
             # Note that these entries are brief, but should contain enough info.
             rooms = [
-                {
-                    "room_id": public_room,
-                    "world_readable": False,
-                    "join_rules": JoinRules.PUBLIC,
-                },
-                {
-                    "room_id": knock_room,
-                    "world_readable": False,
-                    "join_rules": JoinRules.KNOCK,
-                },
-                {
-                    "room_id": not_invited_room,
-                    "world_readable": False,
-                    "join_rules": JoinRules.INVITE,
-                },
-                {
-                    "room_id": invited_room,
-                    "world_readable": False,
-                    "join_rules": JoinRules.INVITE,
-                },
-                {
-                    "room_id": restricted_room,
-                    "world_readable": False,
-                    "join_rules": JoinRules.RESTRICTED,
-                    "allowed_spaces": [],
-                },
-                {
-                    "room_id": restricted_accessible_room,
-                    "world_readable": False,
-                    "join_rules": JoinRules.RESTRICTED,
-                    "allowed_spaces": [self.room],
-                },
-                {
-                    "room_id": world_readable_room,
-                    "world_readable": True,
-                    "join_rules": JoinRules.INVITE,
-                },
-                {
-                    "room_id": joined_room,
-                    "world_readable": False,
-                    "join_rules": JoinRules.INVITE,
-                },
-            ]
-
-            # Place each room in the sub-space.
-            event_content = {"via": [fed_hostname]}
-            events = [
-                {
-                    "room_id": subspace,
-                    "state_key": room["room_id"],
-                    "content": event_content,
-                }
-                for room in rooms
+                _RoomEntry(
+                    public_room,
+                    {
+                        "room_id": public_room,
+                        "world_readable": False,
+                        "join_rules": JoinRules.PUBLIC,
+                    },
+                ),
+                _RoomEntry(
+                    knock_room,
+                    {
+                        "room_id": knock_room,
+                        "world_readable": False,
+                        "join_rules": JoinRules.KNOCK,
+                    },
+                ),
+                _RoomEntry(
+                    not_invited_room,
+                    {
+                        "room_id": not_invited_room,
+                        "world_readable": False,
+                        "join_rules": JoinRules.INVITE,
+                    },
+                ),
+                _RoomEntry(
+                    invited_room,
+                    {
+                        "room_id": invited_room,
+                        "world_readable": False,
+                        "join_rules": JoinRules.INVITE,
+                    },
+                ),
+                _RoomEntry(
+                    restricted_room,
+                    {
+                        "room_id": restricted_room,
+                        "world_readable": False,
+                        "join_rules": JoinRules.RESTRICTED,
+                        "allowed_spaces": [],
+                    },
+                ),
+                _RoomEntry(
+                    restricted_accessible_room,
+                    {
+                        "room_id": restricted_accessible_room,
+                        "world_readable": False,
+                        "join_rules": JoinRules.RESTRICTED,
+                        "allowed_spaces": [self.room],
+                    },
+                ),
+                _RoomEntry(
+                    world_readable_room,
+                    {
+                        "room_id": world_readable_room,
+                        "world_readable": True,
+                        "join_rules": JoinRules.INVITE,
+                    },
+                ),
+                _RoomEntry(
+                    joined_room,
+                    {
+                        "room_id": joined_room,
+                        "world_readable": False,
+                        "join_rules": JoinRules.INVITE,
+                    },
+                ),
             ]
 
             # Also include the subspace.
             rooms.insert(
                 0,
-                {
-                    "room_id": subspace,
-                    "world_readable": True,
-                },
+                _RoomEntry(
+                    subspace,
+                    {
+                        "room_id": subspace,
+                        "world_readable": True,
+                    },
+                    # Place each room in the sub-space.
+                    [
+                        {
+                            "room_id": subspace,
+                            "state_key": room.room_id,
+                            "content": {"via": [fed_hostname]},
+                        }
+                        for room in rooms
+                    ],
+                ),
             )
-            return rooms, events
+            return rooms
 
         # Add a room to the space which is on another server.
         self._add_child(self.space, subspace, self.token)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 81d9e2f484..0b817cc701 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase):
         displayname = self.get_success(self.store.get_profile_displayname("bob"))
         self.assertEqual(displayname, "Bobberino")
 
+    def test_get_userinfo_by_id(self):
+        user_id = self.register_user("alice", "1234")
+        found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        self.assertEqual(found_user.user_id.to_string(), user_id)
+        self.assertIdentical(found_user.is_admin, False)
+
+    def test_get_userinfo_by_id__no_user_found(self):
+        found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
+        self.assertIsNone(found_user)
+
     def test_sending_events_into_room(self):
         """Tests that a module can send events into a room"""
         # Mock out create_and_send_nonmember_event to check whether events are being sent
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index e04bc5c9a6..a487706758 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -45,14 +45,6 @@ class EmailPusherTests(HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
 
-        # List[Tuple[Deferred, args, kwargs]]
-        self.email_attempts = []
-
-        def sendmail(*args, **kwargs):
-            d = Deferred()
-            self.email_attempts.append((d, args, kwargs))
-            return d
-
         config = self.default_config()
         config["email"] = {
             "enable_notifs": True,
@@ -75,7 +67,17 @@ class EmailPusherTests(HomeserverTestCase):
         config["public_baseurl"] = "aaa"
         config["start_pushers"] = True
 
-        hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+        hs = self.setup_test_homeserver(config=config)
+
+        # List[Tuple[Deferred, args, kwargs]]
+        self.email_attempts = []
+
+        def sendmail(*args, **kwargs):
+            d = Deferred()
+            self.email_attempts.append((d, args, kwargs))
+            return d
+
+        hs.get_send_email_handler()._sendmail = sendmail
 
         return hs
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 3df070c936..1a9528ec20 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -19,11 +19,14 @@
 
 import json
 from typing import Iterable
-from unittest.mock import Mock
+from unittest.mock import Mock, call
 from urllib import parse as urlparse
 
+from twisted.internet import defer
+
 import synapse.rest.admin
 from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.errors import HttpResponseException
 from synapse.handlers.pagination import PurgeStatus
 from synapse.rest import admin
 from synapse.rest.client.v1 import directory, login, profile, room
@@ -1124,6 +1127,93 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
 
+class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
+    """Test that we correctly fallback to local filtering if a remote server
+    doesn't support search.
+    """
+
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(federation_client=Mock())
+
+    def prepare(self, reactor, clock, hs):
+        self.register_user("user", "pass")
+        self.token = self.login("user", "pass")
+
+        self.federation_client = hs.get_federation_client()
+
+    def test_simple(self):
+        "Simple test for searching rooms over federation"
+        self.federation_client.get_public_rooms.side_effect = (
+            lambda *a, **k: defer.succeed({})
+        )
+
+        search_filter = {"generic_search_term": "foobar"}
+
+        channel = self.make_request(
+            "POST",
+            b"/_matrix/client/r0/publicRooms?server=testserv",
+            content={"filter": search_filter},
+            access_token=self.token,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.federation_client.get_public_rooms.assert_called_once_with(
+            "testserv",
+            limit=100,
+            since_token=None,
+            search_filter=search_filter,
+            include_all_networks=False,
+            third_party_instance_id=None,
+        )
+
+    def test_fallback(self):
+        "Test that searching public rooms over federation falls back if it gets a 404"
+
+        # The `get_public_rooms` should be called again if the first call fails
+        # with a 404, when using search filters.
+        self.federation_client.get_public_rooms.side_effect = (
+            HttpResponseException(404, "Not Found", b""),
+            defer.succeed({}),
+        )
+
+        search_filter = {"generic_search_term": "foobar"}
+
+        channel = self.make_request(
+            "POST",
+            b"/_matrix/client/r0/publicRooms?server=testserv",
+            content={"filter": search_filter},
+            access_token=self.token,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.federation_client.get_public_rooms.assert_has_calls(
+            [
+                call(
+                    "testserv",
+                    limit=100,
+                    since_token=None,
+                    search_filter=search_filter,
+                    include_all_networks=False,
+                    third_party_instance_id=None,
+                ),
+                call(
+                    "testserv",
+                    limit=None,
+                    since_token=None,
+                    search_filter=None,
+                    include_all_networks=False,
+                    third_party_instance_id=None,
+                ),
+            ]
+        )
+
+
 class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
 
     servlets = [
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 317a2287e3..e7e617e9df 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -47,12 +47,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         config = self.default_config()
 
         # Email config.
-        self.email_attempts = []
-
-        async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
-            self.email_attempts.append(msg)
-            return
-
         config["email"] = {
             "enable_notifs": False,
             "template_dir": os.path.abspath(
@@ -67,7 +61,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         }
         config["public_baseurl"] = "https://example.com"
 
-        hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+        hs = self.setup_test_homeserver(config=config)
+
+        async def sendmail(
+            reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
+        ):
+            self.email_attempts.append(msg)
+
+        self.email_attempts = []
+        hs.get_send_email_handler()._sendmail = sendmail
+
         return hs
 
     def prepare(self, reactor, clock, hs):
@@ -511,11 +514,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         config = self.default_config()
 
         # Email config.
-        self.email_attempts = []
-
-        async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
-            self.email_attempts.append(msg)
-
         config["email"] = {
             "enable_notifs": False,
             "template_dir": os.path.abspath(
@@ -530,7 +528,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         }
         config["public_baseurl"] = "https://example.com"
 
-        self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+        self.hs = self.setup_test_homeserver(config=config)
+
+        async def sendmail(
+            reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
+        ):
+            self.email_attempts.append(msg)
+
+        self.email_attempts = []
+        self.hs.get_send_email_handler()._sendmail = sendmail
+
         return self.hs
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 1cad5f00eb..a52e5e608a 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -509,10 +509,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         }
 
         # Email config.
-        self.email_attempts = []
-
-        async def sendmail(*args, **kwargs):
-            self.email_attempts.append((args, kwargs))
 
         config["email"] = {
             "enable_notifs": True,
@@ -532,7 +528,13 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         }
         config["public_baseurl"] = "aaa"
 
-        self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+        self.hs = self.setup_test_homeserver(config=config)
+
+        async def sendmail(*args, **kwargs):
+            self.email_attempts.append((args, kwargs))
+
+        self.email_attempts = []
+        self.hs.get_send_email_handler()._sendmail = sendmail
 
         self.store = self.hs.get_datastore()
 
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 932970fd9a..d05d367685 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -14,7 +14,10 @@
 import json
 
 from synapse.logging.context import LoggingContext
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util.async_helpers import yieldable_gather_results
 
 from tests import unittest
 
@@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
             res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
             self.assertEquals(res, {"event10"})
             self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
+
+
+class EventCacheTestCase(unittest.HomeserverTestCase):
+    """Test that the various layers of event cache works."""
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store: EventsWorkerStore = hs.get_datastore()
+
+        self.user = self.register_user("user", "pass")
+        self.token = self.login(self.user, "pass")
+
+        self.room = self.helper.create_room_as(self.user, tok=self.token)
+
+        res = self.helper.send(self.room, tok=self.token)
+        self.event_id = res["event_id"]
+
+        # Reset the event cache so the tests start with it empty
+        self.store._get_event_cache.clear()
+
+    def test_simple(self):
+        """Test that we cache events that we pull from the DB."""
+
+        with LoggingContext("test") as ctx:
+            self.get_success(self.store.get_event(self.event_id))
+
+            # We should have fetched the event from the DB
+            self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+    def test_dedupe(self):
+        """Test that if we request the same event multiple times we only pull it
+        out once.
+        """
+
+        with LoggingContext("test") as ctx:
+            d = yieldable_gather_results(
+                self.store.get_event, [self.event_id, self.event_id]
+            )
+            self.get_success(d)
+
+            # We should have fetched the event from the DB
+            self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 0ed8326f55..3785799f46 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -75,10 +75,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         self.handler = self.homeserver.get_federation_handler()
-        self.handler._check_event_auth = (
-            lambda origin, event, context, state, auth_events, backfilled: succeed(
-                context
-            )
+        self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
+            context
         )
         self.client = self.homeserver.get_federation_client()
         self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(