summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-09-18 09:55:04 -0400
committerGitHub <noreply@github.com>2023-09-18 09:55:04 -0400
commitc1e244c8f70ff1a23e358e1608c555f9722dee1f (patch)
tree3a38c06702db3335ec7ca131af283c1f0ef370c0
parentReturn an immutable value from get_latest_event_ids_in_room. (#16326) (diff)
downloadsynapse-c1e244c8f70ff1a23e358e1608c555f9722dee1f.tar.xz
Make cached account data/tags/admin types immutable (#16325)
-rw-r--r--changelog.d/16325.misc1
-rw-r--r--synapse/app/admin_cmd.py14
-rw-r--r--synapse/handlers/admin.py18
-rw-r--r--synapse/handlers/sync.py27
-rw-r--r--synapse/rest/admin/users.py8
-rw-r--r--synapse/rest/client/account_data.py10
-rw-r--r--synapse/storage/databases/main/account_data.py14
-rw-r--r--synapse/storage/databases/main/experimental_features.py7
-rw-r--r--synapse/storage/databases/main/tags.py6
9 files changed, 55 insertions, 50 deletions
diff --git a/changelog.d/16325.misc b/changelog.d/16325.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16325.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index f9aada269a..aa24f7da6c 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -17,7 +17,7 @@ import logging
 import os
 import sys
 import tempfile
-from typing import List, Mapping, Optional
+from typing import List, Mapping, Optional, Sequence
 
 from twisted.internet import defer, task
 
@@ -57,7 +57,7 @@ from synapse.storage.databases.main.state import StateGroupWorkerStore
 from synapse.storage.databases.main.stream import StreamWorkerStore
 from synapse.storage.databases.main.tags import TagsWorkerStore
 from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
-from synapse.types import JsonDict, StateMap
+from synapse.types import JsonMapping, StateMap
 from synapse.util import SYNAPSE_VERSION
 from synapse.util.logcontext import LoggingContext
 
@@ -198,7 +198,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
             for event in state.values():
                 json.dump(event, fp=f)
 
-    def write_profile(self, profile: JsonDict) -> None:
+    def write_profile(self, profile: JsonMapping) -> None:
         user_directory = os.path.join(self.base_directory, "user_data")
         os.makedirs(user_directory, exist_ok=True)
         profile_file = os.path.join(user_directory, "profile")
@@ -206,7 +206,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
         with open(profile_file, "a") as f:
             json.dump(profile, fp=f)
 
-    def write_devices(self, devices: List[JsonDict]) -> None:
+    def write_devices(self, devices: Sequence[JsonMapping]) -> None:
         user_directory = os.path.join(self.base_directory, "user_data")
         os.makedirs(user_directory, exist_ok=True)
         device_file = os.path.join(user_directory, "devices")
@@ -215,7 +215,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
             with open(device_file, "a") as f:
                 json.dump(device, fp=f)
 
-    def write_connections(self, connections: List[JsonDict]) -> None:
+    def write_connections(self, connections: Sequence[JsonMapping]) -> None:
         user_directory = os.path.join(self.base_directory, "user_data")
         os.makedirs(user_directory, exist_ok=True)
         connection_file = os.path.join(user_directory, "connections")
@@ -225,7 +225,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
                 json.dump(connection, fp=f)
 
     def write_account_data(
-        self, file_name: str, account_data: Mapping[str, JsonDict]
+        self, file_name: str, account_data: Mapping[str, JsonMapping]
     ) -> None:
         account_data_directory = os.path.join(
             self.base_directory, "user_data", "account_data"
@@ -237,7 +237,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
         with open(account_data_file, "a") as f:
             json.dump(account_data, fp=f)
 
-    def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
+    def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None:
         file_directory = os.path.join(self.base_directory, "media_ids")
         os.makedirs(file_directory, exist_ok=True)
         media_id_file = os.path.join(file_directory, media_id)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 7092ff3449..ba9704a065 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -14,11 +14,11 @@
 
 import abc
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set
 
 from synapse.api.constants import Direction, Membership
 from synapse.events import EventBase
-from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
+from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo
 from synapse.visibility import filter_events_for_client
 
 if TYPE_CHECKING:
@@ -35,7 +35,7 @@ class AdminHandler:
         self._state_storage_controller = self._storage_controllers.state
         self._msc3866_enabled = hs.config.experimental.msc3866.enabled
 
-    async def get_whois(self, user: UserID) -> JsonDict:
+    async def get_whois(self, user: UserID) -> JsonMapping:
         connections = []
 
         sessions = await self._store.get_user_ip_and_agents(user)
@@ -55,7 +55,7 @@ class AdminHandler:
 
         return ret
 
-    async def get_user(self, user: UserID) -> Optional[JsonDict]:
+    async def get_user(self, user: UserID) -> Optional[JsonMapping]:
         """Function to get user details"""
         user_info: Optional[UserInfo] = await self._store.get_user_by_id(
             user.to_string()
@@ -344,7 +344,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def write_profile(self, profile: JsonDict) -> None:
+    def write_profile(self, profile: JsonMapping) -> None:
         """Write the profile of a user.
 
         Args:
@@ -353,7 +353,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def write_devices(self, devices: List[JsonDict]) -> None:
+    def write_devices(self, devices: Sequence[JsonMapping]) -> None:
         """Write the devices of a user.
 
         Args:
@@ -362,7 +362,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def write_connections(self, connections: List[JsonDict]) -> None:
+    def write_connections(self, connections: Sequence[JsonMapping]) -> None:
         """Write the connections of a user.
 
         Args:
@@ -372,7 +372,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
     def write_account_data(
-        self, file_name: str, account_data: Mapping[str, JsonDict]
+        self, file_name: str, account_data: Mapping[str, JsonMapping]
     ) -> None:
         """Write the account data of a user.
 
@@ -383,7 +383,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
+    def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None:
         """Write the media's metadata of a user.
         Exports only the metadata, as this can be fetched from the database via
         read only. In order to access the files, a connection to the correct
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f1f19666d7..1a4d394eda 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -57,6 +57,7 @@ from synapse.storage.roommember import MemberSummary
 from synapse.types import (
     DeviceListUpdates,
     JsonDict,
+    JsonMapping,
     MutableStateMap,
     Requester,
     RoomStreamToken,
@@ -1793,19 +1794,23 @@ class SyncHandler:
             )
 
             if push_rules_changed:
-                global_account_data = dict(global_account_data)
-                global_account_data[
-                    AccountDataTypes.PUSH_RULES
-                ] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
+                global_account_data = {
+                    AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user(
+                        sync_config.user
+                    ),
+                    **global_account_data,
+                }
         else:
             all_global_account_data = await self.store.get_global_account_data_for_user(
                 user_id
             )
 
-            global_account_data = dict(all_global_account_data)
-            global_account_data[
-                AccountDataTypes.PUSH_RULES
-            ] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
+            global_account_data = {
+                AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user(
+                    sync_config.user
+                ),
+                **all_global_account_data,
+            }
 
         account_data_for_user = (
             await sync_config.filter_collection.filter_global_account_data(
@@ -1909,7 +1914,7 @@ class SyncHandler:
             blocks_all_rooms
             or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data()
         ):
-            account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {}
+            account_data_by_room: Mapping[str, Mapping[str, JsonMapping]] = {}
         elif since_token and not sync_result_builder.full_state:
             account_data_by_room = (
                 await self.store.get_updated_room_account_data_for_user(
@@ -2349,8 +2354,8 @@ class SyncHandler:
         sync_result_builder: "SyncResultBuilder",
         room_builder: "RoomSyncResultBuilder",
         ephemeral: List[JsonDict],
-        tags: Optional[Mapping[str, Mapping[str, Any]]],
-        account_data: Mapping[str, JsonDict],
+        tags: Optional[Mapping[str, JsonMapping]],
+        account_data: Mapping[str, JsonMapping],
         always_include: bool = False,
     ) -> None:
         """Populates the `joined` and `archived` section of `sync_result_builder`
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 91898a5c13..9aaa88e229 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -39,7 +39,7 @@ from synapse.rest.admin._base import (
 from synapse.rest.client._base import client_patterns
 from synapse.storage.databases.main.registration import ExternalIDReuseException
 from synapse.storage.databases.main.stats import UserSortOrder
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, JsonMapping, UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -211,7 +211,7 @@ class UserRestServletV2(RestServlet):
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
-    ) -> Tuple[int, JsonDict]:
+    ) -> Tuple[int, JsonMapping]:
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
@@ -226,7 +226,7 @@ class UserRestServletV2(RestServlet):
 
     async def on_PUT(
         self, request: SynapseRequest, user_id: str
-    ) -> Tuple[int, JsonDict]:
+    ) -> Tuple[int, JsonMapping]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester)
 
@@ -658,7 +658,7 @@ class WhoisRestServlet(RestServlet):
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
-    ) -> Tuple[int, JsonDict]:
+    ) -> Tuple[int, JsonMapping]:
         target_user = UserID.from_string(user_id)
         requester = await self.auth.get_user_by_req(request)
 
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index b1f9e9dc9b..ce0c4e7742 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -20,7 +20,7 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict, RoomID
+from synapse.types import JsonDict, JsonMapping, RoomID
 
 from ._base import client_patterns
 
