summary refs log tree commit diff
path: root/synapse/storage/room.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-04-17 19:44:40 +0100
committerErik Johnston <erik@matrix.org>2019-04-17 19:44:40 +0100
commitca90336a6935b36b5761244005b0f68b496d5d79 (patch)
tree6bbce5eafc0db3b24ccc3b59b051da850382ae09 /synapse/storage/room.py
parentAdd management endpoints for account validity (diff)
parentMerge pull request #5047 from matrix-org/babolivier/account_expiration (diff)
downloadsynapse-ca90336a6935b36b5761244005b0f68b496d5d79.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into babolivier/account_expiration
Diffstat (limited to 'synapse/storage/room.py')
-rw-r--r--synapse/storage/room.py150
1 files changed, 66 insertions, 84 deletions
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index a979d4860a..fe9d79d792 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -30,13 +30,11 @@ logger = logging.getLogger(__name__)
 
 
 OpsLevel = collections.namedtuple(
-    "OpsLevel",
-    ("ban_level", "kick_level", "redact_level",)
+    "OpsLevel", ("ban_level", "kick_level", "redact_level")
 )
 
 RatelimitOverride = collections.namedtuple(
-    "RatelimitOverride",
-    ("messages_per_second", "burst_count",)
+    "RatelimitOverride", ("messages_per_second", "burst_count")
 )
 
 
@@ -60,9 +58,7 @@ class RoomWorkerStore(SQLBaseStore):
     def get_public_room_ids(self):
         return self._simple_select_onecol(
             table="rooms",
-            keyvalues={
-                "is_public": True,
-            },
+            keyvalues={"is_public": True},
             retcol="room_id",
             desc="get_public_room_ids",
         )
@@ -79,11 +75,11 @@ class RoomWorkerStore(SQLBaseStore):
         return self.runInteraction(
             "get_public_room_ids_at_stream_id",
             self.get_public_room_ids_at_stream_id_txn,
-            stream_id, network_tuple=network_tuple
+            stream_id,
+            network_tuple=network_tuple,
         )
 
-    def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
-                                             network_tuple):
+    def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple):
         return {
             rm
             for rm, vis in self.get_published_at_stream_id_txn(
@@ -96,7 +92,7 @@ class RoomWorkerStore(SQLBaseStore):
         if network_tuple:
             # We want to get from a particular list. No aggregation required.
 
-            sql = ("""
+            sql = """
                 SELECT room_id, visibility FROM public_room_list_stream
                 INNER JOIN (
                     SELECT room_id, max(stream_id) AS stream_id
@@ -104,25 +100,22 @@ class RoomWorkerStore(SQLBaseStore):
                     WHERE stream_id <= ? %s
                     GROUP BY room_id
                 ) grouped USING (room_id, stream_id)
-            """)
+            """
 
             if network_tuple.appservice_id is not None:
                 txn.execute(
                     sql % ("AND appservice_id = ? AND network_id = ?",),
-                    (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
+                    (stream_id, network_tuple.appservice_id, network_tuple.network_id),
                 )
             else:
-                txn.execute(
-                    sql % ("AND appservice_id IS NULL",),
-                    (stream_id,)
-                )
+                txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,))
             return dict(txn)
         else:
             # We want to get from all lists, so we need to aggregate the results
 
             logger.info("Executing full list")
 
-            sql = ("""
+            sql = """
                 SELECT room_id, visibility
                 FROM public_room_list_stream
                 INNER JOIN (
@@ -133,12 +126,9 @@ class RoomWorkerStore(SQLBaseStore):
                     WHERE stream_id <= ?
                     GROUP BY room_id, appservice_id, network_id
                 ) grouped USING (room_id, stream_id)
-            """)
+            """
 
-            txn.execute(
-                sql,
-                (stream_id,)
-            )
+            txn.execute(sql, (stream_id,))
 
             results = {}
             # A room is visible if its visible on any list.
@@ -147,8 +137,7 @@ class RoomWorkerStore(SQLBaseStore):
 
             return results
 
-    def get_public_room_changes(self, prev_stream_id, new_stream_id,
-                                network_tuple):
+    def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple):
         def get_public_room_changes_txn(txn):
             then_rooms = self.get_public_room_ids_at_stream_id_txn(
                 txn, prev_stream_id, network_tuple
@@ -158,9 +147,7 @@ class RoomWorkerStore(SQLBaseStore):
                 txn, new_stream_id, network_tuple
             )
 
-            now_rooms_visible = set(
-                rm for rm, vis in now_rooms_dict.items() if vis
-            )
+            now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
             now_rooms_not_visible = set(
                 rm for rm, vis in now_rooms_dict.items() if not vis
             )
@@ -178,9 +165,7 @@ class RoomWorkerStore(SQLBaseStore):
     def is_room_blocked(self, room_id):
         return self._simple_select_one_onecol(
             table="blocked_rooms",
-            keyvalues={
-                "room_id": room_id,
-            },
+            keyvalues={"room_id": room_id},
             retcol="1",
             allow_none=True,
             desc="is_room_blocked",
@@ -208,16 +193,17 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
         if row:
-            defer.returnValue(RatelimitOverride(
-                messages_per_second=row["messages_per_second"],
-                burst_count=row["burst_count"],
-            ))
+            defer.returnValue(
+                RatelimitOverride(
+                    messages_per_second=row["messages_per_second"],
+                    burst_count=row["burst_count"],
+                )
+            )
         else:
             defer.returnValue(None)
 
 
 class RoomStore(RoomWorkerStore, SearchStore):
-
     @defer.inlineCallbacks
     def store_room(self, room_id, room_creator_user_id, is_public):
         """Stores a room.
@@ -231,6 +217,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
             StoreError if the room could not be stored.
         """
         try:
+
             def store_room_txn(txn, next_id):
                 self._simple_insert_txn(
                     txn,
@@ -249,13 +236,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
                             "stream_id": next_id,
                             "room_id": room_id,
                             "visibility": is_public,
-                        }
+                        },
                     )
+
             with self._public_room_id_gen.get_next() as next_id:
-                yield self.runInteraction(
-                    "store_room_txn",
-                    store_room_txn, next_id,
-                )
+                yield self.runInteraction("store_room_txn", store_room_txn, next_id)
         except Exception as e:
             logger.error("store_room with room_id=%s failed: %s", room_id, e)
             raise StoreError(500, "Problem creating room.")
@@ -297,19 +282,19 @@ class RoomStore(RoomWorkerStore, SearchStore):
                         "visibility": is_public,
                         "appservice_id": None,
                         "network_id": None,
-                    }
+                    },
                 )
 
         with self._public_room_id_gen.get_next() as next_id:
             yield self.runInteraction(
-                "set_room_is_public",
-                set_room_is_public_txn, next_id,
+                "set_room_is_public", set_room_is_public_txn, next_id
             )
         self.hs.get_notifier().on_new_replication_data()
 
     @defer.inlineCallbacks
