summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/presence.py26
-rw-r--r--synapse/storage/databases/main/presence.py61
-rw-r--r--synapse/storage/databases/main/purge_events.py13
-rw-r--r--synapse/storage/databases/main/user_directory.py22
-rw-r--r--synapse/types.py6
5 files changed, 78 insertions, 50 deletions
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 067c43ae47..b223b72623 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC):
         Returns:
             dict: `user_id` -> `UserPresenceState`
         """
-        states = {
-            user_id: self.user_to_current_state.get(user_id, None)
-            for user_id in user_ids
-        }
+        states = {}
+        missing = []
+        for user_id in user_ids:
+            state = self.user_to_current_state.get(user_id, None)
+            if state:
+                states[user_id] = state
+            else:
+                missing.append(user_id)
 
-        missing = [user_id for user_id, state in states.items() if not state]
         if missing:
             # There are things not in our in memory cache. Lets pull them out of
             # the database.
             res = await self.store.get_presence_for_users(missing)
             states.update(res)
 
-            missing = [user_id for user_id, state in states.items() if not state]
-            if missing:
-                new = {
-                    user_id: UserPresenceState.default(user_id) for user_id in missing
-                }
-                states.update(new)
-                self.user_to_current_state.update(new)
+            for user_id in missing:
+                # if user has no state in database, create the state
+                if not res.get(user_id, None):
+                    new_state = UserPresenceState.default(user_id)
+                    states[user_id] = new_state
+                    self.user_to_current_state[user_id] = new_state
 
         return states
 
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4f05811a77..d3c4611686 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,15 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
 
 from synapse.api.presence import PresenceState, UserPresenceState
 from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.iterutils import batch_iter
@@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
         # Used by `PresenceStore._get_active_presence()`
@@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
+        self._instance_name = hs.get_instance_name()
+        self._presence_id_gen: AbstractStreamIdGenerator
+
         self._can_persist_presence = (
-            hs.get_instance_name() in hs.config.worker.writers.presence
+            self._instance_name in hs.config.worker.writers.presence
         )
 
         if isinstance(database.engine, PostgresEngine):
@@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return stream_orderings[-1], self._presence_id_gen.get_current_token()
 
-    def _update_presence_txn(self, txn, stream_orderings, presence_states):
+    def _update_presence_txn(
+        self, txn: LoggingTransaction, stream_orderings, presence_states
+    ) -> None:
         for stream_id, state in zip(stream_orderings, presence_states):
             txn.call_after(
                 self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_presence_updates_txn(txn):
+        def get_all_presence_updates_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, list]], int, bool]:
             sql = """
                 SELECT stream_id, user_id, state, last_active_ts,
                     last_federation_update_ts, last_user_sync_ts,
-                    status_msg,
-                currently_active
+                    status_msg, currently_active
                 FROM presence_stream
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [(row[0], row[1:]) for row in txn]
+            updates = cast(
+                List[Tuple[int, list]],
+                [(row[0], row[1:]) for row in txn],
+            )
 
             upper_bound = current_id
             limited = False
@@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         )
 
     @cached()
-    def _get_presence_for_user(self, user_id):
+    def _get_presence_for_user(self, user_id: str) -> None:
         raise NotImplementedError()
 
     @cachedList(
@@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         list_name="user_ids",
         num_args=1,
     )
-    async def get_presence_for_users(self, user_ids):
+    async def get_presence_for_users(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, UserPresenceState]:
         rows = await self.db_pool.simple_select_many_batch(
             table="presence_stream",
             column="user_id",
@@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             True if the user should have full presence sent to them, False otherwise.
         """
 
-        def _should_user_receive_full_presence_with_token_txn(txn):
+        def _should_user_receive_full_presence_with_token_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             sql = """
                 SELECT 1 FROM users_to_send_full_presence_to
                 WHERE user_id = ?
@@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             _should_user_receive_full_presence_with_token_txn,
         )
 
-    async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
+    async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
         """Adds to the list of users who should receive a full snapshot of presence
         upon their next sync.
 
@@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return users_to_state
 
-    def get_current_presence_token(self):
+    def get_current_presence_token(self) -> int:
         return self._presence_id_gen.get_current_token()
 
-    def _get_active_presence(self, db_conn: Connection):
+    def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
         """Fetch non-offline presence from the database so that we can register
         the appropriate time outs.
         """
@@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return [UserPresenceState(**row) for row in rows]
 
-    def take_presence_startup_info(self):
+    def take_presence_startup_info(self) -> List[UserPresenceState]:
         active_on_startup = self._presence_on_startup
-        self._presence_on_startup = None
+        self._presence_on_startup = []
         return active_on_startup
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
         if stream_name == PresenceStream.NAME:
             self._presence_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index e87a8fb85d..2e3818e432 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -13,9 +13,10 @@
 # limitations under the License.
 
 import logging
-from typing import Any, List, Set, Tuple
+from typing import Any, List, Set, Tuple, cast
 
 from synapse.api.errors import SynapseError
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.state import StateGroupWorkerStore
 from synapse.types import RoomStreamToken
@@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         )
 
     def _purge_history_txn(
-        self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        token: RoomStreamToken,
+        delete_local_events: bool,
     ) -> Set[int]:
         # Tables that should be pruned:
         #     event_auth
@@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         """,
             (room_id,),
         )
