diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index a27702a7a0..027bf8c85e 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -461,17 +461,29 @@ class RoomMemberWorkerStore(EventsWorkerStore):
def _get_joined_hosts_cache(self, room_id):
return _JoinedHostsCache(self, room_id)
- @cached()
- def who_forgot_in_room(self, room_id):
- return self._simple_select_list(
- table="room_memberships",
- retcols=("user_id", "event_id"),
- keyvalues={
- "room_id": room_id,
- "forgotten": 1,
- },
- desc="who_forgot"
- )
+ @cachedInlineCallbacks(num_args=2)
+ def did_forget(self, user_id, room_id):
+ """Returns whether user_id has elected to discard history for room_id.
+
+ Returns False if they have since re-joined."""
+ def f(txn):
+ sql = (
+ "SELECT"
+ " COUNT(*)"
+ " FROM"
+ " room_memberships"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ " AND"
+ " forgotten = 0"
+ )
+ txn.execute(sql, (user_id, room_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+ count = yield self.runInteraction("did_forget_membership", f)
+ defer.returnValue(count == 0)
class RoomMemberStore(RoomMemberWorkerStore):
@@ -580,36 +592,11 @@ class RoomMemberStore(RoomMemberWorkerStore):
)
txn.execute(sql, (user_id, room_id))
- txn.call_after(self.did_forget.invalidate, (user_id, room_id))
self._invalidate_cache_and_stream(
- txn, self.who_forgot_in_room, (room_id,)
+ txn, self.did_forget, (user_id, room_id,),
)
return self.runInteraction("forget_membership", f)
- @cachedInlineCallbacks(num_args=2)
- def did_forget(self, user_id, room_id):
- """Returns whether user_id has elected to discard history for room_id.
-
- Returns False if they have since re-joined."""
- def f(txn):
- sql = (
- "SELECT"
- " COUNT(*)"
- " FROM"
- " room_memberships"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- " AND"
- " forgotten = 0"
- )
- txn.execute(sql, (user_id, room_id))
- rows = txn.fetchall()
- return rows[0][0]
- count = yield self.runInteraction("did_forget_membership", f)
- defer.returnValue(count == 0)
-
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
|