summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/events/spamcheck.py7
-rw-r--r--synapse/handlers/user_directory.py4
-rw-r--r--synapse/module_api/__init__.py2
-rw-r--r--synapse/rest/client/user_directory.py4
-rw-r--r--synapse/storage/databases/main/user_directory.py23
-rw-r--r--synapse/types.py11
6 files changed, 39 insertions, 12 deletions
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 60904a55f5..cd80fcf9d1 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -21,7 +21,6 @@ from typing import (
     Awaitable,
     Callable,
     Collection,
-    Dict,
     List,
     Optional,
     Tuple,
@@ -31,7 +30,7 @@ from typing import (
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.types import RoomAlias
+from synapse.types import RoomAlias, UserProfile
 from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
@@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo
 USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
 USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
 USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
-CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
+CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
 LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
     [
         Optional[dict],
@@ -383,7 +382,7 @@ class SpamChecker:
 
         return True
 
-    async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+    async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
         """Checks if a user ID or display name are considered "spammy" by this server.
 
         If the server considers a username spammy, then it will not be included in
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d27ed2be6a..048fd4bb82 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -19,8 +19,8 @@ import synapse.metrics
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
 from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.user_directory import SearchResult
 from synapse.storage.roommember import ProfileInfo
-from synapse.types import JsonDict
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
@@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler):
 
     async def search_users(
         self, user_id: str, search_term: str, limit: int
-    ) -> JsonDict:
+    ) -> SearchResult:
         """Searches for users in directory
 
         Returns:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d735c1d461..aa8256b36f 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -111,6 +111,7 @@ from synapse.types import (
     StateMap,
     UserID,
     UserInfo,
+    UserProfile,
     create_requester,
 )
 from synapse.util import Clock
@@ -150,6 +151,7 @@ __all__ = [
     "EventBase",
     "StateMap",
     "ProfileInfo",
+    "UserProfile",
 ]
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index a47d9bd01d..116c982ce6 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict
+from synapse.types import JsonMapping
 
 from ._base import client_patterns
 
@@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.user_directory_handler = hs.get_user_directory_handler()
 
-    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]:
         """Searches for users in directory
 
         Returns:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e7fddd2426..55cc9178f0 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,8 @@ from typing import (
     cast,
 )
 
+from typing_extensions import TypedDict
+
 from synapse.api.errors import StoreError
 
 if TYPE_CHECKING:
@@ -40,7 +42,12 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.state import StateFilter
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
+from synapse.types import (
+    JsonDict,
+    UserProfile,
+    get_domain_from_id,
+    get_localpart_from_id,
+)
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
 
+class SearchResult(TypedDict):
+    limited: bool
+    results: List[UserProfile]
+
+
 class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     # How many records do we calculate before sending it to
     # add_users_who_share_private_rooms?
@@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
 
     async def search_user_dir(
         self, user_id: str, search_term: str, limit: int
-    ) -> JsonDict:
+    ) -> SearchResult:
         """Searches for users in directory
 
         Returns:
@@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
 
-        results = await self.db_pool.execute(
-            "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+        results = cast(
+            List[UserProfile],
+            await self.db_pool.execute(
+                "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+            ),
         )
 
         limited = len(results) > limit
diff --git a/synapse/types.py b/synapse/types.py
index 53be3583a0..5ce2a5b0a5 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -34,6 +34,7 @@ from typing import (
 import attr
 from frozendict import frozendict
 from signedjson.key import decode_verify_key_bytes
+from typing_extensions import TypedDict
 from unpaddedbase64 import decode_base64
 from zope.interface import Interface
 
@@ -63,6 +64,10 @@ MutableStateMap = MutableMapping[StateKey, T]
 # JSON types. These could be made stronger, but will do for now.
 # A JSON-serialisable dict.
 JsonDict = Dict[str, Any]
+# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
+# Useful when you have a TypedDict which isn't going to be mutated and you don't want
+# to cast to JsonDict everywhere.
+JsonMapping = Mapping[str, Any]
 # A JSON-serialisable object.
 JsonSerializable = object
 
@@ -791,3 +796,9 @@ class UserInfo:
     is_deactivated: bool
     is_guest: bool
     is_shadow_banned: bool
+
+
+class UserProfile(TypedDict):
+    user_id: str
+    display_name: Optional[str]
+    avatar_url: Optional[str]