summary refs log tree commit diff
path: root/synapse/storage/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/receipts.py')
-rw-r--r--synapse/storage/receipts.py166
1 files changed, 76 insertions, 90 deletions
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index a535063547..4202a6b3dc 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,12 +14,11 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from twisted.internet import defer
 
-from blist import sorteddict
 import logging
 import ujson as json
 
@@ -31,7 +30,50 @@ class ReceiptsStore(SQLBaseStore):
     def __init__(self, hs):
         super(ReceiptsStore, self).__init__(hs)
 
-        self._receipts_stream_cache = _RoomStreamChangeCache()
+        self._receipts_stream_cache = StreamChangeCache(
+            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
+        )
+
+    @cached(num_args=2)
+    def get_receipts_for_room(self, room_id, receipt_type):
+        return self._simple_select_list(
+            table="receipts_linearized",
+            keyvalues={
+                "room_id": room_id,
+                "receipt_type": receipt_type,
+            },
+            retcols=("user_id", "event_id"),
+            desc="get_receipts_for_room",
+        )
+
+    @cached(num_args=3)
+    def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
+        return self._simple_select_one_onecol(
+            table="receipts_linearized",
+            keyvalues={
+                "room_id": room_id,
+                "receipt_type": receipt_type,
+                "user_id": user_id
+            },
+            retcol="event_id",
+            desc="get_own_receipt_for_user",
+            allow_none=True,
+        )
+
+    @cachedInlineCallbacks(num_args=2)
+    def get_receipts_for_user(self, user_id, receipt_type):
+        def f(txn):
+            sql = (
+                "SELECT room_id,event_id "
+                "FROM receipts_linearized "
+                "WHERE user_id = ? AND receipt_type = ? "
+            )
+            txn.execute(sql, (user_id, receipt_type))
+            return txn.fetchall()
+
+        defer.returnValue(dict(
+            (yield self.runInteraction("get_receipts_for_user", f))
+        ))
 
     @defer.inlineCallbacks
     def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
@@ -49,8 +91,8 @@ class ReceiptsStore(SQLBaseStore):
         room_ids = set(room_ids)
 
         if from_key:
-            room_ids = yield self._receipts_stream_cache.get_rooms_changed(
-                self, room_ids, from_key
+            room_ids = yield self._receipts_stream_cache.get_entities_changed(
+                room_ids, from_key
             )
 
         results = yield self._get_linearized_receipts_for_rooms(
@@ -182,29 +224,26 @@ class ReceiptsStore(SQLBaseStore):
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_max_token(self)
 
-    @cachedInlineCallbacks()
-    def get_graph_receipts_for_room(self, room_id):
-        """Get receipts for sending to remote servers.
-        """
-        rows = yield self._simple_select_list(
-            table="receipts_graph",
-            keyvalues={"room_id": room_id},
-            retcols=["receipt_type", "user_id", "event_id"],
-            desc="get_linearized_receipts_for_room",
+    def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
+                                      user_id, event_id, data, stream_id):
+        txn.call_after(
+            self.get_receipts_for_room.invalidate, (room_id, receipt_type)
         )
+        txn.call_after(
+            self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+        )
+        # FIXME: This shouldn't invalidate the whole cache
+        txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
 
-        result = {}
-        for row in rows:
-            result.setdefault(
-                row["user_id"], {}
-            ).setdefault(
-                row["receipt_type"], []
-            ).append(row["event_id"])
-
-        defer.returnValue(result)
+        txn.call_after(
+            self._receipts_stream_cache.entity_has_changed,
+            room_id, stream_id
+        )
 
-    def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
-                                      user_id, event_id, data, stream_id):
+        txn.call_after(
+            self.get_last_receipt_event_id_for_user.invalidate,
+            (user_id, room_id, receipt_type)
+        )
 
         # We don't want to clobber receipts for more recent events, so we
         # have to compare orderings of existing receipts
@@ -293,9 +332,6 @@ class ReceiptsStore(SQLBaseStore):
 
         stream_id_manager = yield self._receipts_id_gen.get_next(self)
         with stream_id_manager as stream_id:
-            yield self._receipts_stream_cache.room_has_changed(
-                self, room_id, stream_id
-            )
             have_persisted = yield self.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
@@ -312,6 +348,7 @@ class ReceiptsStore(SQLBaseStore):
         )
 
         max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+
         defer.returnValue((stream_id, max_persisted_id))
 
     def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
@@ -324,6 +361,15 @@ class ReceiptsStore(SQLBaseStore):
 
     def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
                                  user_id, event_ids, data):
+        txn.call_after(
+            self.get_receipts_for_room.invalidate, (room_id, receipt_type)
+        )
+        txn.call_after(
+            self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+        )
+        # FIXME: This shouldn't invalidate the whole cache
+        txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
+
         self._simple_delete_txn(
             txn,
             table="receipts_graph",
@@ -344,63 +390,3 @@ class ReceiptsStore(SQLBaseStore):
                 "data": json.dumps(data),
             }
         )
-
-
-class _RoomStreamChangeCache(object):
-    """Keeps track of the stream_id of the latest change in rooms.
-
-    Given a list of rooms and stream key, it will give a subset of rooms that
-    may have changed since that key. If the key is too old then the cache
-    will simply return all rooms.
-    """
-    def __init__(self, size_of_cache=10000):
-        self._size_of_cache = size_of_cache
-        self._room_to_key = {}
-        self._cache = sorteddict()
-        self._earliest_key = None
-        self.name = "ReceiptsRoomChangeCache"
-        caches_by_name[self.name] = self._cache
-
-    @defer.inlineCallbacks
-    def get_rooms_changed(self, store, room_ids, key):
-        """Returns subset of room ids that have had new receipts since the
-        given key. If the key is too old it will just return the given list.
-        """
-        if key > (yield self._get_earliest_key(store)):
-            keys = self._cache.keys()
-            i = keys.bisect_right(key)
-
-            result = set(
-                self._cache[k] for k in keys[i:]
-            ).intersection(room_ids)
-
-            cache_counter.inc_hits(self.name)
-        else:
-            result = room_ids
-            cache_counter.inc_misses(self.name)
-
-        defer.returnValue(result)
-
-    @defer.inlineCallbacks
-    def room_has_changed(self, store, room_id, key):
-        """Informs the cache that the room has been changed at the given key.
-        """
-        if key > (yield self._get_earliest_key(store)):
-            old_key = self._room_to_key.get(room_id, None)
-            if old_key:
-                key = max(key, old_key)
-                self._cache.pop(old_key, None)
-            self._cache[key] = room_id
-
-            while len(self._cache) > self._size_of_cache:
-                k, r = self._cache.popitem()
-                self._earliest_key = max(k, self._earliest_key)
-                self._room_to_key.pop(r, None)
-
-    @defer.inlineCallbacks
-    def _get_earliest_key(self, store):
-        if self._earliest_key is None:
-            self._earliest_key = yield store.get_max_receipt_stream_id()
-            self._earliest_key = int(self._earliest_key)
-
-        defer.returnValue(self._earliest_key)