-    def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
-                                      is_public):
+    def set_room_is_public_appservice(
+        self, room_id, appservice_id, network_id, is_public
+    ):
         """Edit the appservice/network specific public room list.
 
         Each appservice can have a number of published room lists associated
@@ -324,6 +309,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
             is_public (bool): Whether to publish or unpublish the room from the
                 list.
         """
+
         def set_room_is_public_appservice_txn(txn, next_id):
             if is_public:
                 try:
@@ -333,7 +319,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                         values={
                             "appservice_id": appservice_id,
                             "network_id": network_id,
-                            "room_id": room_id
+                            "room_id": room_id,
                         },
                     )
                 except self.database_engine.module.IntegrityError:
@@ -346,7 +332,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                     keyvalues={
                         "appservice_id": appservice_id,
                         "network_id": network_id,
-                        "room_id": room_id
+                        "room_id": room_id,
                     },
                 )
 
@@ -377,13 +363,14 @@ class RoomStore(RoomWorkerStore, SearchStore):
                         "visibility": is_public,
                         "appservice_id": appservice_id,
                         "network_id": network_id,
-                    }
+                    },
                 )
 
         with self._public_room_id_gen.get_next() as next_id:
             yield self.runInteraction(
                 "set_room_is_public_appservice",
-                set_room_is_public_appservice_txn, next_id,
+                set_room_is_public_appservice_txn,
+                next_id,
             )
         self.hs.get_notifier().on_new_replication_data()
 
@@ -397,9 +384,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
             row = txn.fetchone()
             return row[0] or 0
 
-        return self.runInteraction(
-            "get_rooms", f
-        )
+        return self.runInteraction("get_rooms", f)
 
     def _store_room_topic_txn(self, txn, event):
         if hasattr(event, "content") and "topic" in event.content:
@@ -414,7 +399,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
             )
 
             self.store_event_search_txn(
-                txn, event, "content.topic", event.content["topic"],
+                txn, event, "content.topic", event.content["topic"]
             )
 
     def _store_room_name_txn(self, txn, event):
@@ -426,17 +411,17 @@ class RoomStore(RoomWorkerStore, SearchStore):
                     "event_id": event.event_id,
                     "room_id": event.room_id,
                     "name": event.content["name"],