@@ -95,7 +95,7 @@ class AccountDataServlet(RestServlet):
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str, account_data_type: str
-    ) -> Tuple[int, JsonDict]:
+    ) -> Tuple[int, JsonMapping]:
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot get account data for other users.")
@@ -106,7 +106,7 @@ class AccountDataServlet(RestServlet):
             and account_data_type == AccountDataTypes.PUSH_RULES
         ):
             account_data: Optional[
-                JsonDict
+                JsonMapping
             ] = await self._push_rules_handler.push_rules_for_user(requester.user)
         else:
             account_data = await self.store.get_global_account_data_by_type_for_user(
@@ -236,7 +236,7 @@ class RoomAccountDataServlet(RestServlet):
         user_id: str,
         room_id: str,
         account_data_type: str,
-    ) -> Tuple[int, JsonDict]:
+    ) -> Tuple[int, JsonMapping]:
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot get account data for other users.")
@@ -253,7 +253,7 @@ class RoomAccountDataServlet(RestServlet):
             self._hs.config.experimental.msc4010_push_rules_account_data
             and account_data_type == AccountDataTypes.PUSH_RULES
         ):
-            account_data: Optional[JsonDict] = {}
+            account_data: Optional[JsonMapping] = {}
         else:
             account_data = await self.store.get_account_data_for_room_and_type(
                 user_id, room_id, account_data_type
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 8f7bdbc61a..80f146dd53 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -43,7 +43,7 @@ from synapse.storage.util.id_generators import (
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
-from synapse.types import JsonDict
+from synapse.types import JsonDict, JsonMapping
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -119,7 +119,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     @cached()
     async def get_global_account_data_for_user(
         self, user_id: str
-    ) -> Mapping[str, JsonDict]:
+    ) -> Mapping[str, JsonMapping]:
         """
         Get all the global client account_data for a user.
 
@@ -164,7 +164,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     @cached()
     async def get_room_account_data_for_user(
         self, user_id: str
-    ) -> Mapping[str, Mapping[str, JsonDict]]:
+    ) -> Mapping[str, Mapping[str, JsonMapping]]:
         """
         Get all of the per-room client account_data for a user.
 
@@ -213,7 +213,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     @cached(num_args=2, max_entries=5000, tree=True)
     async def get_global_account_data_by_type_for_user(
         self, user_id: str, data_type: str
-    ) -> Optional[JsonDict]:
+    ) -> Optional[JsonMapping]:
         """
         Returns:
             The account data.
@@ -265,7 +265,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     @cached(num_args=2, tree=True)
     async def get_account_data_for_room(
         self, user_id: str, room_id: str
-    ) -> Mapping[str, JsonDict]:
+    ) -> Mapping[str, JsonMapping]:
         """Get all the client account_data for a user for a room.
 
         Args:
