summary refs log tree commit diff
path: root/synapse/storage/user_directory.py
diff options
context:
space:
mode:
authorMichael Telatynski <7t3chguy@gmail.com>2018-07-24 17:17:46 +0100
committerMichael Telatynski <7t3chguy@gmail.com>2018-07-24 17:17:46 +0100
commit87951d3891efb5bccedf72c12b3da0d6ab482253 (patch)
treede7d997567c66c5a4d8743c1f3b9d6b474f5cfd9 /synapse/storage/user_directory.py
parentif inviter_display_name == ""||None then default to inviter MXID (diff)
parentMerge pull request #3595 from matrix-org/erikj/use_deltas (diff)
downloadsynapse-87951d3891efb5bccedf72c12b3da0d6ab482253.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into t3chguy/default_inviter_display_name_3pid
Diffstat (limited to 'synapse/storage/user_directory.py')
-rw-r--r--synapse/storage/user_directory.py99
1 files changed, 61 insertions, 38 deletions
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 2a4db3f03c..a8781b0e5d 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -13,17 +13,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
+import logging
+import re
 
-from ._base import SQLBaseStore
+from six import iteritems
+
+from twisted.internet import defer
 
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import get_domain_from_id, get_localpart_from_id
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
-import re
-import logging
+from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -63,7 +65,7 @@ class UserDirectoryStore(SQLBaseStore):
             user_ids (list(str)): Users to add
         """
         yield self._simple_insert_many(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             values=[
                 {
                     "user_id": user_id,
@@ -100,7 +102,7 @@ class UserDirectoryStore(SQLBaseStore):
                     user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
                     profile.display_name,
                 )
-                for user_id, profile in users_with_profile.iteritems()
+                for user_id, profile in iteritems(users_with_profile)
             )
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = """
@@ -112,7 +114,7 @@ class UserDirectoryStore(SQLBaseStore):
                     user_id,
                     "%s %s" % (user_id, p.display_name,) if p.display_name else user_id
                 )
-                for user_id, p in users_with_profile.iteritems()
+                for user_id, p in iteritems(users_with_profile)
             )
         else:
             # This should be unreachable.
@@ -130,7 +132,7 @@ class UserDirectoryStore(SQLBaseStore):
                         "display_name": profile.display_name,
                         "avatar_url": profile.avatar_url,
                     }
-                    for user_id, profile in users_with_profile.iteritems()
+                    for user_id, profile in iteritems(users_with_profile)
                 ]
             )
             for user_id in users_with_profile:
@@ -164,7 +166,7 @@ class UserDirectoryStore(SQLBaseStore):
             )
 
             if isinstance(self.database_engine, PostgresEngine):
