summary refs log tree commit diff
path: root/synapse/storage/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/roommember.py')
-rw-r--r--synapse/storage/roommember.py41
1 files changed, 20 insertions, 21 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 839c74f63a..d36a6c18a8 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -66,6 +66,7 @@ class RoomMemberStore(SQLBaseStore):
 
         txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
         txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
+        txn.call_after(self.get_users_in_room.invalidate, event.room_id)
 
     def get_room_member(self, user_id, room_id):
         """Retrieve the current state of a room member.
@@ -76,17 +77,18 @@ class RoomMemberStore(SQLBaseStore):
         Returns:
             Deferred: Results in a MembershipEvent or None.
         """
-        def f(txn):
-            events = self._get_members_events_txn(
-                txn,
-                room_id,
-                user_id=user_id,
-            )
-
-            return events[0] if events else None
-
-        return self.runInteraction("get_room_member", f)
+        return self.runInteraction(
+            "get_room_member",
+            self._get_members_events_txn,
+            room_id,
+            user_id=user_id,
+        ).addCallback(
+            self._get_events
+        ).addCallback(
+            lambda events: events[0] if events else None
+        )
 
+    @cached()
     def get_users_in_room(self, room_id):
         def f(txn):
 
@@ -110,15 +112,12 @@ class RoomMemberStore(SQLBaseStore):
         Returns:
             list of namedtuples representing the members in this room.
         """
-
-        def f(txn):
-            return self._get_members_events_txn(
-                txn,
-                room_id,
-                membership=membership,
-            )
-
-        return self.runInteraction("get_room_members", f)
+        return self.runInteraction(
+            "get_room_members",
+            self._get_members_events_txn,
+            room_id,
+            membership=membership,
+        ).addCallback(self._get_events)
 
     def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
         """ Get all the rooms for this user where the membership for this user
@@ -190,14 +189,14 @@ class RoomMemberStore(SQLBaseStore):
         return self.runInteraction(
             "get_members_query", self._get_members_events_txn,
             where_clause, where_values
-        )
+        ).addCallbacks(self._get_events)
 
     def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
         rows = self._get_members_rows_txn(
             txn,
             room_id, membership, user_id,
         )
-        return self._get_events_txn(txn, [r["event_id"] for r in rows])
+        return [r["event_id"] for r in rows]
 
     def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
         where_clause = "c.room_id = ?"