diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index a535063547..4202a6b3dc 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,12 +14,11 @@
# limitations under the License.
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
-from blist import sorteddict
import logging
import ujson as json
@@ -31,7 +30,50 @@ class ReceiptsStore(SQLBaseStore):
def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs)
- self._receipts_stream_cache = _RoomStreamChangeCache()
+ self._receipts_stream_cache = StreamChangeCache(
+ "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
+ )
+
+ @cached(num_args=2)
+ def get_receipts_for_room(self, room_id, receipt_type):
+ return self._simple_select_list(
+ table="receipts_linearized",
+ keyvalues={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ },
+ retcols=("user_id", "event_id"),
+ desc="get_receipts_for_room",
+ )
+
+ @cached(num_args=3)
+ def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
+ return self._simple_select_one_onecol(
+ table="receipts_linearized",
+ keyvalues={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id
+ },
+ retcol="event_id",
+ desc="get_own_receipt_for_user",
+ allow_none=True,
+ )
+
+ @cachedInlineCallbacks(num_args=2)
+ def get_receipts_for_user(self, user_id, receipt_type):
+ def f(txn):
+ sql = (
+ "SELECT room_id,event_id "
+ "FROM receipts_linearized "
+ "WHERE user_id = ? AND receipt_type = ? "
+ )
+ txn.execute(sql, (user_id, receipt_type))
+ return txn.fetchall()
+
+ defer.returnValue(dict(
+ (yield self.runInteraction("get_receipts_for_user", f))
+ ))
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
@@ -49,8 +91,8 @@ class ReceiptsStore(SQLBaseStore):
room_ids = set(room_ids)
if from_key:
- room_ids = yield self._receipts_stream_cache.get_rooms_changed(
- self, room_ids, from_key
+ room_ids = yield self._receipts_stream_cache.get_entities_changed(
+ room_ids, from_key
)
results = yield self._get_linearized_receipts_for_rooms(
@@ -182,29 +224,26 @@ class ReceiptsStore(SQLBaseStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self)
- @cachedInlineCallbacks()
- def get_graph_receipts_for_room(self, room_id):
- """Get receipts for sending to remote servers.
- """
- rows = yield self._simple_select_list(
- table="receipts_graph",
- keyvalues={"room_id": room_id},
- retcols=["receipt_type", "user_id", "event_id"],
- desc="get_linearized_receipts_for_room",
+ def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
+ user_id, event_id, data, stream_id):
+ txn.call_after(
+ self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
+ txn.call_after(
+ self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+ )
+ # FIXME: This shouldn't invalidate the whole cache
+ txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
- result = {}
- for row in rows:
- result.setdefault(
- row["user_id"], {}
- ).setdefault(
- row["receipt_type"], []
- ).append(row["event_id"])
-
- defer.returnValue(result)
+ txn.call_after(
+ self._receipts_stream_cache.entity_has_changed,
+ room_id, stream_id
+ )
- def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
- user_id, event_id, data, stream_id):
+ txn.call_after(
+ self.get_last_receipt_event_id_for_user.invalidate,
+ (user_id, room_id, receipt_type)
+ )
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
@@ -293,9 +332,6 @@ class ReceiptsStore(SQLBaseStore):
stream_id_manager = yield self._receipts_id_gen.get_next(self)
with stream_id_manager as stream_id:
- yield self._receipts_stream_cache.room_has_changed(
- self, room_id, stream_id
- )
have_persisted = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
@@ -312,6 +348,7 @@ class ReceiptsStore(SQLBaseStore):
)
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+
defer.returnValue((stream_id, max_persisted_id))
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
@@ -324,6 +361,15 @@ class ReceiptsStore(SQLBaseStore):
def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_ids, data):
+ txn.call_after(
+ self.get_receipts_for_room.invalidate, (room_id, receipt_type)
+ )
+ txn.call_after(
+ self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+ )
+ # FIXME: This shouldn't invalidate the whole cache
+ txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
+
self._simple_delete_txn(
txn,
table="receipts_graph",
@@ -344,63 +390,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data),
}
)
-
-
-class _RoomStreamChangeCache(object):
- """Keeps track of the stream_id of the latest change in rooms.
-
- Given a list of rooms and stream key, it will give a subset of rooms that
- may have changed since that key. If the key is too old then the cache
- will simply return all rooms.
- """
- def __init__(self, size_of_cache=10000):
- self._size_of_cache = size_of_cache
- self._room_to_key = {}
- self._cache = sorteddict()
- self._earliest_key = None
- self.name = "ReceiptsRoomChangeCache"
- caches_by_name[self.name] = self._cache
-
- @defer.inlineCallbacks
- def get_rooms_changed(self, store, room_ids, key):
- """Returns subset of room ids that have had new receipts since the
- given key. If the key is too old it will just return the given list.
- """
- if key > (yield self._get_earliest_key(store)):
- keys = self._cache.keys()
- i = keys.bisect_right(key)
-
- result = set(
- self._cache[k] for k in keys[i:]
- ).intersection(room_ids)
-
- cache_counter.inc_hits(self.name)
- else:
- result = room_ids
- cache_counter.inc_misses(self.name)
-
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def room_has_changed(self, store, room_id, key):
- """Informs the cache that the room has been changed at the given key.
- """
- if key > (yield self._get_earliest_key(store)):
- old_key = self._room_to_key.get(room_id, None)
- if old_key:
- key = max(key, old_key)
- self._cache.pop(old_key, None)
- self._cache[key] = room_id
-
- while len(self._cache) > self._size_of_cache:
- k, r = self._cache.popitem()
- self._earliest_key = max(k, self._earliest_key)
- self._room_to_key.pop(r, None)
-
- @defer.inlineCallbacks
- def _get_earliest_key(self, store):
- if self._earliest_key is None:
- self._earliest_key = yield store.get_max_receipt_stream_id()
- self._earliest_key = int(self._earliest_key)
-
- defer.returnValue(self._earliest_key)
|