-                }
+                },
             )
 
             self.store_event_search_txn(
-                txn, event, "content.name", event.content["name"],
+                txn, event, "content.name", event.content["name"]
             )
 
     def _store_room_message_txn(self, txn, event):
         if hasattr(event, "content") and "body" in event.content:
             self.store_event_search_txn(
-                txn, event, "content.body", event.content["body"],
+                txn, event, "content.body", event.content["body"]
             )
 
     def _store_history_visibility_txn(self, txn, event):
@@ -452,14 +437,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
                 " (event_id, room_id, %(key)s)"
                 " VALUES (?, ?, ?)" % {"key": key}
             )
-            txn.execute(sql, (
-                event.event_id,
-                event.room_id,
-                event.content[key]
-            ))
-
-    def add_event_report(self, room_id, event_id, user_id, reason, content,
-                         received_ts):
+            txn.execute(sql, (event.event_id, event.room_id, event.content[key]))
+
+    def add_event_report(
+        self, room_id, event_id, user_id, reason, content, received_ts
+    ):
         next_id = self._event_reports_id_gen.get_next()
         return self._simple_insert(
             table="event_reports",
@@ -472,7 +454,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                 "reason": reason,
                 "content": json.dumps(content),
             },
-            desc="add_event_report"
+            desc="add_event_report",
         )
 
     def get_current_public_room_stream_id(self):
@@ -480,23 +462,21 @@ class RoomStore(RoomWorkerStore, SearchStore):
 
     def get_all_new_public_rooms(self, prev_id, current_id, limit):
         def get_all_new_public_rooms(txn):
-            sql = ("""
+            sql = """
                 SELECT stream_id, room_id, visibility, appservice_id, network_id
                 FROM public_room_list_stream
                 WHERE stream_id > ? AND stream_id <= ?
                 ORDER BY stream_id ASC
                 LIMIT ?
-            """)
+            """
 
-            txn.execute(sql, (prev_id, current_id, limit,))
+            txn.execute(sql, (prev_id, current_id, limit))
             return txn.fetchall()
 
         if prev_id == current_id:
             return defer.succeed([])
 
-        return self.runInteraction(
-            "get_all_new_public_rooms", get_all_new_public_rooms
-        )
+        return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
 
     @defer.inlineCallbacks
     def block_room(self, room_id, user_id):
@@ -511,19 +491,16 @@ class RoomStore(RoomWorkerStore, SearchStore):
         """
         yield self._simple_upsert(
             table="blocked_rooms",
-            keyvalues={
-                "room_id": room_id,
-            },
+            keyvalues={"room_id": room_id},
             values={},
-            insertion_values={
-                "user_id": user_id,
-            },
+            insertion_values={"user_id": user_id},
             desc="block_room",
         )
         yield self.runInteraction(
             "block_room_invalidation",
             self._invalidate_cache_and_stream,
-            self.is_room_blocked, (room_id,),
+            self.is_room_blocked,
+            (room_id,),
         )
 
     def get_media_mxcs_in_room(self, room_id):
@@ -536,6 +513,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
             The local and remote media as a lists of tuples where the key is
             the hostname and the value is the media ID.
         """
+
         def _get_media_mxcs_in_room_txn(txn):
             local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
             local_media_mxcs = []
@@ -548,23 +526,28 @@ class RoomStore(RoomWorkerStore, SearchStore):
                 remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
 
             return local_media_mxcs, remote_media_mxcs
+
         return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
 
     def quarantine_media_ids_in_room(self, room_id, quarantined_by):
         """For a room loops through all events with media and quarantines
         the associated media
         """
+
         def _quarantine_media_in_room_txn(txn):
             local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
             total_media_quarantined = 0
 
             # Now update all the tables to set the quarantined_by flag
 
-            txn.executemany("""
+            txn.executemany(
+                """
                 UPDATE local_media_repository
                 SET quarantined_by = ?
                 WHERE media_id = ?
-            """, ((quarantined_by, media_id) for media_id in local_mxcs))
+            """,
+                ((quarantined_by, media_id) for media_id in local_mxcs),
+            )
 
             txn.executemany(
                 """
@@ -575,7 +558,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                 (
                     (quarantined_by, origin, media_id)
                     for origin, media_id in remote_mxcs
-                )
+                ),
             )
 
             total_media_quarantined += len(local_mxcs)
@@ -584,8 +567,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
             return total_media_quarantined
 
         return self.runInteraction(
-            "quarantine_media_in_room",
-            _quarantine_media_in_room_txn,
+            "quarantine_media_in_room", _quarantine_media_in_room_txn
         )
 
     def _get_media_mxcs_in_room_txn(self, txn, room_id):