summary refs log tree commit diff
path: root/synapse/storage/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/receipts.py')
-rw-r--r--synapse/storage/receipts.py126
1 files changed, 97 insertions, 29 deletions
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py

index dbc074d6b5..9747a04a9a 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py
@@ -31,9 +31,29 @@ class ReceiptsStore(SQLBaseStore): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() ) + @cachedInlineCallbacks() + def get_users_with_read_receipts_in_room(self, room_id): + receipts = yield self.get_receipts_for_room(room_id, "m.read") + defer.returnValue(set(r['user_id'] for r in receipts)) + + def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, + user_id): + if receipt_type != "m.read": + return + + # Returns an ObservableDeferred + res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None) + + if res and res.called and user_id in res.result: + # We'd only be adding to the set, so no point invalidating if the + # user is already there + return + + self.get_users_with_read_receipts_in_room.invalidate((room_id,)) + @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( @@ -62,18 +82,42 @@ class ReceiptsStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): + rows = yield self._simple_select_list( + table="receipts_linearized", + keyvalues={ + "user_id": user_id, + "receipt_type": receipt_type, + }, + retcols=("room_id", "event_id"), + desc="get_receipts_for_user", + ) + + defer.returnValue({row["room_id"]: row["event_id"] for row in rows}) + + @defer.inlineCallbacks + def get_receipts_for_user_with_orderings(self, user_id, receipt_type): def f(txn): sql = ( - "SELECT room_id,event_id " - "FROM receipts_linearized " - "WHERE user_id = ? AND receipt_type = ? " + "SELECT rl.room_id, rl.event_id," + " e.topological_ordering, e.stream_ordering" + " FROM receipts_linearized AS rl" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE rl.room_id = e.room_id" + " AND rl.event_id = e.event_id" + " AND user_id = ?" ) - txn.execute(sql, (user_id, receipt_type)) + txn.execute(sql, (user_id,)) return txn.fetchall() - - defer.returnValue(dict( - (yield self.runInteraction("get_receipts_for_user", f)) - )) + rows = yield self.runInteraction( + "get_receipts_for_user_with_orderings", f + ) + defer.returnValue({ + row[0]: { + "event_id": row[1], + "topological_ordering": row[2], + "stream_ordering": row[3], + } for row in rows + }) @defer.inlineCallbacks def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): @@ -101,7 +145,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue([ev for res in results.values() for ev in res]) - @cachedInlineCallbacks(num_args=3, max_entries=5000) + @cachedInlineCallbacks(num_args=3, tree=True) def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. @@ -161,8 +205,8 @@ class ReceiptsStore(SQLBaseStore): "content": content, }]) - @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", - num_args=3, inlineCallbacks=True) + @cachedList(cached_method_name="get_linearized_receipts_for_room", + list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: defer.returnValue({}) @@ -222,7 +266,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue(results) def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_max_token() + return self._receipts_id_gen.get_current_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): @@ -230,10 +274,14 @@ class ReceiptsStore(SQLBaseStore): self.get_receipts_for_room.invalidate, (room_id, receipt_type) ) txn.call_after( + self._invalidate_get_users_with_receipts_in_room, + room_id, receipt_type, user_id, + ) + txn.call_after( self.get_receipts_for_user.invalidate, (user_id, receipt_type) ) # FIXME: This shouldn't invalidate the whole cache - txn.call_after(self.get_linearized_receipts_for_room.invalidate_all) + txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) txn.call_after( self._receipts_stream_cache.entity_has_changed, @@ -245,6 +293,17 @@ class ReceiptsStore(SQLBaseStore): (user_id, room_id, receipt_type) ) + res = self._simple_select_one_txn( + txn, + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True + ) + + topological_ordering = int(res["topological_ordering"]) if res else None + stream_ordering = int(res["stream_ordering"]) if res else None + # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts sql = ( @@ -256,16 +315,7 @@ class ReceiptsStore(SQLBaseStore): txn.execute(sql, (room_id, receipt_type, user_id)) results = txn.fetchall() - if results: - res = self._simple_select_one_txn( - txn, - table="events", - retcols=["topological_ordering", "stream_ordering"], - keyvalues={"event_id": event_id}, - ) - topological_ordering = int(res["topological_ordering"]) - stream_ordering = int(res["stream_ordering"]) - + if results and topological_ordering: for to, so, _ in results: if int(to) > topological_ordering: return False @@ -295,6 +345,14 @@ class ReceiptsStore(SQLBaseStore): } ) + if receipt_type == "m.read" and topological_ordering: + self._remove_old_push_actions_before_txn( + txn, + room_id=room_id, + user_id=user_id, + topological_ordering=topological_ordering, + ) + return True @defer.inlineCallbacks @@ -347,7 +405,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = self._stream_id_gen.get_max_token() + max_persisted_id = self._stream_id_gen.get_current_token() defer.returnValue((stream_id, max_persisted_id)) @@ -365,10 +423,14 @@ class ReceiptsStore(SQLBaseStore): self.get_receipts_for_room.invalidate, (room_id, receipt_type) ) txn.call_after( + self._invalidate_get_users_with_receipts_in_room, + room_id, receipt_type, user_id, + ) + txn.call_after( self.get_receipts_for_user.invalidate, (user_id, receipt_type) ) # FIXME: This shouldn't invalidate the whole cache - txn.call_after(self.get_linearized_receipts_for_room.invalidate_all) + txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) self._simple_delete_txn( txn, @@ -391,16 +453,22 @@ class ReceiptsStore(SQLBaseStore): } ) - def get_all_updated_receipts(self, last_id, current_id, limit): + def get_all_updated_receipts(self, last_id, current_id, limit=None): + if last_id == current_id: + return defer.succeed([]) + def get_all_updated_receipts_txn(txn): sql = ( "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" " FROM receipts_linearized" " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" - " LIMIT ?" ) - txn.execute(sql, (last_id, current_id, limit)) + args = [last_id, current_id] + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + txn.execute(sql, args) return txn.fetchall() return self.runInteraction(