-                # We weight the loclpart most highly, then display name and finally
+                # We weight the localpart most highly, then display name and finally
                 # server name
                 if new_entry:
                     sql = """
@@ -219,7 +221,7 @@ class UserDirectoryStore(SQLBaseStore):
     @defer.inlineCallbacks
     def update_user_in_public_user_list(self, user_id, room_id):
         yield self._simple_update_one(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
             updatevalues={"room_id": room_id},
             desc="update_user_in_public_user_list",
@@ -240,7 +242,7 @@ class UserDirectoryStore(SQLBaseStore):
             )
             self._simple_delete_txn(
                 txn,
-                table="users_in_pubic_room",
+                table="users_in_public_rooms",
                 keyvalues={"user_id": user_id},
             )
             txn.call_after(
@@ -256,18 +258,18 @@ class UserDirectoryStore(SQLBaseStore):
     @defer.inlineCallbacks
     def remove_from_user_in_public_room(self, user_id):
         yield self._simple_delete(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
             desc="remove_from_user_in_public_room",
         )
         self.get_user_in_public_room.invalidate((user_id,))
 
     def get_users_in_public_due_to_room(self, room_id):
-        """Get all user_ids that are in the room directory becuase they're
+        """Get all user_ids that are in the room directory because they're
         in the given room_id
         """
         return self._simple_select_onecol(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             keyvalues={"room_id": room_id},
             retcol="user_id",
             desc="get_users_in_public_due_to_room",
@@ -275,7 +277,7 @@ class UserDirectoryStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_users_in_dir_due_to_room(self, room_id):
-        """Get all user_ids that are in the room directory becuase they're
+        """Get all user_ids that are in the room directory because they're
         in the given room_id
         """
         user_ids_dir = yield self._simple_select_onecol(
@@ -286,7 +288,7 @@ class UserDirectoryStore(SQLBaseStore):
         )
 
         user_ids_pub = yield self._simple_select_onecol(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             keyvalues={"room_id": room_id},
             retcol="user_id",
             desc="get_users_in_dir_due_to_room",
@@ -317,6 +319,16 @@ class UserDirectoryStore(SQLBaseStore):
         rows = yield self._execute("get_all_rooms", None, sql)
         defer.returnValue([room_id for room_id, in rows])
 
+    @defer.inlineCallbacks
+    def get_all_local_users(self):
+        """Get all local users
+        """
+        sql = """
+            SELECT name FROM users
+        """
+        rows = yield self._execute("get_all_local_users", None, sql)
+        defer.returnValue([name for name, in rows])
+
     def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
         """Insert entries into the users_who_share_rooms table. The first
         user should be a local user.
@@ -514,7 +526,7 @@ class UserDirectoryStore(SQLBaseStore):
         def _delete_all_from_user_dir_txn(txn):
             txn.execute("DELETE FROM user_directory")
             txn.execute("DELETE FROM user_directory_search")
-            txn.execute("DELETE FROM users_in_pubic_room")
+            txn.execute("DELETE FROM users_in_public_rooms")
             txn.execute("DELETE FROM users_who_share_rooms")
             txn.call_after(self.get_user_in_directory.invalidate_all)
             txn.call_after(self.get_user_in_public_room.invalidate_all)
@@ -537,7 +549,7 @@ class UserDirectoryStore(SQLBaseStore):
     @cached()
     def get_user_in_public_room(self, user_id):
         return self._simple_select_one(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
             retcols=("room_id",),
             allow_none=True,
@@ -629,6 +641,25 @@ class UserDirectoryStore(SQLBaseStore):
                     ]
                 }
         """
+
+        if self.hs.config.user_directory_search_all_users:
+            # make s.user_id null to keep the ordering algorithm happy
+            join_clause = """
+                CROSS JOIN (SELECT NULL as user_id) AS s
+            """
+            join_args = ()
+            where_clause = "1=1"
+        else:
+            join_clause = """
+                LEFT JOIN users_in_public_rooms AS p USING (user_id)
+                LEFT JOIN (
+                    SELECT other_user_id AS user_id FROM users_who_share_rooms
+                    WHERE user_id = ? AND share_private
+                ) AS s USING (user_id)
+            """
+            join_args = (user_id,)
+            where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
+
         if isinstance(self.database_engine, PostgresEngine):
             full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
 
@@ -638,16 +669,12 @@ class UserDirectoryStore(SQLBaseStore):
             # The array of numbers are the weights for the various part of the
             # search: (domain, _, display name, localpart)
             sql = """
-                SELECT d.user_id, display_name, avatar_url
+                SELECT d.user_id AS user_id, display_name, avatar_url
                 FROM user_directory_search
                 INNER JOIN user_directory AS d USING (user_id)
-                LEFT JOIN users_in_pubic_room AS p USING (user_id)
-                LEFT JOIN (
-                    SELECT other_user_id AS user_id FROM users_who_share_rooms
-                    WHERE user_id = ? AND share_private
-                ) AS s USING (user_id)
+                %s
                 WHERE
-                    (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+                    %s
                     AND vector @@ to_tsquery('english', ?)
                 ORDER BY
                     (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
@@ -671,30 +698,26 @@ class UserDirectoryStore(SQLBaseStore):
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """
-            args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
+            """ % (join_clause, where_clause)
+            args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
         elif isinstance(self.database_engine, Sqlite3Engine):
             search_query = _parse_query_sqlite(search_term)
 
             sql = """
-                SELECT d.user_id, display_name, avatar_url
+                SELECT d.user_id AS user_id, display_name, avatar_url
                 FROM user_directory_search
                 INNER JOIN user_directory AS d USING (user_id)
-                LEFT JOIN users_in_pubic_room AS p USING (user_id)
-                LEFT JOIN (
-                    SELECT other_user_id AS user_id FROM users_who_share_rooms
-                    WHERE user_id = ? AND share_private
-                ) AS s USING (user_id)
+                %s
                 WHERE
-                    (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+                    %s
                     AND value MATCH ?
                 ORDER BY
                     rank(matchinfo(user_directory_search)) DESC,
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """
-            args = (user_id, search_query, limit + 1)
+            """ % (join_clause, where_clause)
+            args = join_args + (search_query, limit + 1)
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -723,7 +746,7 @@ def _parse_query_sqlite(search_term):
 
     # Pull out the individual words, discarding any non-word characters.
     results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
-    return " & ".join("(%s* | %s)" % (result, result,) for result in results)
+    return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
 
 
 def _parse_query_postgres(search_term):