summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10891.misc1
-rw-r--r--mypy.ini2
-rw-r--r--synapse/storage/databases/main/user_directory.py124
-rw-r--r--tests/handlers/test_user_directory.py5
4 files changed, 95 insertions, 37 deletions
diff --git a/changelog.d/10891.misc b/changelog.d/10891.misc
new file mode 100644
index 0000000000..6eecea4065
--- /dev/null
+++ b/changelog.d/10891.misc
@@ -0,0 +1 @@
+Improve type hinting in the user directory code.
\ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
index 3cb6cecd7e..437d0a46a5 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -85,9 +85,11 @@ files =
   tests/handlers/test_room_summary.py,
   tests/handlers/test_send_email.py,
   tests/handlers/test_sync.py,
+  tests/handlers/test_user_directory.py,
   tests/rest/client/test_login.py,
   tests/rest/client/test_auth.py,
   tests/storage/test_state.py,
+  tests/storage/test_user_directory.py,
   tests/util/test_itertools.py,
   tests/util/test_stream_change_cache.py
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 718f3e9976..7ca04237a5 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -14,14 +14,28 @@
 
 import logging
 import re
-from typing import Any, Dict, Iterable, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    cast,
+)
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.state import StateFilter
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import get_domain_from_id, get_localpart_from_id
+from synapse.storage.types import Connection
+from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     # add_users_who_share_private_rooms?
     SHARE_PRIVATE_WORKING_SET = 500
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: Connection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -57,10 +76,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             "populate_user_directory_cleanup", self._populate_user_directory_cleanup
         )
 
