summary refs log tree commit diff
path: root/synapse/storage/databases/main/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/roommember.py')
-rw-r--r--synapse/storage/databases/main/roommember.py68
1 files changed, 47 insertions, 21 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 86ffe2479e..dcdaf09682 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -14,19 +14,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
-    LoggingTransaction,
-    SQLBaseStore,
-    db_to_json,
-    make_in_list_sql_clause,
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
 )
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
@@ -60,27 +58,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # background update still running?
         self._current_state_events_membership_up_to_date = False
 
-        txn = LoggingTransaction(
-            db_conn.cursor(),
-            name="_check_safe_current_state_events_membership_updated",
-            database_engine=self.database_engine,
+        txn = db_conn.cursor(
+            txn_name="_check_safe_current_state_events_membership_updated"
         )
         self._check_safe_current_state_events_membership_updated_txn(txn)
         txn.close()
 
-        if self.hs.config.metrics_flags.known_servers:
+        if (
+            self.hs.config.run_background_tasks
+            and self.hs.config.metrics_flags.known_servers
+        ):
             self._known_servers_count = 1
             self.hs.get_clock().looping_call(
-                run_as_background_process,
-                60 * 1000,
-                "_count_known_servers",
-                self._count_known_servers,
+                self._count_known_servers, 60 * 1000,
             )
             self.hs.get_clock().call_later(
-                1000,
-                run_as_background_process,
-                "_count_known_servers",
-                self._count_known_servers,
+                1000, self._count_known_servers,
             )
             LaterGauge(
                 "synapse_federation_known_servers",
@@ -89,6 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 lambda: self._known_servers_count,
             )
 
+    @wrap_as_background_process("_count_known_servers")
     async def _count_known_servers(self):
         """
         Count the servers that this server knows about.
@@ -356,6 +350,38 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return results
 
+    async def get_local_current_membership_for_user_in_room(
+        self, user_id: str, room_id: str
+    ) -> Tuple[Optional[str], Optional[str]]:
+        """Retrieve the current local membership state and event ID for a user in a room.
+
+        Args:
+            user_id: The ID of the user.
+            room_id: The ID of the room.
+
+        Returns:
+            A tuple of (membership_type, event_id). Both will be None if a
+                room_id/user_id pair is not found.
+        """
+        # Paranoia check.
+        if not self.hs.is_mine_id(user_id):
+            raise Exception(
+                "Cannot call 'get_local_current_membership_for_user_in_room' on "
+                "non-local user %s" % (user_id,),
+            )
+
+        results_dict = await self.db_pool.simple_select_one(
+            "local_current_membership",
+            {"room_id": room_id, "user_id": user_id},
+            ("membership", "event_id"),
+            allow_none=True,
+            desc="get_local_current_membership_for_user_in_room",
+        )
+        if not results_dict:
+            return None, None
+
+        return results_dict.get("membership"), results_dict.get("event_id")
+
     @cached(max_entries=500000, iterable=True)
     async def get_rooms_for_user_with_stream_ordering(
         self, user_id: str
@@ -535,7 +561,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # If we do then we can reuse that result and simply update it with
             # any membership changes in `delta_ids`
             if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get(
+                prev_res = self._get_joined_users_from_context.cache.get_immediate(
                     (room_id, context.prev_group), None
                 )
                 if prev_res and isinstance(prev_res, dict):