summary refs log tree commit diff
path: root/synapse/storage/databases/main/presence.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/presence.py')
-rw-r--r--synapse/storage/databases/main/presence.py61
1 files changed, 41 insertions, 20 deletions
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4f05811a77..d3c4611686 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,15 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
 
 from synapse.api.presence import PresenceState, UserPresenceState
 from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.iterutils import batch_iter
@@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
         # Used by `PresenceStore._get_active_presence()`
@@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
+        self._instance_name = hs.get_instance_name()
+        self._presence_id_gen: AbstractStreamIdGenerator
+
         self._can_persist_presence = (
-            hs.get_instance_name() in hs.config.worker.writers.presence
+            self._instance_name in hs.config.worker.writers.presence
         )
 
         if isinstance(database.engine, PostgresEngine):
@@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return stream_orderings[-1], self._presence_id_gen.get_current_token()
 
-    def _update_presence_txn(self, txn, stream_orderings, presence_states):
+    def _update_presence_txn(
+        self, txn: LoggingTransaction, stream_orderings, presence_states
+    ) -> None:
         for stream_id, state in zip(stream_orderings, presence_states):
             txn.call_after(
                 self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_presence_updates_txn(txn):
+        def get_all_presence_updates_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, list]], int, bool]:
             sql = """
                 SELECT stream_id, user_id, state, last_active_ts,
                     last_federation_update_ts, last_user_sync_ts,
-                    status_msg,
-                currently_active
+                    status_msg, currently_active
                 FROM presence_stream
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [(row[0], row[1:]) for row in txn]
+            updates = cast(
+                List[Tuple[int, list]],
+                [(row[0], row[1:]) for row in txn],
+            )
 
             upper_bound = current_id
             limited = False
@@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         )
 
     @cached()
-    def _get_presence_for_user(self, user_id):
+    def _get_presence_for_user(self, user_id: str) -> None:
         raise NotImplementedError()
 
     @cachedList(
@@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         list_name="user_ids",
         num_args=1,
     )
-    async def get_presence_for_users(self, user_ids):
+    async def get_presence_for_users(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, UserPresenceState]:
         rows = await self.db_pool.simple_select_many_batch(
             table="presence_stream",
             column="user_id",
@@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             True if the user should have full presence sent to them, False otherwise.
         """
 
-        def _should_user_receive_full_presence_with_token_txn(txn):
+        def _should_user_receive_full_presence_with_token_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             sql = """
                 SELECT 1 FROM users_to_send_full_presence_to
                 WHERE user_id = ?
@@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             _should_user_receive_full_presence_with_token_txn,
         )
 
-    async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
+    async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
         """Adds to the list of users who should receive a full snapshot of presence
         upon their next sync.
 
@@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return users_to_state
 
-    def get_current_presence_token(self):
+    def get_current_presence_token(self) -> int:
         return self._presence_id_gen.get_current_token()
 
-    def _get_active_presence(self, db_conn: Connection):
+    def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
         """Fetch non-offline presence from the database so that we can register
         the appropriate time outs.
         """
@@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return [UserPresenceState(**row) for row in rows]
 
-    def take_presence_startup_info(self):
+    def take_presence_startup_info(self) -> List[UserPresenceState]:
         active_on_startup = self._presence_on_startup
-        self._presence_on_startup = None
+        self._presence_on_startup = []
         return active_on_startup
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
         if stream_name == PresenceStream.NAME:
             self._presence_id_gen.advance(instance_name, token)
             for row in rows: