diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index f42b8014c7..0ac665e967 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,52 +14,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+import abc
+import logging
+
+from canonicaljson import json
from twisted.internet import defer
-import logging
-import ujson as json
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+from ._base import SQLBaseStore
+from .util.id_generators import StreamIdGenerator
logger = logging.getLogger(__name__)
-class ReceiptsStore(SQLBaseStore):
- def __init__(self, hs):
- super(ReceiptsStore, self).__init__(hs)
+class ReceiptsWorkerStore(SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_receipt_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+ "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
+ @abc.abstractmethod
+ def get_max_receipt_stream_id(self):
+ """Get the current max stream ID for receipts stream
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
defer.returnValue(set(r['user_id'] for r in receipts))
- def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
- user_id):
- if receipt_type != "m.read":
- return
-
- # Returns an ObservableDeferred
- res = self.get_users_with_read_receipts_in_room.cache.get(
- room_id, None, update_metrics=False,
- )
-
- if res:
- if isinstance(res, defer.Deferred) and res.called:
- res = res.result
- if user_id in res:
- # We'd only be adding to the set, so no point invalidating if the
- # user is already there
- return
-
- self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
@@ -139,7 +140,9 @@ class ReceiptsStore(SQLBaseStore):
"""
room_ids = set(room_ids)
- if from_key:
+ if from_key is not None:
+ # Only ask the database about rooms where there have been new
+ # receipts added since `from_key`
room_ids = yield self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
@@ -150,7 +153,6 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res])
- @cachedInlineCallbacks(num_args=3, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
@@ -161,7 +163,19 @@ class ReceiptsStore(SQLBaseStore):
from the start.
Returns:
- list: A list of receipts.
+ Deferred[list]: A list of receipts.
+ """
+ if from_key is not None:
+ # Check the cache first to see if any new receipts have been added
+ # since`from_key`. If not we can no-op.
+ if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+ defer.succeed([])
+
+ return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+
+ @cachedInlineCallbacks(num_args=3, tree=True)
+ def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ """See get_linearized_receipts_for_room
"""
def f(txn):
if from_key:
@@ -210,7 +224,7 @@ class ReceiptsStore(SQLBaseStore):
"content": content,
}])
- @cachedList(cached_method_name="get_linearized_receipts_for_room",
+ @cachedList(cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids", num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
@@ -270,11 +284,97 @@ class ReceiptsStore(SQLBaseStore):
}
defer.returnValue(results)
+ def get_all_updated_receipts(self, last_id, current_id, limit=None):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_receipts_txn(txn):
+ sql = (
+ "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ )
+ args = [last_id, current_id]
+ if limit is not None:
+ sql += " LIMIT ?"
+ args.append(limit)
+ txn.execute(sql, args)
+
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_receipts", get_all_updated_receipts_txn
+ )
+
+ def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+ user_id):
+ if receipt_type != "m.read":
+ return
+
+ # Returns either an ObservableDeferred or the raw result
+ res = self.get_users_with_read_receipts_in_room.cache.get(
+ room_id, None, update_metrics=False,
+ )
+
+ # first handle the Deferred case
+ if isinstance(res, defer.Deferred):
+ if res.called:
+ res = res.result
+ else:
+ res = None
+
+ if res and user_id in res:
+ # We'd only be adding to the set, so no point invalidating if the
+ # user is already there
+ return
+
+ self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+ def __init__(self, db_conn, hs):
+ # We instantiate this first as the ReceiptsWorkerStore constructor
+ # needs to be able to call get_max_receipt_stream_id
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+
+ super(ReceiptsStore, self).__init__(db_conn, hs)
+
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
+ res = self._simple_select_one_txn(
+ txn,
+ table="events",
+ retcols=["topological_ordering", "stream_ordering"],
+ keyvalues={"event_id": event_id},
+ allow_none=True
+ )
+
+ stream_ordering = int(res["stream_ordering"]) if res else None
+
+ # We don't want to clobber receipts for more recent events, so we
+ # have to compare orderings of existing receipts
+ if stream_ordering is not None:
+ sql = (
+ "SELECT stream_ordering, event_id FROM events"
+ " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
+ " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
+ )
+ txn.execute(sql, (room_id, receipt_type, user_id))
+
+ for so, eid in txn:
+ if int(so) >= stream_ordering:
+ logger.debug(
+ "Ignoring new receipt for %s in favour of existing "
+ "one for later event %s",
+ event_id, eid,
+ )
+ return False
+
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
@@ -286,7 +386,7 @@ class ReceiptsStore(SQLBaseStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
- txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+ txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
txn.call_after(
self._receipts_stream_cache.entity_has_changed,
@@ -298,34 +398,6 @@ class ReceiptsStore(SQLBaseStore):
(user_id, room_id, receipt_type)
)
- res = self._simple_select_one_txn(
- txn,
- table="events",
- retcols=["topological_ordering", "stream_ordering"],
- keyvalues={"event_id": event_id},
- allow_none=True
- )
-
- topological_ordering = int(res["topological_ordering"]) if res else None
- stream_ordering = int(res["stream_ordering"]) if res else None
-
- # We don't want to clobber receipts for more recent events, so we
- # have to compare orderings of existing receipts
- sql = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
- " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
- )
-
- txn.execute(sql, (room_id, receipt_type, user_id))
-
- if topological_ordering:
- for to, so, _ in txn:
- if int(to) > topological_ordering:
- return False
- elif int(to) == topological_ordering and int(so) >= stream_ordering:
- return False
-
self._simple_delete_txn(
txn,
table="receipts_linearized",
@@ -349,12 +421,11 @@ class ReceiptsStore(SQLBaseStore):
}
)
- if receipt_type == "m.read" and topological_ordering:
+ if receipt_type == "m.read" and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
txn,
room_id=room_id,
user_id=user_id,
- topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
)
@@ -435,7 +506,7 @@ class ReceiptsStore(SQLBaseStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
- txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+ txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
self._simple_delete_txn(
txn,
@@ -457,25 +528,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data),
}
)
-
- def get_all_updated_receipts(self, last_id, current_id, limit=None):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_receipts_txn(txn):
- sql = (
- "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
- " FROM receipts_linearized"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- )
- args = [last_id, current_id]
- if limit is not None:
- sql += " LIMIT ?"
- args.append(limit)
- txn.execute(sql, args)
-
- return txn.fetchall()
- return self.runInteraction(
- "get_all_updated_receipts", get_all_updated_receipts_txn
- )
|