diff options
Diffstat (limited to 'synapse/storage/receipts.py')
-rw-r--r-- | synapse/storage/receipts.py | 223 |
1 files changed, 136 insertions, 87 deletions
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index f42b8014c7..0ac665e967 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,52 +14,52 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached -from synapse.util.caches.stream_change_cache import StreamChangeCache +import abc +import logging + +from canonicaljson import json from twisted.internet import defer -import logging -import ujson as json +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache +from ._base import SQLBaseStore +from .util.id_generators import StreamIdGenerator logger = logging.getLogger(__name__) -class ReceiptsStore(SQLBaseStore): - def __init__(self, hs): - super(ReceiptsStore, self).__init__(hs) +class ReceiptsWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_receipt_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(db_conn, hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() + "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) + @abc.abstractmethod + def get_max_receipt_stream_id(self): + """Get the current max stream ID for receipts stream + + Returns: + int + """ + raise NotImplementedError() + @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") defer.returnValue(set(r['user_id'] for r in receipts)) - def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, - user_id): - if receipt_type != "m.read": - return - - # Returns an ObservableDeferred - res = self.get_users_with_read_receipts_in_room.cache.get( - room_id, None, update_metrics=False, - ) - - if res: - if isinstance(res, defer.Deferred) and res.called: - res = res.result - if user_id in res: - # We'd only be adding to the set, so no point invalidating if the - # user is already there - return - - self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( @@ -139,7 +140,9 @@ class ReceiptsStore(SQLBaseStore): """ room_ids = set(room_ids) - if from_key: + if from_key is not None: + # Only ask the database about rooms where there have been new + # receipts added since `from_key` room_ids = yield self._receipts_stream_cache.get_entities_changed( room_ids, from_key ) @@ -150,7 +153,6 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue([ev for res in results.values() for ev in res]) - @cachedInlineCallbacks(num_args=3, tree=True) def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. @@ -161,7 +163,19 @@ class ReceiptsStore(SQLBaseStore): from the start. Returns: - list: A list of receipts. + Deferred[list]: A list of receipts. + """ + if from_key is not None: + # Check the cache first to see if any new receipts have been added + # since`from_key`. If not we can no-op. + if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): + defer.succeed([]) + + return self._get_linearized_receipts_for_room(room_id, to_key, from_key) + + @cachedInlineCallbacks(num_args=3, tree=True) + def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + """See get_linearized_receipts_for_room """ def f(txn): if from_key: @@ -210,7 +224,7 @@ class ReceiptsStore(SQLBaseStore): "content": content, }]) - @cachedList(cached_method_name="get_linearized_receipts_for_room", + @cachedList(cached_method_name="_get_linearized_receipts_for_room", list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: @@ -270,11 +284,97 @@ class ReceiptsStore(SQLBaseStore): } defer.returnValue(results) + def get_all_updated_receipts(self, last_id, current_id, limit=None): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + ) + args = [last_id, current_id] + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + txn.execute(sql, args) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) + + def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, + user_id): + if receipt_type != "m.read": + return + + # Returns either an ObservableDeferred or the raw result + res = self.get_users_with_read_receipts_in_room.cache.get( + room_id, None, update_metrics=False, + ) + + # first handle the Deferred case + if isinstance(res, defer.Deferred): + if res.called: + res = res.result + else: + res = None + + if res and user_id in res: + # We'd only be adding to the set, so no point invalidating if the + # user is already there + return + + self.get_users_with_read_receipts_in_room.invalidate((room_id,)) + + +class ReceiptsStore(ReceiptsWorkerStore): + def __init__(self, db_conn, hs): + # We instantiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + + super(ReceiptsStore, self).__init__(db_conn, hs) + def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): + res = self._simple_select_one_txn( + txn, + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True + ) + + stream_ordering = int(res["stream_ordering"]) if res else None + + # We don't want to clobber receipts for more recent events, so we + # have to compare orderings of existing receipts + if stream_ordering is not None: + sql = ( + "SELECT stream_ordering, event_id FROM events" + " INNER JOIN receipts_linearized as r USING (event_id, room_id)" + " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + ) + txn.execute(sql, (room_id, receipt_type, user_id)) + + for so, eid in txn: + if int(so) >= stream_ordering: + logger.debug( + "Ignoring new receipt for %s in favour of existing " + "one for later event %s", + event_id, eid, + ) + return False + txn.call_after( self.get_receipts_for_room.invalidate, (room_id, receipt_type) ) @@ -286,7 +386,7 @@ class ReceiptsStore(SQLBaseStore): self.get_receipts_for_user.invalidate, (user_id, receipt_type) ) # FIXME: This shouldn't invalidate the whole cache - txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) + txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,)) txn.call_after( self._receipts_stream_cache.entity_has_changed, @@ -298,34 +398,6 @@ class ReceiptsStore(SQLBaseStore): (user_id, room_id, receipt_type) ) - res = self._simple_select_one_txn( - txn, - table="events", - retcols=["topological_ordering", "stream_ordering"], - keyvalues={"event_id": event_id}, - allow_none=True - ) - - topological_ordering = int(res["topological_ordering"]) if res else None - stream_ordering = int(res["stream_ordering"]) if res else None - - # We don't want to clobber receipts for more recent events, so we - # have to compare orderings of existing receipts - sql = ( - "SELECT topological_ordering, stream_ordering, event_id FROM events" - " INNER JOIN receipts_linearized as r USING (event_id, room_id)" - " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" - ) - - txn.execute(sql, (room_id, receipt_type, user_id)) - - if topological_ordering: - for to, so, _ in txn: - if int(to) > topological_ordering: - return False - elif int(to) == topological_ordering and int(so) >= stream_ordering: - return False - self._simple_delete_txn( txn, table="receipts_linearized", @@ -349,12 +421,11 @@ class ReceiptsStore(SQLBaseStore): } ) - if receipt_type == "m.read" and topological_ordering: + if receipt_type == "m.read" and stream_ordering is not None: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, - topological_ordering=topological_ordering, stream_ordering=stream_ordering, ) @@ -435,7 +506,7 @@ class ReceiptsStore(SQLBaseStore): self.get_receipts_for_user.invalidate, (user_id, receipt_type) ) # FIXME: This shouldn't invalidate the whole cache - txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) + txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,)) self._simple_delete_txn( txn, @@ -457,25 +528,3 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) - - def get_all_updated_receipts(self, last_id, current_id, limit=None): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_receipts_txn(txn): - sql = ( - "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" - " FROM receipts_linearized" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - ) - args = [last_id, current_id] - if limit is not None: - sql += " LIMIT ?" - args.append(limit) - txn.execute(sql, args) - - return txn.fetchall() - return self.runInteraction( - "get_all_updated_receipts", get_all_updated_receipts_txn - ) |