summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/14053.bugfix1
-rw-r--r--synapse/storage/databases/main/room.py43
-rw-r--r--tests/rest/client/test_rooms.py41
3 files changed, 58 insertions, 27 deletions
diff --git a/changelog.d/14053.bugfix b/changelog.d/14053.bugfix
new file mode 100644
index 0000000000..07769f51d0
--- /dev/null
+++ b/changelog.d/14053.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 1.53.0 when querying `/publicRooms` with both a `room_type` filter and a `third_party_instance_id`.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7412bce255..e41c99027a 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -207,21 +207,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
     def _construct_room_type_where_clause(
         self, room_types: Union[List[Union[str, None]], None]
-    ) -> Tuple[Union[str, None], List[str]]:
+    ) -> Tuple[Union[str, None], list]:
         if not room_types:
             return None, []
-        else:
-            # We use None when we want get rooms without a type
-            is_null_clause = ""
-            if None in room_types:
-                is_null_clause = "OR room_type IS NULL"
-                room_types = [value for value in room_types if value is not None]
 
+        # Since None is used to represent a room without a type, care needs to
+        # be taken into account when constructing the where clause.
+        clauses = []
+        args: list = []
+
+        room_types_set = set(room_types)
+
+        # We use None to represent a room without a type.
+        if None in room_types_set:
+            clauses.append("room_type IS NULL")
+            room_types_set.remove(None)
+
+        # If there are other room types, generate the proper clause.
+        if room_types:
             list_clause, args = make_in_list_sql_clause(
-                self.database_engine, "room_type", room_types
+                self.database_engine, "room_type", room_types_set
             )
+            clauses.append(list_clause)
 
-            return f"({list_clause} {is_null_clause})", args
+        return f"({' OR '.join(clauses)})", args
 
     async def count_public_rooms(
         self,
@@ -241,14 +250,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
             query_args = []
 
-            room_type_clause, args = self._construct_room_type_where_clause(
-                search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
-                if search_filter
-                else None
-            )
-            room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
-            query_args += args
-
             if network_tuple:
                 if network_tuple.appservice_id:
                     published_sql = """
@@ -268,6 +269,14 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                     UNION SELECT room_id from appservice_room_list
             """
 
+            room_type_clause, args = self._construct_room_type_where_clause(
+                search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+                if search_filter
+                else None
+            )
+            room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+            query_args += args
+
             sql = f"""
                 SELECT
                     COUNT(*)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 5e66b5b26c..3612ebe7b9 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -2213,14 +2213,17 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
         )
 
     def make_public_rooms_request(
-        self, room_types: Union[List[Union[str, None]], None]
+        self,
+        room_types: Optional[List[Union[str, None]]],
+        instance_id: Optional[str] = None,
     ) -> Tuple[List[Dict[str, Any]], int]:
-        channel = self.make_request(
-            "POST",
-            self.url,
-            {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
-            self.token,
-        )
+        body: JsonDict = {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}
+        if instance_id:
+            body["third_party_instance_id"] = "test|test"
+
+        channel = self.make_request("POST", self.url, body, self.token)
+        self.assertEqual(channel.code, 200)
+
         chunk = channel.json_body["chunk"]
         count = channel.json_body["total_room_count_estimate"]
 
@@ -2230,31 +2233,49 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
 
     def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
         chunk, count = self.make_public_rooms_request(None)
-
         self.assertEqual(count, 2)
 
+        # Also check if there's no filter property at all in the body.
+        channel = self.make_request("POST", self.url, {}, self.token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["chunk"]), 2)
+        self.assertEqual(channel.json_body["total_room_count_estimate"], 2)
+
+        chunk, count = self.make_public_rooms_request(None, "test|test")
+        self.assertEqual(count, 0)
+
     def test_returns_only_rooms_based_on_filter(self) -> None:
         chunk, count = self.make_public_rooms_request([None])
 
         self.assertEqual(count, 1)
         self.assertEqual(chunk[0].get("room_type", None), None)
 
+        chunk, count = self.make_public_rooms_request([None], "test|test")
+        self.assertEqual(count, 0)
+
     def test_returns_only_space_based_on_filter(self) -> None:
         chunk, count = self.make_public_rooms_request(["m.space"])
 
         self.assertEqual(count, 1)
         self.assertEqual(chunk[0].get("room_type", None), "m.space")
 
+        chunk, count = self.make_public_rooms_request(["m.space"], "test|test")
+        self.assertEqual(count, 0)
+
     def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
         chunk, count = self.make_public_rooms_request(["m.space", None])
-
         self.assertEqual(count, 2)
 
+        chunk, count = self.make_public_rooms_request(["m.space", None], "test|test")
+        self.assertEqual(count, 0)
+
     def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
         chunk, count = self.make_public_rooms_request([])
-
         self.assertEqual(count, 2)
 
+        chunk, count = self.make_public_rooms_request([], "test|test")
+        self.assertEqual(count, 0)
+
 
 class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
     """Test that we correctly fallback to local filtering if a remote server