summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/user_directory.py153
1 files changed, 66 insertions, 87 deletions
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index e8b574ee5e..fea866c043 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -44,7 +44,7 @@ class UserDirectoryStore(SQLBaseStore):
         )
 
         current_state_ids = yield self.get_filtered_current_state_ids(
-            room_id, StateFilter.from_types(types_to_filter),
+            room_id, StateFilter.from_types(types_to_filter)
         )
 
         join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
@@ -74,14 +74,8 @@ class UserDirectoryStore(SQLBaseStore):
         """
         yield self._simple_insert_many(
             table="users_in_public_rooms",
-            values=[
-                {
-                    "user_id": user_id,
-                    "room_id": room_id,
-                }
-                for user_id in user_ids
-            ],
-            desc="add_users_to_public_room"
+            values=[{"user_id": user_id, "room_id": room_id} for user_id in user_ids],
+            desc="add_users_to_public_room",
         )
         for user_id in user_ids:
             self.get_user_in_public_room.invalidate((user_id,))
@@ -107,7 +101,9 @@ class UserDirectoryStore(SQLBaseStore):
             """
             args = (
                 (
-                    user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
+                    user_id,
+                    get_localpart_from_id(user_id),
+                    get_domain_from_id(user_id),
                     profile.display_name,
                 )
                 for user_id, profile in iteritems(users_with_profile)
@@ -120,7 +116,7 @@ class UserDirectoryStore(SQLBaseStore):
             args = (
                 (
                     user_id,
-                    "%s %s" % (user_id, p.display_name,) if p.display_name else user_id
+                    "%s %s" % (user_id, p.display_name) if p.display_name else user_id,
                 )
                 for user_id, p in iteritems(users_with_profile)
             )
@@ -141,12 +137,10 @@ class UserDirectoryStore(SQLBaseStore):
                         "avatar_url": profile.avatar_url,
                     }
                     for user_id, profile in iteritems(users_with_profile)
-                ]
+                ],
             )
             for user_id in users_with_profile:
-                txn.call_after(
-                    self.get_user_in_directory.invalidate, (user_id,)
-                )
+                txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
         return self.runInteraction(
             "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
@@ -188,9 +182,11 @@ class UserDirectoryStore(SQLBaseStore):
                     txn.execute(
                         sql,
                         (
-                            user_id, get_localpart_from_id(user_id),
-                            get_domain_from_id(user_id), display_name,
-                        )
+                            user_id,
+                            get_localpart_from_id(user_id),
+                            get_domain_from_id(user_id),
+                            display_name,
+                        ),
                     )
                 else:
                     # TODO: Remove this code after we've bumped the minimum version
