summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/v1/test_rooms.py92
1 files changed, 91 insertions, 1 deletions
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 3df070c936..1a9528ec20 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -19,11 +19,14 @@
 
 import json
 from typing import Iterable
-from unittest.mock import Mock
+from unittest.mock import Mock, call
 from urllib import parse as urlparse
 
+from twisted.internet import defer
+
 import synapse.rest.admin
 from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.errors import HttpResponseException
 from synapse.handlers.pagination import PurgeStatus
 from synapse.rest import admin
 from synapse.rest.client.v1 import directory, login, profile, room
@@ -1124,6 +1127,93 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
 
+class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
+    """Test that we correctly fallback to local filtering if a remote server
+    doesn't support search.
+    """
+
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(federation_client=Mock())
+
+    def prepare(self, reactor, clock, hs):
+        self.register_user("user", "pass")
+        self.token = self.login("user", "pass")
+
+        self.federation_client = hs.get_federation_client()
+
+    def test_simple(self):
+        "Simple test for searching rooms over federation"
+        self.federation_client.get_public_rooms.side_effect = (
+            lambda *a, **k: defer.succeed({})
+        )
+
+        search_filter = {"generic_search_term": "foobar"}
+
+        channel = self.make_request(
+            "POST",
+            b"/_matrix/client/r0/publicRooms?server=testserv",
+            content={"filter": search_filter},
+            access_token=self.token,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.federation_client.get_public_rooms.assert_called_once_with(
+            "testserv",
+            limit=100,
+            since_token=None,
+            search_filter=search_filter,
+            include_all_networks=False,
+            third_party_instance_id=None,
+        )
+
+    def test_fallback(self):
+        "Test that searching public rooms over federation falls back if it gets a 404"
+
+        # The `get_public_rooms` should be called again if the first call fails
+        # with a 404, when using search filters.
+        self.federation_client.get_public_rooms.side_effect = (
+            HttpResponseException(404, "Not Found", b""),
+            defer.succeed({}),
+        )
+
+        search_filter = {"generic_search_term": "foobar"}
+
+        channel = self.make_request(
+            "POST",
+            b"/_matrix/client/r0/publicRooms?server=testserv",
+            content={"filter": search_filter},
+            access_token=self.token,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.federation_client.get_public_rooms.assert_has_calls(
+            [
+                call(
+                    "testserv",
+                    limit=100,
+                    since_token=None,
+                    search_filter=search_filter,
+                    include_all_networks=False,
+                    third_party_instance_id=None,
+                ),
+                call(
+                    "testserv",
+                    limit=None,
+                    since_token=None,
+                    search_filter=None,
+                    include_all_networks=False,
+                    third_party_instance_id=None,
+                ),
+            ]
+        )
+
+
 class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
 
     servlets = [