diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/databases/main/room.py | 43 |
1 files changed, 26 insertions, 17 deletions
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7412bce255..e41c99027a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -207,21 +207,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _construct_room_type_where_clause( self, room_types: Union[List[Union[str, None]], None] - ) -> Tuple[Union[str, None], List[str]]: + ) -> Tuple[Union[str, None], list]: if not room_types: return None, [] - else: - # We use None when we want get rooms without a type - is_null_clause = "" - if None in room_types: - is_null_clause = "OR room_type IS NULL" - room_types = [value for value in room_types if value is not None] + # Since None is used to represent a room without a type, care needs to + # be taken into account when constructing the where clause. + clauses = [] + args: list = [] + + room_types_set = set(room_types) + + # We use None to represent a room without a type. + if None in room_types_set: + clauses.append("room_type IS NULL") + room_types_set.remove(None) + + # If there are other room types, generate the proper clause. + if room_types: list_clause, args = make_in_list_sql_clause( - self.database_engine, "room_type", room_types + self.database_engine, "room_type", room_types_set ) + clauses.append(list_clause) - return f"({list_clause} {is_null_clause})", args + return f"({' OR '.join(clauses)})", args async def count_public_rooms( self, @@ -241,14 +250,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _count_public_rooms_txn(txn: LoggingTransaction) -> int: query_args = [] - room_type_clause, args = self._construct_room_type_where_clause( - search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) - if search_filter - else None - ) - room_type_clause = f" AND {room_type_clause}" if room_type_clause else "" - query_args += args - if network_tuple: if network_tuple.appservice_id: published_sql = """ @@ -268,6 +269,14 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): UNION SELECT room_id from appservice_room_list """ + room_type_clause, args = self._construct_room_type_where_clause( + search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) + if search_filter + else None + ) + room_type_clause = f" AND {room_type_clause}" if room_type_clause else "" + query_args += args + sql = f""" SELECT COUNT(*) |