summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11342.misc1
-rw-r--r--mypy.ini7
-rw-r--r--synapse/storage/databases/main/profile.py12
-rw-r--r--tests/storage/test_profile.py9
4 files changed, 21 insertions, 8 deletions
diff --git a/changelog.d/11342.misc b/changelog.d/11342.misc
new file mode 100644
index 0000000000..86594a332d
--- /dev/null
+++ b/changelog.d/11342.misc
@@ -0,0 +1 @@
+Add type hints to storage classes.
diff --git a/mypy.ini b/mypy.ini
index 710b1f3a4b..b2953974ea 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -38,7 +38,6 @@ exclude = (?x)
    |synapse/storage/databases/main/metrics.py
    |synapse/storage/databases/main/monthly_active_users.py
    |synapse/storage/databases/main/presence.py
-   |synapse/storage/databases/main/profile.py
    |synapse/storage/databases/main/purge_events.py
    |synapse/storage/databases/main/push_rule.py
    |synapse/storage/databases/main/receipts.py
@@ -182,6 +181,9 @@ disallow_untyped_defs = True
 [mypy-synapse.storage.databases.main.room_batch]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.profile]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.state_deltas]
 disallow_untyped_defs = True
 
@@ -284,6 +286,9 @@ disallow_untyped_defs = True
 [mypy-tests.handlers.test_user_directory]
 disallow_untyped_defs = True
 
+[mypy-tests.storage.test_profile]
+disallow_untyped_defs = True
+
 [mypy-tests.storage.test_user_directory]
 disallow_untyped_defs = True
 
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index dd8e27e226..e197b7203e 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.roommember import ProfileInfo
 
 
@@ -104,7 +105,7 @@ class ProfileWorkerStore(SQLBaseStore):
             desc="update_remote_profile_cache",
         )
 
-    async def maybe_delete_remote_profile_cache(self, user_id):
+    async def maybe_delete_remote_profile_cache(self, user_id: str) -> None:
         """Check if we still care about the remote user's profile, and if we
         don't then remove their profile from the cache
         """
@@ -116,9 +117,9 @@ class ProfileWorkerStore(SQLBaseStore):
                 desc="delete_remote_profile_cache",
             )
 
-    async def is_subscribed_remote_profile_for_user(self, user_id):
+    async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool:
         """Check whether we are interested in a remote user's profile."""
-        res = await self.db_pool.simple_select_one_onecol(
+        res: Optional[str] = await self.db_pool.simple_select_one_onecol(
             table="group_users",
             keyvalues={"user_id": user_id},
             retcol="user_id",
@@ -139,13 +140,16 @@ class ProfileWorkerStore(SQLBaseStore):
 
         if res:
             return True
+        return False
 
     async def get_remote_profile_cache_entries_that_expire(
         self, last_checked: int
     ) -> List[Dict[str, str]]:
         """Get all users who haven't been checked since `last_checked`"""
 
-        def _get_remote_profile_cache_entries_that_expire_txn(txn):
+        def _get_remote_profile_cache_entries_that_expire_txn(
+            txn: LoggingTransaction,
+        ) -> List[Dict[str, str]]:
             sql = """
                 SELECT user_id, displayname, avatar_url
                 FROM remote_profile_cache
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index a1ba99ff14..d37736edf8 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -11,19 +11,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.util import Clock
 
 from tests import unittest
 
 
 class ProfileStoreTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastore()
 
         self.u_frank = UserID.from_string("@frank:test")
 
-    def test_displayname(self):
+    def test_displayname(self) -> None:
         self.get_success(self.store.create_profile(self.u_frank.localpart))
 
         self.get_success(
@@ -48,7 +51,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
             self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
         )
 
-    def test_avatar_url(self):
+    def test_avatar_url(self) -> None:
         self.get_success(self.store.create_profile(self.u_frank.localpart))
 
         self.get_success(