summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/push/__init__.py73
-rw-r--r--synapse/push/httppusher.py14
-rw-r--r--synapse/push/pusherpool.py15
-rw-r--r--synapse/storage/receipts.py54
-rw-r--r--synapse/storage/schema/delta/28/receipts_user_id_index.sql18
5 files changed, 108 insertions, 66 deletions
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index a5dc84160c..9a4af2b3ca 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
 
 from synapse.streams.config import PaginationConfig
 from synapse.types import StreamToken
+from synapse.api.constants import Membership
 
 import synapse.util.async
 import push_rule_evaluator as push_rule_evaluator
@@ -55,6 +56,7 @@ class Pusher(object):
         self.backoff_delay = Pusher.INITIAL_BACKOFF
         self.failing_since = failing_since
         self.alive = True
+        self.badge = None
 
         # The last value of last_active_time that we saw
         self.last_last_active_time = 0
@@ -92,8 +94,7 @@ class Pusher(object):
             # we fail to dispatch the push)
             config = PaginationConfig(from_token=None, limit='1')
             chunk = yield self.evStreamHandler.get_stream(
-                self.user_id, config, timeout=0, affect_presence=False,
-                only_room_events=True
+                self.user_id, config, timeout=0, affect_presence=False
             )
             self.last_token = chunk['end']
             self.store.update_pusher_last_token(
@@ -124,9 +125,11 @@ class Pusher(object):
         from_tok = StreamToken.from_string(self.last_token)
         config = PaginationConfig(from_token=from_tok, limit='1')
         timeout = (300 + random.randint(-60, 60)) * 1000
+        # note that we need to get read receipts down the stream as we need to
+        # wake up when one arrives. we don't need to explicitly look for
+        # them though.
         chunk = yield self.evStreamHandler.get_stream(
-            self.user_id, config, timeout=timeout, affect_presence=False,
-            only_room_events=True
+            self.user_id, config, timeout=timeout, affect_presence=False
         )
 
         # limiting to 1 may get 1 event plus 1 presence event, so
@@ -135,10 +138,10 @@ class Pusher(object):
         for c in chunk['chunk']:
             if 'event_id' in c:  # Hmmm...
                 single_event = c
-                break
+
         if not single_event:
+            yield self.update_badge()
             self.last_token = chunk['end']
-            logger.debug("Event stream timeout for pushkey %s", self.pushkey)
             yield self.store.update_pusher_last_token(
                 self.app_id,
                 self.pushkey,
@@ -161,7 +164,8 @@ class Pusher(object):
         tweaks = rule_evaluator.tweaks_for_actions(actions)
 
         if 'notify' in actions:
-            rejected = yield self.dispatch_push(single_event, tweaks)
+            self.badge = yield self._get_badge_count()
+            rejected = yield self.dispatch_push(single_event, tweaks, self.badge)
             self.has_unread = True
             if isinstance(rejected, list) or isinstance(rejected, tuple):
                 processed = True
@@ -181,7 +185,6 @@ class Pusher(object):
                         yield self.hs.get_pusherpool().remove_pusher(
                             self.app_id, pk, self.user_id
                         )
-        else:
             processed = True
 
         if not self.alive:
@@ -254,7 +257,7 @@ class Pusher(object):
     def stop(self):
         self.alive = False
 
-    def dispatch_push(self, p, tweaks):
+    def dispatch_push(self, p, tweaks, badge):
         """
         Overridden by implementing classes to actually deliver the notification
         Args:
@@ -266,23 +269,47 @@ class Pusher(object):
         """
         pass
 
-    def reset_badge_count(self):
-        pass
+    @defer.inlineCallbacks
+    def update_badge(self):
+        new_badge = yield self._get_badge_count()
+        if self.badge != new_badge:
+            self.badge = new_badge
+            yield self.send_badge(self.badge)
 
-    def presence_changed(self, state):
+    def send_badge(self, badge):
         """
-        We clear badge counts whenever a user's last_active time is bumped
-        This is by no means perfect but I think it's the best we can do
-        without read receipts.
+        Overridden by implementing classes to send an updated badge count
         """
-        if 'last_active' in state.state:
-            last_active = state.state['last_active']
-            if last_active > self.last_last_active_time:
-                self.last_last_active_time = last_active
-                if self.has_unread:
-                    logger.info("Resetting badge count for %s", self.user_id)
-                    self.reset_badge_count()
-                    self.has_unread = False
+        pass
+
+    @defer.inlineCallbacks
+    def _get_badge_count(self):
+        room_list = yield self.store.get_rooms_for_user_where_membership_is(
+            user_id=self.user_id,
+            membership_list=(Membership.INVITE, Membership.JOIN)
+        )
+
+        my_receipts_by_room = yield self.store.get_receipts_for_user(
+            self.user_id,
+            "m.read",
+        )
+
+        badge = 0
+
+        for r in room_list:
+            if r.membership == Membership.INVITE:
+                badge += 1
+            else:
+                if r.room_id in my_receipts_by_room:
+                    last_unread_event_id = my_receipts_by_room[r.room_id]
+
+                    notifs = yield (
+                        self.store.get_unread_event_push_actions_by_room_for_user(
+                            r.room_id, self.user_id, last_unread_event_id
+                        )
+                    )
+                    badge += len(notifs)
+        defer.returnValue(badge)
 
 
 class PusherConfigException(Exception):
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 28f1fab0e4..cdc4494928 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -51,7 +51,7 @@ class HttpPusher(Pusher):
         del self.data_minus_url['url']
 
     @defer.inlineCallbacks
-    def _build_notification_dict(self, event, tweaks):
+    def _build_notification_dict(self, event, tweaks, badge):
         # we probably do not want to push for every presence update
         # (we may want to be able to set up notifications when specific
         # people sign in, but we'd want to only deliver the pertinent ones)
@@ -71,7 +71,7 @@ class HttpPusher(Pusher):
                 'counts': {  # -- we don't mark messages as read yet so
                              # we have no way of knowing
                     # Just set the badge to 1 until we have read receipts
-                    'unread': 1,
+                    'unread': badge,
                     # 'missed_calls': 2
                 },
                 'devices': [
@@ -101,8 +101,8 @@ class HttpPusher(Pusher):
         defer.returnValue(d)
 
     @defer.inlineCallbacks
-    def dispatch_push(self, event, tweaks):
-        notification_dict = yield self._build_notification_dict(event, tweaks)
+    def dispatch_push(self, event, tweaks, badge):
+        notification_dict = yield self._build_notification_dict(event, tweaks, badge)
         if not notification_dict:
             defer.returnValue([])
         try:
@@ -116,15 +116,15 @@ class HttpPusher(Pusher):
         defer.returnValue(rejected)
 
     @defer.inlineCallbacks
-    def reset_badge_count(self):
+    def send_badge(self, badge):
+        logger.info("Sending updated badge count %d to %r", badge, self.user_id)
         d = {
             'notification': {
                 'id': '',
                 'type': None,
                 'sender': '',
                 'counts': {
-                    'unread': 0,
-                    'missed_calls': 0
+                    'unread': badge
                 },
                 'devices': [
                     {
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 12c4af14bd..d1b7c0802f 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -31,21 +31,6 @@ class PusherPool:
         self.pushers = {}
         self.last_pusher_started = -1
 
-        distributor = self.hs.get_distributor()
-        distributor.observe(
-            "user_presence_changed", self.user_presence_changed
-        )
-
-    @defer.inlineCallbacks
-    def user_presence_changed(self, user, state):
-        user_id = user.to_string()
-
-        # until we have read receipts, pushers use this to reset a user's
-        # badge counters to zero
-        for p in self.pushers.values():
-            if p.user_id == user_id:
-                yield p.presence_changed(state)
-
     @defer.inlineCallbacks
     def start(self):
         pushers = yield self.store.get_all_pushers()
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index c80e576620..c4232bdc65 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -45,6 +45,21 @@ class ReceiptsStore(SQLBaseStore):
             desc="get_receipts_for_room",
         )
 
+    @cachedInlineCallbacks(num_args=2)
+    def get_receipts_for_user(self, user_id, receipt_type):
+        def f(txn):
+            sql = (
+                "SELECT room_id,event_id "
+                "FROM receipts_linearized "
+                "WHERE user_id = ? AND receipt_type = ? "
+            )
+            txn.execute(sql, (user_id, receipt_type))
+            return txn.fetchall()
+
+        defer.returnValue(dict(
+            (yield self.runInteraction("get_receipts_for_user", f))
+        ))
+
     @defer.inlineCallbacks
     def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
         """Get receipts for multiple rooms for sending to clients.
@@ -194,29 +209,16 @@ class ReceiptsStore(SQLBaseStore):
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_max_token(self)
 
-    @cachedInlineCallbacks()
-    def get_graph_receipts_for_room(self, room_id):
-        """Get receipts for sending to remote servers.
-        """
-        rows = yield self._simple_select_list(
-            table="receipts_graph",
-            keyvalues={"room_id": room_id},
-            retcols=["receipt_type", "user_id", "event_id"],
-            desc="get_linearized_receipts_for_room",
-        )
-
-        result = {}
-        for row in rows:
-            result.setdefault(
-                row["user_id"], {}
-            ).setdefault(
-                row["receipt_type"], []
-            ).append(row["event_id"])
-
-        defer.returnValue(result)
-
     def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
                                       user_id, event_id, data, stream_id):
+        txn.call_after(
+            self.get_receipts_for_room.invalidate, (room_id, receipt_type)
+        )
+        txn.call_after(
+            self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+        )
+        # FIXME: This shouldn't invalidate the whole cache
+        txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
 
         # We don't want to clobber receipts for more recent events, so we
         # have to compare orderings of existing receipts
@@ -324,6 +326,7 @@ class ReceiptsStore(SQLBaseStore):
         )
 
         max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+
         defer.returnValue((stream_id, max_persisted_id))
 
     def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
@@ -336,6 +339,15 @@ class ReceiptsStore(SQLBaseStore):
 
     def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
                                  user_id, event_ids, data):
+        txn.call_after(
+            self.get_receipts_for_room.invalidate, (room_id, receipt_type)
+        )
+        txn.call_after(
+            self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+        )
+        # FIXME: This shouldn't invalidate the whole cache
+        txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
+
         self._simple_delete_txn(
             txn,
             table="receipts_graph",
diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/schema/delta/28/receipts_user_id_index.sql
new file mode 100644
index 0000000000..452a1b3c6c
--- /dev/null
+++ b/synapse/storage/schema/delta/28/receipts_user_id_index.sql
@@ -0,0 +1,18 @@
+/* Copyright 2015, 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX receipts_linearized_user ON receipts_linearized(
+    user_id
+);