diff --git a/changelog.d/14053.bugfix b/changelog.d/14053.bugfix
new file mode 100644
index 0000000000..07769f51d0
--- /dev/null
+++ b/changelog.d/14053.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 1.53.0 when querying `/publicRooms` with both a `room_type` filter and a `third_party_instance_id`.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7412bce255..e41c99027a 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -207,21 +207,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _construct_room_type_where_clause(
self, room_types: Union[List[Union[str, None]], None]
- ) -> Tuple[Union[str, None], List[str]]:
+ ) -> Tuple[Union[str, None], list]:
if not room_types:
return None, []
- else:
- # We use None when we want get rooms without a type
- is_null_clause = ""
- if None in room_types:
- is_null_clause = "OR room_type IS NULL"
- room_types = [value for value in room_types if value is not None]
+ # Since None is used to represent a room without a type, care needs to
+ # be taken into account when constructing the where clause.
+ clauses = []
+ args: list = []
+
+ room_types_set = set(room_types)
+
+ # We use None to represent a room without a type.
+ if None in room_types_set:
+ clauses.append("room_type IS NULL")
+ room_types_set.remove(None)
+
+ # If there are other room types, generate the proper clause.
+ if room_types:
list_clause, args = make_in_list_sql_clause(
- self.database_engine, "room_type", room_types
+ self.database_engine, "room_type", room_types_set
)
+ clauses.append(list_clause)
- return f"({list_clause} {is_null_clause})", args
+ return f"({' OR '.join(clauses)})", args
async def count_public_rooms(
self,
@@ -241,14 +250,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = []
- room_type_clause, args = self._construct_room_type_where_clause(
- search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
- if search_filter
- else None
- )
- room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
- query_args += args
-
if network_tuple:
if network_tuple.appservice_id:
published_sql = """
@@ -268,6 +269,14 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
UNION SELECT room_id from appservice_room_list
"""
+ room_type_clause, args = self._construct_room_type_where_clause(
+ search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+ if search_filter
+ else None
+ )
+ room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+ query_args += args
+
sql = f"""
SELECT
COUNT(*)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 5e66b5b26c..3612ebe7b9 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -2213,14 +2213,17 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
)
def make_public_rooms_request(
- self, room_types: Union[List[Union[str, None]], None]
+ self,
+ room_types: Optional[List[Union[str, None]]],
+ instance_id: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
- channel = self.make_request(
- "POST",
- self.url,
- {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
- self.token,
- )
+ body: JsonDict = {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}
+ if instance_id:
+ body["third_party_instance_id"] = "test|test"
+
+ channel = self.make_request("POST", self.url, body, self.token)
+ self.assertEqual(channel.code, 200)
+
chunk = channel.json_body["chunk"]
count = channel.json_body["total_room_count_estimate"]
@@ -2230,31 +2233,49 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
chunk, count = self.make_public_rooms_request(None)
-
self.assertEqual(count, 2)
+ # Also check if there's no filter property at all in the body.
+ channel = self.make_request("POST", self.url, {}, self.token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["chunk"]), 2)
+ self.assertEqual(channel.json_body["total_room_count_estimate"], 2)
+
+ chunk, count = self.make_public_rooms_request(None, "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_only_rooms_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request([None])
self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("room_type", None), None)
+ chunk, count = self.make_public_rooms_request([None], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_only_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space"])
self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("room_type", None), "m.space")
+ chunk, count = self.make_public_rooms_request(["m.space"], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space", None])
-
self.assertEqual(count, 2)
+ chunk, count = self.make_public_rooms_request(["m.space", None], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
chunk, count = self.make_public_rooms_request([])
-
self.assertEqual(count, 2)
+ chunk, count = self.make_public_rooms_request([], "test|test")
+ self.assertEqual(count, 0)
+
class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
"""Test that we correctly fallback to local filtering if a remote server
|