summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8232.misc1
-rw-r--r--mypy.ini1
-rw-r--r--synapse/events/__init__.py4
-rw-r--r--synapse/storage/database.py34
-rw-r--r--synapse/storage/databases/main/stream.py46
5 files changed, 66 insertions, 20 deletions
diff --git a/changelog.d/8232.misc b/changelog.d/8232.misc
new file mode 100644
index 0000000000..3a7a352c4f
--- /dev/null
+++ b/changelog.d/8232.misc
@@ -0,0 +1 @@
+Add type hints to `StreamStore`.
diff --git a/mypy.ini b/mypy.ini
index 21c6f523a0..ae3290d5bb 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -43,6 +43,7 @@ files =
   synapse/server_notices,
   synapse/spam_checker_api,
   synapse/state,
+  synapse/storage/databases/main/stream.py,
   synapse/storage/databases/main/ui_auth.py,
   synapse/storage/database.py,
   synapse/storage/engines,
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 67db763dbf..62ea44fa49 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -18,7 +18,7 @@
 import abc
 import os
 from distutils.util import strtobool
-from typing import Dict, Optional, Type
+from typing import Dict, Optional, Tuple, Type
 
 from unpaddedbase64 import encode_base64
 
@@ -120,7 +120,7 @@ class _EventInternalMetadata(object):
     # be here
     before = DictProperty("before")  # type: str
     after = DictProperty("after")  # type: str
-    order = DictProperty("order")  # type: int
+    order = DictProperty("order")  # type: Tuple[int, int]
 
     def get_dict(self) -> JsonDict:
         return dict(self._dict)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 7ab370efef..af8796ad92 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -604,6 +604,18 @@ class DatabasePool(object):
         results = [dict(zip(col_headers, row)) for row in cursor]
         return results
 
+    @overload
+    async def execute(
+        self, desc: str, decoder: Literal[None], query: str, *args: Any
+    ) -> List[Tuple[Any, ...]]:
+        ...
+
+    @overload
+    async def execute(
+        self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
+    ) -> R:
+        ...
+
     async def execute(
         self,
         desc: str,
@@ -1088,6 +1100,28 @@ class DatabasePool(object):
             desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
         )
 
+    @overload
+    async def simple_select_one_onecol(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: Literal[False] = False,
+        desc: str = "simple_select_one_onecol",
+    ) -> Any:
+        ...
+
+    @overload
+    async def simple_select_one_onecol(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: Literal[True] = True,
+        desc: str = "simple_select_one_onecol",
+    ) -> Optional[Any]:
+        ...
+
     async def simple_select_one_onecol(
         self,
         table: str,
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 83c1ddf95a..be6df8a6d1 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,7 +39,7 @@ what sort order was used:
 import abc
 import logging
 from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 from twisted.internet import defer
 
@@ -54,9 +54,12 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.types import RoomStreamToken
+from synapse.types import Collection, RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -206,7 +209,7 @@ def _make_generic_sql_bound(
     )
 
 
-def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
+def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
     # "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -264,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super(StreamWorkerStore, self).__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
@@ -297,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         self._stream_order_on_start = self.get_room_max_stream_ordering()
 
     @abc.abstractmethod
-    def get_room_max_stream_ordering(self):
+    def get_room_max_stream_ordering(self) -> int:
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def get_room_min_stream_ordering(self):
+    def get_room_min_stream_ordering(self) -> int:
         raise NotImplementedError()
 
     async def get_room_events_stream_for_rooms(
         self,
-        room_ids: Iterable[str],
+        room_ids: Collection[str],
         from_key: str,
         to_key: str,
         limit: int = 0,
@@ -360,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return results
 
-    def get_rooms_that_changed(self, room_ids, from_key):
+    def get_rooms_that_changed(
+        self, room_ids: Collection[str], from_key: str
+    ) -> Set[str]:
         """Given a list of rooms and a token, return rooms where there may have
         been changes.
 
         Args:
-            room_ids (list)
-            from_key (str): The room_key portion of a StreamToken
+            room_ids
+            from_key: The room_key portion of a StreamToken
         """
-        from_key = RoomStreamToken.parse_stream_token(from_key).stream
+        from_id = RoomStreamToken.parse_stream_token(from_key).stream
         return {
             room_id
             for room_id in room_ids
-            if self._events_stream_cache.has_entity_changed(room_id, from_key)
+            if self._events_stream_cache.has_entity_changed(room_id, from_id)
         }
 
     async def get_room_events_stream_for_room(
@@ -444,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return ret, key
 
-    async def get_membership_changes_for_user(self, user_id, from_key, to_key):
+    async def get_membership_changes_for_user(
+        self, user_id: str, from_key: str, to_key: str
+    ) -> List[EventBase]:
         from_id = RoomStreamToken.parse_stream_token(from_key).stream
         to_id = RoomStreamToken.parse_stream_token(to_key).stream
 
@@ -661,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
         return row[0][0] if row else 0
 
-    def _get_max_topological_txn(self, txn, room_id):
+    def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
         txn.execute(
             "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
             (room_id,),
@@ -734,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     def _get_events_around_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         room_id: str,
         event_id: str,
         before_limit: int,
@@ -762,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             retcols=["stream_ordering", "topological_ordering"],
         )
 
+        # This cannot happen as `allow_none=False`.
+        assert results is not None
+
         # Paginating backwards includes the event at the token, but paginating
         # forward doesn't.
         before_token = RoomStreamToken(
@@ -871,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             desc="update_federation_out_pos",
         )
 
-    def _reset_federation_positions_txn(self, txn) -> None:
+    def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
         """Fiddles with the `federation_stream_position` table to make it match
         the configured federation sender instances during start up.
         """
@@ -910,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             GROUP BY type
         """
         txn.execute(sql)
-        min_positions = dict(txn)  # Map from type -> min position
+        min_positions = {typ: pos for typ, pos in txn}  # Map from type -> min position
 
         # Ensure we do actually have some values here
         assert set(min_positions) == {"federation", "events"}
@@ -937,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     def _paginate_room_events_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         room_id: str,
         from_token: RoomStreamToken,
         to_token: Optional[RoomStreamToken] = None,