summary refs log tree commit diff
path: root/synapse/storage/room.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/room.py')
-rw-r--r--synapse/storage/room.py171
1 files changed, 151 insertions, 20 deletions
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 11813b44f6..4b3605c776 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -106,7 +106,11 @@ class RoomStore(SQLBaseStore):
             entries = self._simple_select_list_txn(
                 txn,
                 table="public_room_list_stream",
-                keyvalues={"room_id": room_id},
+                keyvalues={
+                    "room_id": room_id,
+                    "appservice_id": None,
+                    "network_id": None,
+                },
                 retcols=("stream_id", "visibility"),
             )
 
@@ -124,6 +128,8 @@ class RoomStore(SQLBaseStore):
                         "stream_id": next_id,
                         "room_id": room_id,
                         "visibility": is_public,
+                        "appservice_id": None,
+                        "network_id": None,
                     }
                 )
 
@@ -133,6 +139,73 @@ class RoomStore(SQLBaseStore):
                 set_room_is_public_txn, next_id,
             )
 
+    @defer.inlineCallbacks
+    def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
+                                      is_public):
+        """Edit the appservice/network specific public room list.
+        """
+        def set_room_is_public_appservice_txn(txn, next_id):
+            if is_public:
+                try:
+                    self._simple_insert_txn(
+                        txn,
+                        table="appservice_room_list",
+                        values={
+                            "appservice_id": appservice_id,
+                            "network_id": "network_id",
+                            "room_id": room_id
+                        },
+                    )
+                except self.database_engine.module.IntegrityError:
+                    # We've already inserted, nothing to do.
+                    return
+            else:
+                self._simple_delete_txn(
+                    txn,
+                    table="appservice_room_list",
+                    keyvalues={
+                        "appservice_id": appservice_id,
+                        "network_id": network_id,
+                        "room_id": room_id
+                    },
+                )
+
+            entries = self._simple_select_list_txn(
+                txn,
+                table="public_room_list_stream",
+                keyvalues={
+                    "room_id": room_id,
+                    "appservice_id": appservice_id,
+                    "network_id": network_id,
+                },
+                retcols=("stream_id", "visibility"),
+            )
+
+            entries.sort(key=lambda r: r["stream_id"])
+
+            add_to_stream = True
+            if entries:
+                add_to_stream = bool(entries[-1]["visibility"]) != is_public
+
+            if add_to_stream:
+                self._simple_insert_txn(
+                    txn,
+                    table="public_room_list_stream",
+                    values={
+                        "stream_id": next_id,
+                        "room_id": room_id,
+                        "visibility": is_public,
+                        "appservice_id": appservice_id,
+                        "network_id": network_id,
+                    }
+                )
+
+        with self._public_room_id_gen.get_next() as next_id:
+            yield self.runInteraction(
+                "set_room_is_public_appservice",
+                set_room_is_public_appservice_txn, next_id,
+            )
+
     def get_public_room_ids(self):
         return self._simple_select_onecol(
             table="rooms",
@@ -259,38 +332,95 @@ class RoomStore(SQLBaseStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    def get_public_room_ids_at_stream_id(self, stream_id):
+    def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
+        """Get pulbic rooms for a particular list, or across all lists.
+
+        Args:
+            stream_id (int)
+            network_tuple (ThirdPartyInstanceID): The list to use (None, None)
+                means the main list, None means all lsits.
+        """
         return self.runInteraction(
             "get_public_room_ids_at_stream_id",
-            self.get_public_room_ids_at_stream_id_txn, stream_id
+            self.get_public_room_ids_at_stream_id_txn,
+            stream_id, network_tuple=network_tuple
         )
 
-    def get_public_room_ids_at_stream_id_txn(self, txn, stream_id):
+    def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
+                                             network_tuple):
         return {
             rm
-            for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items()
+            for rm, vis in self.get_published_at_stream_id_txn(
+                txn, stream_id, network_tuple=network_tuple
+            ).items()
             if vis
         }
 
-    def get_published_at_stream_id_txn(self, txn, stream_id):
-        sql = ("""
-            SELECT room_id, visibility FROM public_room_list_stream
-            INNER JOIN (
-                SELECT room_id, max(stream_id) AS stream_id
+    def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
+        if network_tuple:
+            # We want to get from a particular list. No aggregation required.
+
+            sql = ("""
+                SELECT room_id, visibility FROM public_room_list_stream
+                INNER JOIN (
+                    SELECT room_id, max(stream_id) AS stream_id
+                    FROM public_room_list_stream
+                    WHERE stream_id <= ? %s
+                    GROUP BY room_id
+                ) grouped USING (room_id, stream_id)
+            """)
+
+            if network_tuple.appservice_id is not None:
+                txn.execute(
+                    sql % ("AND appservice_id = ? AND network_id = ?",),
+                    (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
+                )
+            else:
+                txn.execute(
+                    sql % ("AND appservice_id IS NULL",),
+                    (stream_id,)
+                )
+            return dict(txn.fetchall())
+        else:
+            # We want to get from all lists, so we need to aggregate the results
+
+            logger.info("Executing full list")
+
+            sql = ("""
+                SELECT room_id, visibility
                 FROM public_room_list_stream
-                WHERE stream_id <= ?
-                GROUP BY room_id
-            ) grouped USING (room_id, stream_id)
-        """)
+                INNER JOIN (
+                    SELECT
+                        room_id, max(stream_id) AS stream_id, appservice_id,
+                        network_id
+                    FROM public_room_list_stream
+                    WHERE stream_id <= ?
+                    GROUP BY room_id, appservice_id, network_id
+                ) grouped USING (room_id, stream_id)
+            """)
 
-        txn.execute(sql, (stream_id,))
-        return dict(txn.fetchall())
+            txn.execute(
+                sql,
+                (stream_id,)
+            )
+
+            results = {}
+            # A room is visible if its visible on any list.
+            for room_id, visibility in txn.fetchall():
+                results[room_id] = bool(visibility) or results.get(room_id, False)
+
+            return results
 
-    def get_public_room_changes(self, prev_stream_id, new_stream_id):
+    def get_public_room_changes(self, prev_stream_id, new_stream_id,
+                                network_tuple):
         def get_public_room_changes_txn(txn):
-            then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id)
+            then_rooms = self.get_public_room_ids_at_stream_id_txn(
+                txn, prev_stream_id, network_tuple
+            )
 
-            now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id)
+            now_rooms_dict = self.get_published_at_stream_id_txn(
+                txn, new_stream_id, network_tuple
+            )
 
             now_rooms_visible = set(
                 rm for rm, vis in now_rooms_dict.items() if vis
@@ -311,7 +441,8 @@ class RoomStore(SQLBaseStore):
     def get_all_new_public_rooms(self, prev_id, current_id, limit):
         def get_all_new_public_rooms(txn):
             sql = ("""
-                SELECT stream_id, room_id, visibility FROM public_room_list_stream
+                SELECT stream_id, room_id, visibility, appservice_id, network_id
+                FROM public_room_list_stream
                 WHERE stream_id > ? AND stream_id <= ?
                 ORDER BY stream_id ASC
                 LIMIT ?