summary refs log tree commit diff
path: root/synapse/storage/databases/main/stream.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-09-02 17:52:38 +0100
committerGitHub <noreply@github.com>2020-09-02 17:52:38 +0100
commit112266eafd457204a34a76fa51d7074d0809a1db (patch)
treee860525baa50d8fc0cd5e5b5c2ce7d2a9727ee92 /synapse/storage/databases/main/stream.py
parentRe-implement unread counts (again) (#8059) (diff)
downloadsynapse-112266eafd457204a34a76fa51d7074d0809a1db.tar.xz
Add StreamStore to mypy (#8232)
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r--synapse/storage/databases/main/stream.py46
1 files changed, 28 insertions, 18 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 83c1ddf95a..be6df8a6d1 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,7 +39,7 @@ what sort order was used:
 import abc
 import logging
 from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 from twisted.internet import defer
 
@@ -54,9 +54,12 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.types import RoomStreamToken
+from synapse.types import Collection, RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -206,7 +209,7 @@ def _make_generic_sql_bound(
     )
 
 
-def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
+def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
     # "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -264,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super(StreamWorkerStore, self).__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
@@ -297,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         self._stream_order_on_start = self.get_room_max_stream_ordering()
 
     @abc.abstractmethod
-    def get_room_max_stream_ordering(self):
+    def get_room_max_stream_ordering(self) -> int:
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def get_room_min_stream_ordering(self):
+    def get_room_min_stream_ordering(self) -> int:
         raise NotImplementedError()
 
     async def get_room_events_stream_for_rooms(
         self,
-        room_ids: Iterable[str],
+        room_ids: Collection[str],
         from_key: str,
         to_key: str,
         limit: int = 0,
@@ -360,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return results
 
-    def get_rooms_that_changed(self, room_ids, from_key):
+    def get_rooms_that_changed(
+        self, room_ids: Collection[str], from_key: str
+    ) -> Set[str]:
         """Given a list of rooms and a token, return rooms where there may have
         been changes.
 
         Args:
-            room_ids (list)
-            from_key (str): The room_key portion of a StreamToken
+            room_ids
+            from_key: The room_key portion of a StreamToken
         """
-        from_key = RoomStreamToken.parse_stream_token(from_key).stream
+        from_id = RoomStreamToken.parse_stream_token(from_key).stream
         return {
             room_id
             for room_id in room_ids
-            if self._events_stream_cache.has_entity_changed(room_id, from_key)
+            if self._events_stream_cache.has_entity_changed(room_id, from_id)
         }
 
     async def get_room_events_stream_for_room(
@@ -444,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return ret, key
 
-    async def get_membership_changes_for_user(self, user_id, from_key, to_key):
+    async def get_membership_changes_for_user(
+        self, user_id: str, from_key: str, to_key: str
+    ) -> List[EventBase]:
         from_id = RoomStreamToken.parse_stream_token(from_key).stream
         to_id = RoomStreamToken.parse_stream_token(to_key).stream
 
@@ -661,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
         return row[0][0] if row else 0
 
-    def _get_max_topological_txn(self, txn, room_id):
+    def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
         txn.execute(
             "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
             (room_id,),
@@ -734,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     def _get_events_around_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         room_id: str,
         event_id: str,
         before_limit: int,
@@ -762,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             retcols=["stream_ordering", "topological_ordering"],
         )
 
+        # This cannot happen as `allow_none=False`.
+        assert results is not None
+
         # Paginating backwards includes the event at the token, but paginating
         # forward doesn't.
         before_token = RoomStreamToken(
@@ -871,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             desc="update_federation_out_pos",
         )
 
-    def _reset_federation_positions_txn(self, txn) -> None:
+    def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
         """Fiddles with the `federation_stream_position` table to make it match
         the configured federation sender instances during start up.
         """
@@ -910,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             GROUP BY type
         """
         txn.execute(sql)
-        min_positions = dict(txn)  # Map from type -> min position
+        min_positions = {typ: pos for typ, pos in txn}  # Map from type -> min position
 
         # Ensure we do actually have some values here
         assert set(min_positions) == {"federation", "events"}
@@ -937,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     def _paginate_room_events_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         room_id: str,
         from_token: RoomStreamToken,
         to_token: Optional[RoomStreamToken] = None,