-        (min_depth,) = txn.fetchone()
+        (min_depth,) = cast(Tuple[int], txn.fetchone())
 
         logger.info("[purge] updating room_depth to %d", min_depth)
 
@@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "purge_room", self._purge_room_txn, room_id
         )
 
-    def _purge_room_txn(self, txn, room_id: str) -> List[int]:
+    def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
         # First we fetch all the state groups that should be deleted, before
         # we delete that information.
         txn.execute(
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f7c778bdf2..e7fddd2426 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         processed_event_count = 0
 
         for room_id, event_count in rooms_to_work_on:
-            is_in_room = await self.is_host_joined(room_id, self.server_name)
+            is_in_room = await self.is_host_joined(room_id, self.server_name)  # type: ignore[attr-defined]
 
             if is_in_room:
-                users_with_profile = await self.get_users_in_room_with_profiles(room_id)
+                users_with_profile = await self.get_users_in_room_with_profiles(room_id)  # type: ignore[attr-defined]
                 # Throw away users excluded from the directory.
                 users_with_profile = {
                     user_id: profile
@@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         for user_id in users_to_work_on:
             if await self.should_include_local_user_in_dir(user_id):
-                profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+                profile = await self.get_profileinfo(get_localpart_from_id(user_id))  # type: ignore[attr-defined]
                 await self.update_profile_in_user_dir(
                     user_id, profile.display_name, profile.avatar_url
                 )
@@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         # technically it could be DM-able. In the future, this could potentially
         # be configurable per-appservice whether the appservice sender can be
         # contacted.
-        if self.get_app_service_by_user_id(user) is not None:
+        if self.get_app_service_by_user_id(user) is not None:  # type: ignore[attr-defined]
             return False
 
         # We're opting to exclude appservice users (anyone matching the user
@@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         # they could be DM-able. In the future, this could potentially
         # be configurable per-appservice whether the appservice users can be
         # contacted.
-        if self.get_if_app_services_interested_in_user(user):
+        if self.get_if_app_services_interested_in_user(user):  # type: ignore[attr-defined]
             # TODO we might want to make this configurable for each app service
             return False
 
         # Support users are for diagnostics and should not appear in the user directory.
-        if await self.is_support_user(user):
+        if await self.is_support_user(user):  # type: ignore[attr-defined]
             return False
 
         # Deactivated users aren't contactable, so should not appear in the user directory.
         try:
-            if await self.get_user_deactivated_status(user):
+            if await self.get_user_deactivated_status(user):  # type: ignore[attr-defined]
                 return False
         except StoreError:
             # No such user in the users table. No need to do this when calling
@@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             (EventTypes.RoomHistoryVisibility, ""),
         )
 
-        current_state_ids = await self.get_filtered_current_state_ids(
+        current_state_ids = await self.get_filtered_current_state_ids(  # type: ignore[attr-defined]
             room_id, StateFilter.from_types(types_to_filter)
         )
 
         join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
         if join_rules_id:
-            join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
+            join_rule_ev = await self.get_event(join_rules_id, allow_none=True)  # type: ignore[attr-defined]
             if join_rule_ev:
                 if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
                     return True
 
         hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
         if hist_vis_id:
-            hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
+            hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)  # type: ignore[attr-defined]
             if hist_vis_ev:
                 if (
                     hist_vis_ev.content.get("history_visibility")
diff --git a/synapse/types.py b/synapse/types.py
index f89fb216a6..53be3583a0 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
     from synapse.appservice.api import ApplicationService
-    from synapse.storage.databases.main import DataStore
+    from synapse.storage.databases.main import DataStore, PurgeEventsStore
 
 # Define a state map type from type/state_key to T (usually an event ID or
 # event)
@@ -485,7 +485,7 @@ class RoomStreamToken:
             )
 
     @classmethod
-    async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
+    async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
         try:
             if string[0] == "s":
                 return cls(topological=None, stream=int(string[1:]))
@@ -502,7 +502,7 @@ class RoomStreamToken:
                     instance_id = int(key)
                     pos = int(value)
 
-                    instance_name = await store.get_name_from_instance_id(instance_id)
+                    instance_name = await store.get_name_from_instance_id(instance_id)  # type: ignore[attr-defined]
                     instance_map[instance_name] = pos
 
                 return cls(