@@ -208,9 +204,11 @@ class UserDirectoryStore(SQLBaseStore):
                         txn.execute(
                             sql,
                             (
-                                user_id, get_localpart_from_id(user_id),
-                                get_domain_from_id(user_id), display_name,
-                            )
+                                user_id,
+                                get_localpart_from_id(user_id),
+                                get_domain_from_id(user_id),
+                                display_name,
+                            ),
                         )
                     elif new_entry is False:
                         sql = """
@@ -225,15 +223,16 @@ class UserDirectoryStore(SQLBaseStore):
                             (
                                 get_localpart_from_id(user_id),
                                 get_domain_from_id(user_id),
-                                display_name, user_id,
-                            )
+                                display_name,
+                                user_id,
+                            ),
                         )
                     else:
                         raise RuntimeError(
                             "upsert returned None when 'can_native_upsert' is False"
                         )
             elif isinstance(self.database_engine, Sqlite3Engine):
-                value = "%s %s" % (user_id, display_name,) if display_name else user_id
+                value = "%s %s" % (user_id, display_name) if display_name else user_id
                 self._simple_upsert_txn(
                     txn,
                     table="user_directory_search",
@@ -264,29 +263,18 @@ class UserDirectoryStore(SQLBaseStore):
     def remove_from_user_dir(self, user_id):
         def _remove_from_user_dir_txn(txn):
             self._simple_delete_txn(
-                txn,
-                table="user_directory",
-                keyvalues={"user_id": user_id},
+                txn, table="user_directory", keyvalues={"user_id": user_id}
             )
             self._simple_delete_txn(
-                txn,
-                table="user_directory_search",
-                keyvalues={"user_id": user_id},
+                txn, table="user_directory_search", keyvalues={"user_id": user_id}
             )
             self._simple_delete_txn(
-                txn,
-                table="users_in_public_rooms",
-                keyvalues={"user_id": user_id},
-            )
-            txn.call_after(
-                self.get_user_in_directory.invalidate, (user_id,)
-            )
-            txn.call_after(
-                self.get_user_in_public_room.invalidate, (user_id,)
+                txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
             )
-        return self.runInteraction(
-            "remove_from_user_dir", _remove_from_user_dir_txn,
-        )
+            txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+            txn.call_after(self.get_user_in_public_room.invalidate, (user_id,))
+
+        return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
 
     @defer.inlineCallbacks
     def remove_from_user_in_public_room(self, user_id):
@@ -371,6 +359,7 @@ class UserDirectoryStore(SQLBaseStore):
             share_private (bool): Is the room private
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
+
         def _add_users_who_share_room_txn(txn):
             self._simple_insert_many_txn(
                 txn,
@@ -387,13 +376,12 @@ class UserDirectoryStore(SQLBaseStore):
             )
             for user_id, other_user_id in user_id_tuples:
                 txn.call_after(
-                    self.get_users_who_share_room_from_dir.invalidate,
-                    (user_id,),
+                    self.get_users_who_share_room_from_dir.invalidate, (user_id,)
                 )
                 txn.call_after(
-                    self.get_if_users_share_a_room.invalidate,
-                    (user_id, other_user_id),
+                    self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
                 )
+
         return self.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
@@ -407,6 +395,7 @@ class UserDirectoryStore(SQLBaseStore):
             share_private (bool): Is the room private
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
+
         def _update_users_who_share_room_txn(txn):
             sql = """
                 UPDATE users_who_share_rooms
@@ -414,21 +403,16 @@ class UserDirectoryStore(SQLBaseStore):
                 WHERE user_id = ? AND other_user_id = ?
             """
             txn.executemany(
-                sql,
-                (
-                    (room_id, share_private, uid, oid)
-                    for uid, oid in user_id_sets
-                )
+                sql, ((room_id, share_private, uid, oid) for uid, oid in user_id_sets)
             )
             for user_id, other_user_id in user_id_sets:
                 txn.call_after(
-                    self.get_users_who_share_room_from_dir.invalidate,
-                    (user_id,),
+                    self.get_users_who_share_room_from_dir.invalidate, (user_id,)
                 )
                 txn.call_after(
-                    self.get_if_users_share_a_room.invalidate,
-                    (user_id, other_user_id),
+                    self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
                 )
+
         return self.runInteraction(
             "update_users_who_share_room", _update_users_who_share_room_txn
         )
@@ -442,22 +426,18 @@ class UserDirectoryStore(SQLBaseStore):
             share_private (bool): Is the room private
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
+
         def _remove_user_who_share_room_txn(txn):
             self._simple_delete_txn(
                 txn,
                 table="users_who_share_rooms",
-                keyvalues={
-                    "user_id": user_id,
-                    "other_user_id": other_user_id,
-                },
+                keyvalues={"user_id": user_id, "other_user_id": other_user_id},
             )
             txn.call_after(
-                self.get_users_who_share_room_from_dir.invalidate,
-                (user_id,),
+                self.get_users_who_share_room_from_dir.invalidate, (user_id,)
             )
             txn.call_after(
-                self.get_if_users_share_a_room.invalidate,
-                (user_id, other_user_id),
+                self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
             )
 
         return self.runInteraction(
@@ -478,10 +458,7 @@ class UserDirectoryStore(SQLBaseStore):
         """
         return self._simple_select_one_onecol(
             table="users_who_share_rooms",
-            keyvalues={
-                "user_id": user_id,
-                "other_user_id": other_user_id,
-            },
+            keyvalues={"user_id": user_id, "other_user_id": other_user_id},
             retcol="share_private",
             allow_none=True,
             desc="get_if_users_share_a_room",
@@ -499,17 +476,12 @@ class UserDirectoryStore(SQLBaseStore):
         """
         rows = yield self._simple_select_list(
             table="users_who_share_rooms",
-            keyvalues={
-                "user_id": user_id,
-            },
-            retcols=("other_user_id", "share_private",),
+            keyvalues={"user_id": user_id},
+            retcols=("other_user_id", "share_private"),
             desc="get_users_who_share_room_with_user",
         )
 
-        defer.returnValue({
-            row["other_user_id"]: row["share_private"]
-            for row in rows
-        })
+        defer.returnValue({row["other_user_id"]: row["share_private"] for row in rows})
 
     def get_users_in_share_dir_with_room_id(self, user_id, room_id):
         """Get all user tuples that are in the users_who_share_rooms due to the
@@ -556,6 +528,7 @@ class UserDirectoryStore(SQLBaseStore):
     def delete_all_from_user_dir(self):
         """Delete the entire user directory
         """
+
         def _delete_all_from_user_dir_txn(txn):
             txn.execute("DELETE FROM user_directory")
             txn.execute("DELETE FROM user_directory_search")
@@ -565,6 +538,7 @@ class UserDirectoryStore(SQLBaseStore):
             txn.call_after(self.get_user_in_public_room.invalidate_all)
             txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
             txn.call_after(self.get_if_users_share_a_room.invalidate_all)
+
         return self.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
         )
