diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index ffae02a285..9a4af2b3ca 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -16,9 +16,8 @@
from twisted.internet import defer
from synapse.streams.config import PaginationConfig
-from synapse.types import StreamToken, UserID
+from synapse.types import StreamToken
from synapse.api.constants import Membership
-from synapse.api.filtering import FilterCollection
import synapse.util.async
import push_rule_evaluator as push_rule_evaluator
@@ -290,22 +289,9 @@ class Pusher(object):
membership_list=(Membership.INVITE, Membership.JOIN)
)
- user_is_guest = yield self.store.is_guest(self.user_id)
-
- # XXX: importing inside method to break circular dependency.
- # should sort out the mess by moving all this logic out of
- # push/__init__.py and probably moving the logic we use from the sync
- # handler to somewhere more amenable to re-use.
- from synapse.handlers.sync import SyncConfig
- sync_config = SyncConfig(
- user=UserID.from_string(self.user_id),
- filter=FilterCollection({}),
- is_guest=user_is_guest,
- )
- now_token = yield self.hs.get_event_sources().get_current_token()
- sync_handler = self.hs.get_handlers().sync_handler
- _, ephemeral_by_room = yield sync_handler.ephemeral_by_room(
- sync_config, now_token
+ my_receipts_by_room = yield self.store.get_receipts_for_user(
+ self.user_id,
+ "m.read",
)
badge = 0
@@ -314,11 +300,9 @@ class Pusher(object):
if r.membership == Membership.INVITE:
badge += 1
else:
- last_unread_event_id = sync_handler.last_read_event_id_for_room_and_user(
- r.room_id, self.user_id, ephemeral_by_room
- )
+ if r.room_id in my_receipts_by_room:
+ last_unread_event_id = my_receipts_by_room[r.room_id]
- if last_unread_event_id:
notifs = yield (
self.store.get_unread_event_push_actions_by_room_for_user(
r.room_id, self.user_id, last_unread_event_id
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index c80e576620..018140f47a 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -45,6 +45,21 @@ class ReceiptsStore(SQLBaseStore):
desc="get_receipts_for_room",
)
+ @cachedInlineCallbacks(num_args=2)
+ def get_receipts_for_user(self, user_id, receipt_type):
+ def f(txn):
+ sql = (
+ "SELECT room_id,event_id "
+ "FROM receipts_linearized "
+ "WHERE user_id = ? AND receipt_type = ? "
+ )
+ txn.execute(sql, (user_id, receipt_type))
+ return txn.fetchall()
+
+ defer.returnValue(dict(
+ (yield self.runInteraction("get_receipts_for_user", f))
+ ))
+
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
"""Get receipts for multiple rooms for sending to clients.
@@ -194,29 +209,16 @@ class ReceiptsStore(SQLBaseStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self)
- @cachedInlineCallbacks()
- def get_graph_receipts_for_room(self, room_id):
- """Get receipts for sending to remote servers.
- """
- rows = yield self._simple_select_list(
- table="receipts_graph",
- keyvalues={"room_id": room_id},
- retcols=["receipt_type", "user_id", "event_id"],
- desc="get_linearized_receipts_for_room",
- )
-
- result = {}
- for row in rows:
- result.setdefault(
- row["user_id"], {}
- ).setdefault(
- row["receipt_type"], []
- ).append(row["event_id"])
-
- defer.returnValue(result)
-
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
+ txn.call_after(
+ self.get_receipts_for_room.invalidate, (room_id, receipt_type)
+ )
+ txn.call_after(
+ self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+ )
+ # FIXME: This shouldn't invalidate the whole cache
+ txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
@@ -324,6 +326,7 @@ class ReceiptsStore(SQLBaseStore):
)
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+
defer.returnValue((stream_id, max_persisted_id))
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
@@ -336,6 +339,16 @@ class ReceiptsStore(SQLBaseStore):
def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_ids, data):
+ txn.call_after(
+ self.get_receipts_for_room.invalidate, (room_id, receipt_type)
+ )
+ txn.call_after(
+ self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+ )
+ # FIXME: This shouldn't invalidate the whole cache
+ txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
+
+
self._simple_delete_txn(
txn,
table="receipts_graph",
|