summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-11-11 08:47:31 -0500
committerGitHub <noreply@github.com>2021-11-11 08:47:31 -0500
commit64ef25391d22795463ebf3c48604f7aee1690fe4 (patch)
tree2d439152e7ac70db70a541c001ae058585751881 /synapse
parentFix error in thumbnail generation (#11288) (diff)
downloadsynapse-64ef25391d22795463ebf3c48604f7aee1690fe4.tar.xz
Add type hints to some storage classes (#11307)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/main/censor_events.py30
-rw-r--r--synapse/storage/databases/main/deviceinbox.py52
-rw-r--r--synapse/storage/databases/main/filtering.py6
-rw-r--r--synapse/storage/databases/main/lock.py6
-rw-r--r--synapse/storage/databases/main/openid.py17
-rw-r--r--synapse/storage/databases/main/tags.py27
-rw-r--r--synapse/storage/util/id_generators.py24
7 files changed, 115 insertions, 47 deletions
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index eee07227ef..0f56e10220 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -13,12 +13,12 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
 
 from synapse.events.utils import prune_event_dict
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.util import json_encoder
@@ -41,7 +41,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
             hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
 
     @wrap_as_background_process("_censor_redactions")
-    async def _censor_redactions(self):
+    async def _censor_redactions(self) -> None:
         """Censors all redactions older than the configured period that haven't
         been censored yet.
 
@@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 and original_event.internal_metadata.is_redacted()
             ):
                 # Redaction was allowed
-                pruned_json = json_encoder.encode(
+                pruned_json: Optional[str] = json_encoder.encode(
                     prune_event_dict(
                         original_event.room_version, original_event.get_dict()
                     )
@@ -116,7 +116,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
 
             updates.append((redaction_id, event_id, pruned_json))
 
-        def _update_censor_txn(txn):
+        def _update_censor_txn(txn: LoggingTransaction) -> None:
             for redaction_id, event_id, pruned_json in updates:
                 if pruned_json:
                     self._censor_event_txn(txn, event_id, pruned_json)
@@ -130,14 +130,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
 
         await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)
 
-    def _censor_event_txn(self, txn, event_id, pruned_json):
+    def _censor_event_txn(
+        self, txn: LoggingTransaction, event_id: str, pruned_json: str
+    ) -> None:
         """Censor an event by replacing its JSON in the event_json table with the
         provided pruned JSON.
 
         Args:
-            txn (LoggingTransaction): The database transaction.
-            event_id (str): The ID of the event to censor.
-            pruned_json (str): The pruned JSON
+            txn: The database transaction.
+            event_id: The ID of the event to censor.
+            pruned_json: The pruned JSON
         """
         self.db_pool.simple_update_one_txn(
             txn,
@@ -157,7 +159,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
         # Try to retrieve the event's content from the database or the event cache.
         event = await self.get_event(event_id)
 
-        def delete_expired_event_txn(txn):
+        def delete_expired_event_txn(txn: LoggingTransaction) -> None:
             # Delete the expiry timestamp associated with this event from the database.
             self._delete_event_expiry_txn(txn, event_id)
 
@@ -194,14 +196,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
             "delete_expired_event", delete_expired_event_txn
         )
 
-    def _delete_event_expiry_txn(self, txn, event_id):
+    def _delete_event_expiry_txn(self, txn: LoggingTransaction, event_id: str) -> None:
         """Delete the expiry timestamp associated with an event ID without deleting the
         actual event.
 
         Args:
-            txn (LoggingTransaction): The transaction to use to perform the deletion.
-            event_id (str): The event ID to delete the associated expiry timestamp of.
+            txn: The transaction to use to perform the deletion.
+            event_id: The event ID to delete the associated expiry timestamp of.
         """
-        return self.db_pool.simple_delete_txn(
+        self.db_pool.simple_delete_txn(
             txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
         )
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ae3afdd5d2..7c0f953365 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -1,4 +1,5 @@
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,9 +20,17 @@ from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.replication.tcp.streams import ToDeviceStream
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -34,14 +43,21 @@ logger = logging.getLogger(__name__)
 
 
 class DeviceInboxWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
 
         # Map of (user_id, device_id) to the last stream_id that has been
         # deleted up to. This is so that we can no op deletions.
-        self._last_device_delete_cache = ExpiringCache(
+        self._last_device_delete_cache: ExpiringCache[
+            Tuple[str, Optional[str]], int
+        ] = ExpiringCache(
             cache_name="last_device_delete_cache",
             clock=self._clock,
             max_len=10000,
@@ -53,14 +69,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 self._instance_name in hs.config.worker.writers.to_device
             )
 
-            self._device_inbox_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                stream_name="to_device",
-                instance_name=self._instance_name,
-                tables=[("device_inbox", "instance_name", "stream_id")],
-                sequence_name="device_inbox_sequence",
-                writers=hs.config.worker.writers.to_device,
+            self._device_inbox_id_gen: AbstractStreamIdGenerator = (
+                MultiWriterIdGenerator(
+                    db_conn=db_conn,
+                    db=database,
+                    stream_name="to_device",
+                    instance_name=self._instance_name,
+                    tables=[("device_inbox", "instance_name", "stream_id")],
+                    sequence_name="device_inbox_sequence",
+                    writers=hs.config.worker.writers.to_device,
+                )
             )
         else:
             self._can_write_to_device = True
@@ -101,6 +119,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == ToDeviceStream.NAME:
+            # If replication is happening than postgres must be being used.
+            assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
             self._device_inbox_id_gen.advance(instance_name, token)
             for row in rows:
                 if row.entity.startswith("@"):
@@ -220,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         log_kv({"message": f"deleted {count} messages for device", "count": count})
 
         # Update the cache, ensuring that we only ever increase the value
-        last_deleted_stream_id = self._last_device_delete_cache.get(
+        updated_last_deleted_stream_id = self._last_device_delete_cache.get(
             (user_id, device_id), 0
         )
         self._last_device_delete_cache[(user_id, device_id)] = max(
-            last_deleted_stream_id, up_to_stream_id
+            updated_last_deleted_stream_id, up_to_stream_id
         )
 
         return count
@@ -432,7 +452,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 )
 
         async with self._device_inbox_id_gen.get_next() as stream_id:
-            now_ms = self.clock.time_msec()
+            now_ms = self._clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
             )
@@ -483,7 +503,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             )
 
         async with self._device_inbox_id_gen.get_next() as stream_id:
-            now_ms = self.clock.time_msec()
+            now_ms = self._clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
                 add_messages_txn,
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 434986fa64..cf842803bc 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -1,4 +1,5 @@
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -18,6 +19,7 @@ from canonicaljson import encode_canonical_json
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
@@ -49,7 +51,7 @@ class FilteringStore(SQLBaseStore):
 
         # Need an atomic transaction to SELECT the maximal ID so far then
         # INSERT a new one
-        def _do_txn(txn):
+        def _do_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT filter_id FROM user_filters "
                 "WHERE user_id = ? AND filter_json = ?"
@@ -61,7 +63,7 @@ class FilteringStore(SQLBaseStore):
 
             sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
             txn.execute(sql, (user_localpart,))
-            max_id = txn.fetchone()[0]
+            max_id = txn.fetchone()[0]  # type: ignore[index]
             if max_id is None:
                 filter_id = 0
             else:
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 3d0df0cbd4..a540f7fb26 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import logging
 from types import TracebackType
-from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Tuple, Type
 from weakref import WeakValueDictionary
 
 from twisted.internet.interfaces import IReactorCore
@@ -62,7 +62,9 @@ class LockStore(SQLBaseStore):
 
         # A map from `(lock_name, lock_key)` to the token of any locks that we
         # think we currently hold.
-        self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
+        self._live_tokens: WeakValueDictionary[
+            Tuple[str, str], Lock
+        ] = WeakValueDictionary()
 
         # When we shut down we want to remove the locks. Technically this can
         # lead to a race, as we may drop the lock while we are still processing.
diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index 2aac64901b..a46685219f 100644
--- a/synapse/storage/databases/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -1,6 +1,21 @@
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 from typing import Optional
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
 
 
 class OpenIdStore(SQLBaseStore):
@@ -20,7 +35,7 @@ class OpenIdStore(SQLBaseStore):
     async def get_user_id_for_open_id_token(
         self, token: str, ts_now_ms: int
     ) -> Optional[str]:
-        def get_user_id_for_token_txn(txn):
+        def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]:
             sql = (
                 "SELECT user_id FROM open_id_tokens"
                 " WHERE token = ? AND ? <= ts_valid_until_ms"
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index f93ff0a545..8f510de53d 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -1,5 +1,6 @@
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,9 +15,10 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Tuple
+from typing import Dict, List, Tuple, cast
 
 from synapse.storage._base import db_to_json
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -50,7 +52,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
     async def get_all_updated_tags(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+    ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
         """Get updates for tags replication stream.
 
         Args:
@@ -75,7 +77,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_updated_tags_txn(txn):
+        def get_all_updated_tags_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str]]:
             sql = (
                 "SELECT stream_id, user_id, room_id"
                 " FROM room_tags_revisions as r"
@@ -83,13 +87,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
                 " ORDER BY stream_id ASC LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
+            # mypy doesn't understand what the query is selecting.
+            return cast(List[Tuple[int, str, str]], txn.fetchall())
 
         tag_ids = await self.db_pool.runInteraction(
             "get_all_updated_tags", get_all_updated_tags_txn
         )
 
-        def get_tag_content(txn, tag_ids):
+        def get_tag_content(
+            txn: LoggingTransaction, tag_ids
+        ) -> List[Tuple[int, Tuple[str, str, str]]]:
             sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
             results = []
             for stream_id, user_id, room_id in tag_ids:
@@ -127,15 +134,15 @@ class TagsWorkerStore(AccountDataWorkerStore):
         given version
 
         Args:
-            user_id(str): The user to get the tags for.
-            stream_id(int): The earliest update to get for the user.
+            user_id: The user to get the tags for.
+            stream_id: The earliest update to get for the user.
 
         Returns:
             A mapping from room_id strings to lists of tag strings for all the
             rooms that changed since the stream_id token.
         """
 
-        def get_updated_tags_txn(txn):
+        def get_updated_tags_txn(txn: LoggingTransaction) -> List[str]:
             sql = (
                 "SELECT room_id from room_tags_revisions"
                 " WHERE user_id = ? AND stream_id > ?"
@@ -200,7 +207,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
         content_json = json_encoder.encode(content)
 
-        def add_tag_txn(txn, next_id):
+        def add_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
             self.db_pool.simple_upsert_txn(
                 txn,
                 table="room_tags",
@@ -224,7 +231,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
         """
         assert self._can_write_to_account_data
 
-        def remove_tag_txn(txn, next_id):
+        def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
             sql = (
                 "DELETE FROM room_tags "
                 " WHERE user_id = ? AND room_id = ? AND tag = ?"
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 670811611f..ac56bc9a05 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -1,4 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -11,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import heapq
 import logging
 import threading
@@ -87,7 +89,25 @@ def _load_current_id(
     return (max if step > 0 else min)(current_id, step)
 
 
-class StreamIdGenerator:
+class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
+    @abc.abstractmethod
+    def get_next(self) -> AsyncContextManager[int]:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_current_token(self) -> int:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        raise NotImplementedError()
+
+
+class StreamIdGenerator(AbstractStreamIdGenerator):
     """Used to generate new stream ids when persisting events while keeping
     track of which transactions have been completed.
 
@@ -209,7 +229,7 @@ class StreamIdGenerator:
         return self.get_current_token()
 
 
-class MultiWriterIdGenerator:
+class MultiWriterIdGenerator(AbstractStreamIdGenerator):
     """An ID generator that tracks a stream that can have multiple writers.
 
     Uses a Postgres sequence to coordinate ID assignment, but positions of other