-    async def _populate_user_directory_createtables(self, progress, batch_size):
+    async def _populate_user_directory_createtables(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
 
         # Get all the rooms that we want to process.
-        def _make_staging_area(txn):
+        def _make_staging_area(txn: LoggingTransaction) -> None:
             sql = (
                 "CREATE TABLE IF NOT EXISTS "
                 + TEMP_TABLE
@@ -110,16 +131,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
         return 1
 
-    async def _populate_user_directory_cleanup(self, progress, batch_size):
+    async def _populate_user_directory_cleanup(
+        self,
+        progress: JsonDict,
+        batch_size: int,
+    ) -> int:
         """
         Update the user directory stream position, then clean up the old tables.
         """
         position = await self.db_pool.simple_select_one_onecol(
-            TEMP_TABLE + "_position", None, "position"
+            TEMP_TABLE + "_position", {}, "position"
         )
         await self.update_user_directory_stream_pos(position)
 
-        def _delete_staging_area(txn):
+        def _delete_staging_area(txn: LoggingTransaction) -> None:
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
@@ -133,18 +158,32 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
         return 1
 
-    async def _populate_user_directory_process_rooms(self, progress, batch_size):
+    async def _populate_user_directory_process_rooms(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """
+        Rescan the state of all rooms so we can track
+
+        - who's in a public room;
+        - which local users share a private room with other users (local
+          and remote); and
+        - who should be in the user_directory.
+
         Args:
             progress (dict)
             batch_size (int): Maximum number of state events to process
                 per cycle.
+
+        Returns:
+            number of events processed.
         """
         # If we don't have progress filed, delete everything.
         if not progress:
             await self.delete_all_from_user_dir()
 
-        def _get_next_batch(txn):
+        def _get_next_batch(
+            txn: LoggingTransaction,
+        ) -> Optional[Sequence[Tuple[str, int]]]:
             # Only fetch 250 rooms, so we don't fetch too many at once, even
             # if those 250 rooms have less than batch_size state events.
             sql = """
@@ -155,7 +194,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 TEMP_TABLE + "_rooms",
             )
             txn.execute(sql)
-            rooms_to_work_on = txn.fetchall()
+            rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())
 
             if not rooms_to_work_on:
                 return None
@@ -163,7 +202,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             # Get how many are left to process, so we can give status on how
             # far we are in processing
             txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
-            progress["remaining"] = txn.fetchone()[0]
+            result = txn.fetchone()
+            assert result is not None
+            progress["remaining"] = result[0]
 
             return rooms_to_work_on
 
@@ -261,29 +302,33 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         return processed_event_count
 
-    async def _populate_user_directory_process_users(self, progress, batch_size):
+    async def _populate_user_directory_process_users(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """
         Add all local users to the user directory.
         """
 
-        def _get_next_batch(txn):
+        def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
             sql = "SELECT user_id FROM %s LIMIT %s" % (
                 TEMP_TABLE + "_users",
                 str(batch_size),
             )
             txn.execute(sql)
-            users_to_work_on = txn.fetchall()
+            user_result = cast(List[Tuple[str]], txn.fetchall())
 
-            if not users_to_work_on:
+            if not user_result:
                 return None
 
-            users_to_work_on = [x[0] for x in users_to_work_on]
+            users_to_work_on = [x[0] for x in user_result]
 
             # Get how many are left to process, so we can give status on how
             # far we are in processing
             sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
             txn.execute(sql)
-            progress["remaining"] = txn.fetchone()[0]
+            count_result = txn.fetchone()
+            assert count_result is not None
+            progress["remaining"] = count_result[0]
 
             return users_to_work_on
 
@@ -324,7 +369,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         return len(users_to_work_on)
 
-    async def is_room_world_readable_or_publicly_joinable(self, room_id):
+    async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
         """Check if the room is either world_readable or publically joinable"""
 
         # Create a state filter that only queries join and history state event
@@ -368,7 +413,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         if not isinstance(avatar_url, str):
             avatar_url = None
 
-        def _update_profile_in_user_dir_txn(txn):
+        def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_upsert_txn(
                 txn,
                 table="user_directory",
@@ -435,7 +480,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 for user_id, other_user_id in user_id_tuples
             ],
             value_names=(),
-            value_values=None,
+            value_values=(),
             desc="add_users_who_share_room",
         )
 
@@ -454,14 +499,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             key_names=["user_id", "room_id"],
             key_values=[(user_id, room_id) for user_id in user_ids],
             value_names=(),
-            value_values=None,
+            value_values=(),
             desc="add_users_in_public_rooms",
         )
 
     async def delete_all_from_user_dir(self) -> None:
         """Delete the entire user directory"""
 
-        def _delete_all_from_user_dir_txn(txn):
+        def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
             txn.execute("DELETE FROM user_directory")
             txn.execute("DELETE FROM user_directory_search")
             txn.execute("DELETE FROM users_in_public_rooms")
@@ -473,7 +518,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
     @cached()
-    async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+    async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
         return await self.db_pool.simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
@@ -497,7 +542,12 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     # add_users_who_share_private_rooms?
     SHARE_PRIVATE_WORKING_SET = 500
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: Connection,
+        hs: "HomeServer",
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
         self._prefer_local_users_in_search = (
@@ -506,7 +556,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         self._server_name = hs.config.server.server_name
 
     async def remove_from_user_dir(self, user_id: str) -> None:
-        def _remove_from_user_dir_txn(txn):
+        def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn, table="user_directory", keyvalues={"user_id": user_id}
             )
@@ -532,7 +582,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             "remove_from_user_dir", _remove_from_user_dir_txn
         )
 
-    async def get_users_in_dir_due_to_room(self, room_id):
+    async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]:
         """Get all user_ids that are in the room directory because they're
         in the given room_id
         """
@@ -565,7 +615,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             room_id
         """
 
-        def _remove_user_who_share_room_txn(txn):
+        def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="users_who_share_private_rooms",
@@ -586,7 +636,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             "remove_user_who_share_room", _remove_user_who_share_room_txn
         )
 
-    async def get_user_dir_rooms_user_is_in(self, user_id):
+    async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]:
         """
         Returns the rooms that a user is in.
 
@@ -628,7 +678,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             A set of room ID's that the users share.
         """
 
-        def _get_shared_rooms_for_users_txn(txn):
+        def _get_shared_rooms_for_users_txn(
+            txn: LoggingTransaction,
+        ) -> List[Dict[str, str]]:
             txn.execute(
                 """
                 SELECT p1.room_id
@@ -669,7 +721,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             desc="get_user_directory_stream_pos",
         )
 
-    async def search_user_dir(self, user_id, search_term, limit):
+    async def search_user_dir(
+        self, user_id: str, search_term: str, limit: int
+    ) -> JsonDict:
         """Searches for users in directory
 
         Returns:
@@ -705,7 +759,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         # We allow manipulating the ranking algorithm by injecting statements
         # based on config options.
         additional_ordering_statements = []
-        ordering_arguments = ()
+        ordering_arguments: Tuple[str, ...] = ()
 
         if isinstance(self.database_engine, PostgresEngine):
             full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@@ -811,7 +865,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         return {"limited": limited, "results": results}
 
 
-def _parse_query_sqlite(search_term):
+def _parse_query_sqlite(search_term: str) -> str:
     """Takes a plain unicode string from the user and converts it into a form
     that can be passed to database.
     We use this so that we can add prefix matching, which isn't something
@@ -826,7 +880,7 @@ def _parse_query_sqlite(search_term):
     return " & ".join("(%s* OR %s)" % (result, result) for result in results)
 
 
-def _parse_query_postgres(search_term):
+def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
     """Takes a plain unicode string from the user and converts it into a form
     that can be passed to database.
     We use this so that we can add prefix matching, which isn't something
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index f3684c34a2..ba32585a14 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import List, Tuple
 from unittest.mock import Mock
 from urllib.parse import quote
 
@@ -325,7 +326,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             r.add((i["user_id"], i["other_user_id"], i["room_id"]))
         return r
 
-    def get_users_in_public_rooms(self):
+    def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
         r = self.get_success(
             self.store.db_pool.simple_select_list(
                 "users_in_public_rooms", None, ("user_id", "room_id")
@@ -336,7 +337,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             retval.append((i["user_id"], i["room_id"]))
         return retval
 
-    def get_users_who_share_private_rooms(self):
+    def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]:
         return self.get_success(
             self.store.db_pool.simple_select_list(
                 "users_who_share_private_rooms",