diff options
Diffstat (limited to 'synapse/storage/room.py')
-rw-r--r-- | synapse/storage/room.py | 150 |
1 files changed, 66 insertions, 84 deletions
diff --git a/synapse/storage/room.py b/synapse/storage/room.py index a979d4860a..fe9d79d792 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -30,13 +30,11 @@ logger = logging.getLogger(__name__) OpsLevel = collections.namedtuple( - "OpsLevel", - ("ban_level", "kick_level", "redact_level",) + "OpsLevel", ("ban_level", "kick_level", "redact_level") ) RatelimitOverride = collections.namedtuple( - "RatelimitOverride", - ("messages_per_second", "burst_count",) + "RatelimitOverride", ("messages_per_second", "burst_count") ) @@ -60,9 +58,7 @@ class RoomWorkerStore(SQLBaseStore): def get_public_room_ids(self): return self._simple_select_onecol( table="rooms", - keyvalues={ - "is_public": True, - }, + keyvalues={"is_public": True}, retcol="room_id", desc="get_public_room_ids", ) @@ -79,11 +75,11 @@ class RoomWorkerStore(SQLBaseStore): return self.runInteraction( "get_public_room_ids_at_stream_id", self.get_public_room_ids_at_stream_id_txn, - stream_id, network_tuple=network_tuple + stream_id, + network_tuple=network_tuple, ) - def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, - network_tuple): + def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple): return { rm for rm, vis in self.get_published_at_stream_id_txn( @@ -96,7 +92,7 @@ class RoomWorkerStore(SQLBaseStore): if network_tuple: # We want to get from a particular list. No aggregation required. - sql = (""" + sql = """ SELECT room_id, visibility FROM public_room_list_stream INNER JOIN ( SELECT room_id, max(stream_id) AS stream_id @@ -104,25 +100,22 @@ class RoomWorkerStore(SQLBaseStore): WHERE stream_id <= ? %s GROUP BY room_id ) grouped USING (room_id, stream_id) - """) + """ if network_tuple.appservice_id is not None: txn.execute( sql % ("AND appservice_id = ? AND network_id = ?",), - (stream_id, network_tuple.appservice_id, network_tuple.network_id,) + (stream_id, network_tuple.appservice_id, network_tuple.network_id), ) else: - txn.execute( - sql % ("AND appservice_id IS NULL",), - (stream_id,) - ) + txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,)) return dict(txn) else: # We want to get from all lists, so we need to aggregate the results logger.info("Executing full list") - sql = (""" + sql = """ SELECT room_id, visibility FROM public_room_list_stream INNER JOIN ( @@ -133,12 +126,9 @@ class RoomWorkerStore(SQLBaseStore): WHERE stream_id <= ? GROUP BY room_id, appservice_id, network_id ) grouped USING (room_id, stream_id) - """) + """ - txn.execute( - sql, - (stream_id,) - ) + txn.execute(sql, (stream_id,)) results = {} # A room is visible if its visible on any list. @@ -147,8 +137,7 @@ class RoomWorkerStore(SQLBaseStore): return results - def get_public_room_changes(self, prev_stream_id, new_stream_id, - network_tuple): + def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple): def get_public_room_changes_txn(txn): then_rooms = self.get_public_room_ids_at_stream_id_txn( txn, prev_stream_id, network_tuple @@ -158,9 +147,7 @@ class RoomWorkerStore(SQLBaseStore): txn, new_stream_id, network_tuple ) - now_rooms_visible = set( - rm for rm, vis in now_rooms_dict.items() if vis - ) + now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis) now_rooms_not_visible = set( rm for rm, vis in now_rooms_dict.items() if not vis ) @@ -178,9 +165,7 @@ class RoomWorkerStore(SQLBaseStore): def is_room_blocked(self, room_id): return self._simple_select_one_onecol( table="blocked_rooms", - keyvalues={ - "room_id": room_id, - }, + keyvalues={"room_id": room_id}, retcol="1", allow_none=True, desc="is_room_blocked", @@ -208,16 +193,17 @@ class RoomWorkerStore(SQLBaseStore): ) if row: - defer.returnValue(RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - )) + defer.returnValue( + RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + ) + ) else: defer.returnValue(None) class RoomStore(RoomWorkerStore, SearchStore): - @defer.inlineCallbacks def store_room(self, room_id, room_creator_user_id, is_public): """Stores a room. @@ -231,6 +217,7 @@ class RoomStore(RoomWorkerStore, SearchStore): StoreError if the room could not be stored. """ try: + def store_room_txn(txn, next_id): self._simple_insert_txn( txn, @@ -249,13 +236,11 @@ class RoomStore(RoomWorkerStore, SearchStore): "stream_id": next_id, "room_id": room_id, "visibility": is_public, - } + }, ) + with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( - "store_room_txn", - store_room_txn, next_id, - ) + yield self.runInteraction("store_room_txn", store_room_txn, next_id) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -297,19 +282,19 @@ class RoomStore(RoomWorkerStore, SearchStore): "visibility": is_public, "appservice_id": None, "network_id": None, - } + }, ) with self._public_room_id_gen.get_next() as next_id: yield self.runInteraction( - "set_room_is_public", - set_room_is_public_txn, next_id, + "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() @defer.inlineCallbacks - def set_room_is_public_appservice(self, room_id, appservice_id, network_id, - is_public): + def set_room_is_public_appservice( + self, room_id, appservice_id, network_id, is_public + ): """Edit the appservice/network specific public room list. Each appservice can have a number of published room lists associated @@ -324,6 +309,7 @@ class RoomStore(RoomWorkerStore, SearchStore): is_public (bool): Whether to publish or unpublish the room from the list. """ + def set_room_is_public_appservice_txn(txn, next_id): if is_public: try: @@ -333,7 +319,7 @@ class RoomStore(RoomWorkerStore, SearchStore): values={ "appservice_id": appservice_id, "network_id": network_id, - "room_id": room_id + "room_id": room_id, }, ) except self.database_engine.module.IntegrityError: @@ -346,7 +332,7 @@ class RoomStore(RoomWorkerStore, SearchStore): keyvalues={ "appservice_id": appservice_id, "network_id": network_id, - "room_id": room_id + "room_id": room_id, }, ) @@ -377,13 +363,14 @@ class RoomStore(RoomWorkerStore, SearchStore): "visibility": is_public, "appservice_id": appservice_id, "network_id": network_id, - } + }, ) with self._public_room_id_gen.get_next() as next_id: yield self.runInteraction( "set_room_is_public_appservice", - set_room_is_public_appservice_txn, next_id, + set_room_is_public_appservice_txn, + next_id, ) self.hs.get_notifier().on_new_replication_data() @@ -397,9 +384,7 @@ class RoomStore(RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.runInteraction( - "get_rooms", f - ) + return self.runInteraction("get_rooms", f) def _store_room_topic_txn(self, txn, event): if hasattr(event, "content") and "topic" in event.content: @@ -414,7 +399,7 @@ class RoomStore(RoomWorkerStore, SearchStore): ) self.store_event_search_txn( - txn, event, "content.topic", event.content["topic"], + txn, event, "content.topic", event.content["topic"] ) def _store_room_name_txn(self, txn, event): @@ -426,17 +411,17 @@ class RoomStore(RoomWorkerStore, SearchStore): "event_id": event.event_id, "room_id": event.room_id, "name": event.content["name"], - } + }, ) self.store_event_search_txn( - txn, event, "content.name", event.content["name"], + txn, event, "content.name", event.content["name"] ) def _store_room_message_txn(self, txn, event): if hasattr(event, "content") and "body" in event.content: self.store_event_search_txn( - txn, event, "content.body", event.content["body"], + txn, event, "content.body", event.content["body"] ) def _store_history_visibility_txn(self, txn, event): @@ -452,14 +437,11 @@ class RoomStore(RoomWorkerStore, SearchStore): " (event_id, room_id, %(key)s)" " VALUES (?, ?, ?)" % {"key": key} ) - txn.execute(sql, ( - event.event_id, - event.room_id, - event.content[key] - )) - - def add_event_report(self, room_id, event_id, user_id, reason, content, - received_ts): + txn.execute(sql, (event.event_id, event.room_id, event.content[key])) + + def add_event_report( + self, room_id, event_id, user_id, reason, content, received_ts + ): next_id = self._event_reports_id_gen.get_next() return self._simple_insert( table="event_reports", @@ -472,7 +454,7 @@ class RoomStore(RoomWorkerStore, SearchStore): "reason": reason, "content": json.dumps(content), }, - desc="add_event_report" + desc="add_event_report", ) def get_current_public_room_stream_id(self): @@ -480,23 +462,21 @@ class RoomStore(RoomWorkerStore, SearchStore): def get_all_new_public_rooms(self, prev_id, current_id, limit): def get_all_new_public_rooms(txn): - sql = (""" + sql = """ SELECT stream_id, room_id, visibility, appservice_id, network_id FROM public_room_list_stream WHERE stream_id > ? AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? - """) + """ - txn.execute(sql, (prev_id, current_id, limit,)) + txn.execute(sql, (prev_id, current_id, limit)) return txn.fetchall() if prev_id == current_id: return defer.succeed([]) - return self.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) + return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms) @defer.inlineCallbacks def block_room(self, room_id, user_id): @@ -511,19 +491,16 @@ class RoomStore(RoomWorkerStore, SearchStore): """ yield self._simple_upsert( table="blocked_rooms", - keyvalues={ - "room_id": room_id, - }, + keyvalues={"room_id": room_id}, values={}, - insertion_values={ - "user_id": user_id, - }, + insertion_values={"user_id": user_id}, desc="block_room", ) yield self.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, - self.is_room_blocked, (room_id,), + self.is_room_blocked, + (room_id,), ) def get_media_mxcs_in_room(self, room_id): @@ -536,6 +513,7 @@ class RoomStore(RoomWorkerStore, SearchStore): The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. """ + def _get_media_mxcs_in_room_txn(txn): local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) local_media_mxcs = [] @@ -548,23 +526,28 @@ class RoomStore(RoomWorkerStore, SearchStore): remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) return local_media_mxcs, remote_media_mxcs + return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) def quarantine_media_ids_in_room(self, room_id, quarantined_by): """For a room loops through all events with media and quarantines the associated media """ + def _quarantine_media_in_room_txn(txn): local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) total_media_quarantined = 0 # Now update all the tables to set the quarantined_by flag - txn.executemany(""" + txn.executemany( + """ UPDATE local_media_repository SET quarantined_by = ? WHERE media_id = ? - """, ((quarantined_by, media_id) for media_id in local_mxcs)) + """, + ((quarantined_by, media_id) for media_id in local_mxcs), + ) txn.executemany( """ @@ -575,7 +558,7 @@ class RoomStore(RoomWorkerStore, SearchStore): ( (quarantined_by, origin, media_id) for origin, media_id in remote_mxcs - ) + ), ) total_media_quarantined += len(local_mxcs) @@ -584,8 +567,7 @@ class RoomStore(RoomWorkerStore, SearchStore): return total_media_quarantined return self.runInteraction( - "quarantine_media_in_room", - _quarantine_media_in_room_txn, + "quarantine_media_in_room", _quarantine_media_in_room_txn ) def _get_media_mxcs_in_room_txn(self, txn, room_id): |