@@ -574,7 +548,7 @@ class UserDirectoryStore(SQLBaseStore):
         return self._simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
-            retcols=("room_id", "display_name", "avatar_url",),
+            retcols=("room_id", "display_name", "avatar_url"),
             allow_none=True,
             desc="get_user_in_directory",
         )
@@ -607,7 +581,9 @@ class UserDirectoryStore(SQLBaseStore):
 
     def get_current_state_deltas(self, prev_stream_id):
         prev_stream_id = int(prev_stream_id)
-        if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+        if not self._curr_state_delta_stream_cache.has_any_entity_changed(
+            prev_stream_id
+        ):
             return []
 
         def get_current_state_deltas_txn(txn):
@@ -641,7 +617,7 @@ class UserDirectoryStore(SQLBaseStore):
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
             """
-            txn.execute(sql, (prev_stream_id, max_stream_id,))
+            txn.execute(sql, (prev_stream_id, max_stream_id))
             return self.cursor_to_dict(txn)
 
         return self.runInteraction(
@@ -731,8 +707,11 @@ class UserDirectoryStore(SQLBaseStore):
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """ % (join_clause, where_clause)
-            args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
+            """ % (
+                join_clause,
+                where_clause,
+            )
+            args = join_args + (full_query, exact_query, prefix_query, limit + 1)
         elif isinstance(self.database_engine, Sqlite3Engine):
             search_query = _parse_query_sqlite(search_term)
 
@@ -749,7 +728,10 @@ class UserDirectoryStore(SQLBaseStore):
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """ % (join_clause, where_clause)
+            """ % (
+                join_clause,
+                where_clause,
+            )
             args = join_args + (search_query, limit + 1)
         else:
             # This should be unreachable.
@@ -761,10 +743,7 @@ class UserDirectoryStore(SQLBaseStore):
 
         limited = len(results) > limit
 
-        defer.returnValue({
-            "limited": limited,
-            "results": results,
-        })
+        defer.returnValue({"limited": limited, "results": results})
 
 
 def _parse_query_sqlite(search_term):
@@ -779,7 +758,7 @@ def _parse_query_sqlite(search_term):
 
     # Pull out the individual words, discarding any non-word characters.
     results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
-    return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
+    return " & ".join("(%s* OR %s)" % (result, result) for result in results)
 
 
 def _parse_query_postgres(search_term):
@@ -792,7 +771,7 @@ def _parse_query_postgres(search_term):
     # Pull out the individual words, discarding any non-word characters.
     results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
 
-    both = " & ".join("(%s:* | %s)" % (result, result,) for result in results)
+    both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
     exact = " & ".join("%s" % (result,) for result in results)
     prefix = " & ".join("%s:*" % (result,) for result in results)