@@ -296,7 +296,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     @cached(num_args=3, max_entries=5000, tree=True)
     async def get_account_data_for_room_and_type(
         self, user_id: str, room_id: str, account_data_type: str
-    ) -> Optional[JsonDict]:
+    ) -> Optional[JsonMapping]:
         """Get the client account_data of given type for a user for a room.
 
         Args:
@@ -394,7 +394,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
     async def get_updated_global_account_data_for_user(
         self, user_id: str, stream_id: int
-    ) -> Dict[str, JsonDict]:
+    ) -> Mapping[str, JsonMapping]:
         """Get all the global account_data that's changed for a user.
 
         Args:
diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py
index cf3226ae5a..654f924019 100644
--- a/synapse/storage/databases/main/experimental_features.py
+++ b/synapse/storage/databases/main/experimental_features.py
@@ -12,11 +12,10 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, FrozenSet
 
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
-from synapse.types import StrCollection
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -34,7 +33,7 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
         super().__init__(database, db_conn, hs)
 
     @cached()
-    async def list_enabled_features(self, user_id: str) -> StrCollection:
+    async def list_enabled_features(self, user_id: str) -> FrozenSet[str]:
         """
         Checks to see what features are enabled for a given user
         Args:
@@ -49,7 +48,7 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
             ["feature"],
         )
 
-        return [feature["feature"] for feature in enabled]
+        return frozenset(feature["feature"] for feature in enabled)
 
     async def set_features_for_user(
         self,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index c149a9eacb..61403a98cf 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -23,7 +23,7 @@ from synapse.storage._base import db_to_json
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
-from synapse.types import JsonDict
+from synapse.types import JsonDict, JsonMapping
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -34,7 +34,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
     async def get_tags_for_user(
         self, user_id: str
-    ) -> Mapping[str, Mapping[str, JsonDict]]:
+    ) -> Mapping[str, Mapping[str, JsonMapping]]:
         """Get all the tags for a user.
 
 
@@ -109,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
     async def get_updated_tags(
         self, user_id: str, stream_id: int
-    ) -> Mapping[str, Mapping[str, JsonDict]]:
+    ) -> Mapping[str, Mapping[str, JsonMapping]]:
         """Get all the tags for the rooms where the tags have changed since the
         given version