summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/filtering.py22
-rw-r--r--synapse/handlers/register.py2
-rw-r--r--synapse/handlers/sync.py182
-rw-r--r--synapse/push/__init__.py30
-rw-r--r--synapse/rest/client/v2_alpha/sync.py57
-rw-r--r--synapse/storage/_base.py72
-rw-r--r--synapse/storage/presence.py35
-rw-r--r--synapse/storage/push_rule.py63
-rw-r--r--synapse/storage/roommember.py1
-rw-r--r--synapse/storage/schema/delta/28/event_push_actions.sql1
-rw-r--r--tests/api/test_filtering.py10
11 files changed, 304 insertions, 171 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 5530b8c48f..6c13ada5df 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -190,18 +190,16 @@ class Filter(object):
         Returns:
             bool: True if the event matches
         """
-        if isinstance(event, dict):
-            return self.check_fields(
-                event.get("room_id", None),
-                event.get("sender", None),
-                event.get("type", None),
-            )
-        else:
-            return self.check_fields(
-                getattr(event, "room_id", None),
-                getattr(event, "sender", None),
-                event.type,
-            )
+        sender = event.get("sender", None)
+        if not sender:
+            # Presence events have their 'sender' in content.user_id
+            sender = event.get("content", {}).get("user_id", None)
+
+        return self.check_fields(
+            event.get("room_id", None),
+            sender,
+            event.get("type", None),
+        )
 
     def check_fields(self, room_id, sender, event_type):
         """Checks whether the filter matches the given event fields.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 1e99c1303c..c11b98d0b7 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -52,7 +52,7 @@ class RegistrationHandler(BaseHandler):
         if urllib.quote(localpart.encode('utf-8')) != localpart:
             raise SynapseError(
                 400,
-                "User ID can only contain characters a-z, 0-9, or '-./'",
+                "User ID can only contain characters a-z, 0-9, or '_-./'",
                 Codes.INVALID_USERNAME
             )
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 53e1eb0508..328c049b03 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 SyncConfig = collections.namedtuple("SyncConfig", [
     "user",
-    "filter",
+    "filter_collection",
     "is_guest",
 ])
 
@@ -142,8 +142,9 @@ class SyncHandler(BaseHandler):
         if timeout == 0 or since_token is None or full_state:
             # we are going to return immediately, so don't bother calling
             # notifier.wait_for_events.
-            result = yield self.current_sync_for_user(sync_config, since_token,
-                                                      full_state=full_state)
+            result = yield self.current_sync_for_user(
+                sync_config, since_token, full_state=full_state,
+            )
             defer.returnValue(result)
         else:
             def current_sync_callback(before_token, after_token):
@@ -151,7 +152,7 @@ class SyncHandler(BaseHandler):
 
             result = yield self.notifier.wait_for_events(
                 sync_config.user.to_string(), timeout, current_sync_callback,
-                from_token=since_token
+                from_token=since_token,
             )
             defer.returnValue(result)
 
@@ -205,7 +206,7 @@ class SyncHandler(BaseHandler):
         )
 
         membership_list = (Membership.INVITE, Membership.JOIN)
-        if sync_config.filter.include_leave:
+        if sync_config.filter_collection.include_leave:
             membership_list += (Membership.LEAVE, Membership.BAN)
 
         room_list = yield self.store.get_rooms_for_user_where_membership_is(
@@ -266,9 +267,17 @@ class SyncHandler(BaseHandler):
             deferreds, consumeErrors=True
         ).addErrback(unwrapFirstError)
 
+        account_data_for_user = sync_config.filter_collection.filter_account_data(
+            self.account_data_for_user(account_data)
+        )
+
+        presence = sync_config.filter_collection.filter_presence(
+            presence
+        )
+
         defer.returnValue(SyncResult(
             presence=presence,
-            account_data=self.account_data_for_user(account_data),
+            account_data=account_data_for_user,
             joined=joined,
             invited=invited,
             archived=archived,
@@ -302,14 +311,31 @@ class SyncHandler(BaseHandler):
 
         current_state = yield self.get_state_at(room_id, now_token)
 
+        current_state = {
+            (e.type, e.state_key): e
+            for e in sync_config.filter_collection.filter_room_state(
+                current_state.values()
+            )
+        }
+
+        account_data = self.account_data_for_room(
+            room_id, tags_by_room, account_data_by_room
+        )
+
+        account_data = sync_config.filter_collection.filter_room_account_data(
+            account_data
+        )
+
+        ephemeral = sync_config.filter_collection.filter_room_ephemeral(
+            ephemeral_by_room.get(room_id, [])
+        )
+
         defer.returnValue(JoinedSyncResult(
             room_id=room_id,
             timeline=batch,
             state=current_state,
-            ephemeral=ephemeral_by_room.get(room_id, []),
-            account_data=self.account_data_for_room(
-                room_id, tags_by_room, account_data_by_room
-            ),
+            ephemeral=ephemeral,
+            account_data=account_data,
             unread_notifications=unread_notifications,
         ))
 
@@ -365,7 +391,7 @@ class SyncHandler(BaseHandler):
         typing, typing_key = yield typing_source.get_new_events(
             user=sync_config.user,
             from_key=typing_key,
-            limit=sync_config.filter.ephemeral_limit(),
+            limit=sync_config.filter_collection.ephemeral_limit(),
             room_ids=room_ids,
             is_guest=sync_config.is_guest,
         )
@@ -388,7 +414,7 @@ class SyncHandler(BaseHandler):
         receipts, receipt_key = yield receipt_source.get_new_events(
             user=sync_config.user,
             from_key=receipt_key,
-            limit=sync_config.filter.ephemeral_limit(),
+            limit=sync_config.filter_collection.ephemeral_limit(),
             room_ids=room_ids,
             is_guest=sync_config.is_guest,
         )
@@ -419,13 +445,26 @@ class SyncHandler(BaseHandler):
 
         leave_state = yield self.store.get_state_for_event(leave_event_id)
 
+        leave_state = {
+            (e.type, e.state_key): e
+            for e in sync_config.filter_collection.filter_room_state(
+                leave_state.values()
+            )
+        }
+
+        account_data = self.account_data_for_room(
+            room_id, tags_by_room, account_data_by_room
+        )
+
+        account_data = sync_config.filter_collection.filter_room_account_data(
+            account_data
+        )
+
         defer.returnValue(ArchivedSyncResult(
             room_id=room_id,
             timeline=batch,
             state=leave_state,
-            account_data=self.account_data_for_room(
-                room_id, tags_by_room, account_data_by_room
-            ),
+            account_data=account_data,
         ))
 
     @defer.inlineCallbacks
@@ -444,7 +483,7 @@ class SyncHandler(BaseHandler):
         presence, presence_key = yield presence_source.get_new_events(
             user=sync_config.user,
             from_key=since_token.presence_key,
-            limit=sync_config.filter.presence_limit(),
+            limit=sync_config.filter_collection.presence_limit(),
             room_ids=room_ids,
             is_guest=sync_config.is_guest,
         )
@@ -473,7 +512,7 @@ class SyncHandler(BaseHandler):
                 sync_config.user
             )
 
-        timeline_limit = sync_config.filter.timeline_limit()
+        timeline_limit = sync_config.filter_collection.timeline_limit()
 
         room_events, _ = yield self.store.get_room_events_stream(
             sync_config.user.to_string(),
@@ -538,6 +577,27 @@ class SyncHandler(BaseHandler):
                     # the timeline is inherently limited if we've just joined
                     limited = True
 
+                recents = sync_config.filter_collection.filter_room_timeline(recents)
+
+                state = {
+                    (e.type, e.state_key): e
+                    for e in sync_config.filter_collection.filter_room_state(
+                        state.values()
+                    )
+                }
+
+                acc_data = self.account_data_for_room(
+                    room_id, tags_by_room, account_data_by_room
+                )
+
+                acc_data = sync_config.filter_collection.filter_room_account_data(
+                    acc_data
+                )
+
+                ephemeral = sync_config.filter_collection.filter_room_ephemeral(
+                    ephemeral_by_room.get(room_id, [])
+                )
+
                 room_sync = JoinedSyncResult(
                     room_id=room_id,
                     timeline=TimelineBatch(
@@ -546,10 +606,8 @@ class SyncHandler(BaseHandler):
                         limited=limited,
                     ),
                     state=state,
-                    ephemeral=ephemeral_by_room.get(room_id, []),
-                    account_data=self.account_data_for_room(
-                        room_id, tags_by_room, account_data_by_room
-                    ),
+                    ephemeral=ephemeral,
+                    account_data=acc_data,
                     unread_notifications={},
                 )
                 logger.debug("Result for room %s: %r", room_id, room_sync)
@@ -603,9 +661,17 @@ class SyncHandler(BaseHandler):
             for event in invite_events
         ]
 
+        account_data_for_user = sync_config.filter_collection.filter_account_data(
+            self.account_data_for_user(account_data)
+        )
+
+        presence = sync_config.filter_collection.filter_presence(
+            presence
+        )
+
         defer.returnValue(SyncResult(
             presence=presence,
-            account_data=self.account_data_for_user(account_data),
+            account_data=account_data_for_user,
             joined=joined,
             invited=invited,
             archived=archived,
@@ -621,7 +687,7 @@ class SyncHandler(BaseHandler):
         limited = True
         recents = []
         filtering_factor = 2
-        timeline_limit = sync_config.filter.timeline_limit()
+        timeline_limit = sync_config.filter_collection.timeline_limit()
         load_limit = max(timeline_limit * filtering_factor, 100)
         max_repeat = 3  # Only try a few times per room, otherwise
         room_key = now_token.room_key
@@ -634,9 +700,9 @@ class SyncHandler(BaseHandler):
                 from_token=since_token.room_key if since_token else None,
                 end_token=end_key,
             )
-            (room_key, _) = keys
+            room_key, _ = keys
             end_key = "s" + room_key.split('-')[-1]
-            loaded_recents = sync_config.filter.filter_room_timeline(events)
+            loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
             loaded_recents = yield self._filter_events_for_client(
                 sync_config.user.to_string(),
                 loaded_recents,
@@ -684,17 +750,23 @@ class SyncHandler(BaseHandler):
 
         logger.debug("Recents %r", batch)
 
-        current_state = yield self.get_state_at(room_id, now_token)
+        if batch.limited:
+            current_state = yield self.get_state_at(room_id, now_token)
 
-        state_at_previous_sync = yield self.get_state_at(
-            room_id, stream_position=since_token
-        )
+            state_at_previous_sync = yield self.get_state_at(
+                room_id, stream_position=since_token
+            )
 
-        state = yield self.compute_state_delta(
-            since_token=since_token,
-            previous_state=state_at_previous_sync,
-            current_state=current_state,
-        )
+            state = yield self.compute_state_delta(
+                since_token=since_token,
+                previous_state=state_at_previous_sync,
+                current_state=current_state,
+            )
+        else:
+            state = {
+                (event.type, event.state_key): event
+                for event in batch.events if event.is_state()
+            }
 
         just_joined = yield self.check_joined_room(sync_config, state)
         if just_joined:
@@ -711,14 +783,29 @@ class SyncHandler(BaseHandler):
                 1 for notif in notifs if _action_has_highlight(notif["actions"])
             ])
 
+        state = {
+            (e.type, e.state_key): e
+            for e in sync_config.filter_collection.filter_room_state(state.values())
+        }
+
+        account_data = self.account_data_for_room(
+            room_id, tags_by_room, account_data_by_room
+        )
+
+        account_data = sync_config.filter_collection.filter_room_account_data(
+            account_data
+        )
+
+        ephemeral = sync_config.filter_collection.filter_room_ephemeral(
+            ephemeral_by_room.get(room_id, [])
+        )
+
         room_sync = JoinedSyncResult(
             room_id=room_id,
             timeline=batch,
             state=state,
-            ephemeral=ephemeral_by_room.get(room_id, []),
-            account_data=self.account_data_for_room(
-                room_id, tags_by_room, account_data_by_room
-            ),
+            ephemeral=ephemeral,
+            account_data=account_data,
             unread_notifications=unread_notifications,
         )
 
@@ -765,13 +852,26 @@ class SyncHandler(BaseHandler):
             current_state=state_events_at_leave,
         )
 
+        state_events_delta = {
+            (e.type, e.state_key): e
+            for e in sync_config.filter_collection.filter_room_state(
+                state_events_delta.values()
+            )
+        }
+
+        account_data = self.account_data_for_room(
+            leave_event.room_id, tags_by_room, account_data_by_room
+        )
+
+        account_data = sync_config.filter_collection.filter_room_account_data(
+            account_data
+        )
+
         room_sync = ArchivedSyncResult(
             room_id=leave_event.room_id,
             timeline=batch,
             state=state_events_delta,
-            account_data=self.account_data_for_room(
-                leave_event.room_id, tags_by_room, account_data_by_room
-            ),
+            account_data=account_data,
         )
 
         logger.debug("Room sync: %r", room_sync)
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index e6a28bd8c0..9bc0b356f4 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -17,7 +17,6 @@ from twisted.internet import defer
 
 from synapse.streams.config import PaginationConfig
 from synapse.types import StreamToken
-from synapse.api.constants import Membership
 
 import synapse.util.async
 import push_rule_evaluator as push_rule_evaluator
@@ -296,31 +295,28 @@ class Pusher(object):
 
     @defer.inlineCallbacks
     def _get_badge_count(self):
-        room_list = yield self.store.get_rooms_for_user_where_membership_is(
-            user_id=self.user_id,
-            membership_list=(Membership.INVITE, Membership.JOIN)
-        )
+        invites, joins = yield defer.gatherResults([
+            self.store.get_invites_for_user(self.user_id),
+            self.store.get_rooms_for_user(self.user_id),
+        ], consumeErrors=True)
 
         my_receipts_by_room = yield self.store.get_receipts_for_user(
             self.user_id,
             "m.read",
         )
 
-        badge = 0
+        badge = len(invites)
 
-        for r in room_list:
-            if r.membership == Membership.INVITE:
-                badge += 1
-            else:
-                if r.room_id in my_receipts_by_room:
-                    last_unread_event_id = my_receipts_by_room[r.room_id]
+        for r in joins:
+            if r.room_id in my_receipts_by_room:
+                last_unread_event_id = my_receipts_by_room[r.room_id]
 
-                    notifs = yield (
-                        self.store.get_unread_event_push_actions_by_room_for_user(
-                            r.room_id, self.user_id, last_unread_event_id
-                        )
+                notifs = yield (
+                    self.store.get_unread_event_push_actions_by_room_for_user(
+                        r.room_id, self.user_id, last_unread_event_id
                     )
-                    badge += len(notifs)
+                )
+                badge += len(notifs)
         defer.returnValue(badge)
 
 
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index ab924ad9e0..07b5b5dfd5 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -130,7 +130,7 @@ class SyncRestServlet(RestServlet):
 
         sync_config = SyncConfig(
             user=user,
-            filter=filter,
+            filter_collection=filter,
             is_guest=requester.is_guest,
         )
 
@@ -154,23 +154,21 @@ class SyncRestServlet(RestServlet):
         time_now = self.clock.time_msec()
 
         joined = self.encode_joined(
-            sync_result.joined, filter, time_now, requester.access_token_id
+            sync_result.joined, time_now, requester.access_token_id
         )
 
         invited = self.encode_invited(
-            sync_result.invited, filter, time_now, requester.access_token_id
+            sync_result.invited, time_now, requester.access_token_id
         )
 
         archived = self.encode_archived(
-            sync_result.archived, filter, time_now, requester.access_token_id
+            sync_result.archived, time_now, requester.access_token_id
         )
 
         response_content = {
-            "account_data": self.encode_account_data(
-                sync_result.account_data, filter, time_now
-            ),
+            "account_data": {"events": sync_result.account_data},
             "presence": self.encode_presence(
-                sync_result.presence, filter, time_now
+                sync_result.presence, time_now
             ),
             "rooms": {
                 "join": joined,
@@ -182,24 +180,20 @@ class SyncRestServlet(RestServlet):
 
         defer.returnValue((200, response_content))
 
-    def encode_presence(self, events, filter, time_now):
+    def encode_presence(self, events, time_now):
         formatted = []
         for event in events:
             event = copy.deepcopy(event)
             event['sender'] = event['content'].pop('user_id')
             formatted.append(event)
-        return {"events": filter.filter_presence(formatted)}
-
-    def encode_account_data(self, events, filter, time_now):
-        return {"events": filter.filter_account_data(events)}
+        return {"events": formatted}
 
-    def encode_joined(self, rooms, filter, time_now, token_id):
+    def encode_joined(self, rooms, time_now, token_id):
         """
         Encode the joined rooms in a sync result
 
         :param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync
             results for rooms this user is joined to
-        :param FilterCollection filter: filters to apply to the results
         :param int time_now: current time - used as a baseline for age
             calculations
         :param int token_id: ID of the user's auth token - used for namespacing
@@ -211,18 +205,17 @@ class SyncRestServlet(RestServlet):
         joined = {}
         for room in rooms:
             joined[room.room_id] = self.encode_room(
-                room, filter, time_now, token_id
+                room, time_now, token_id
             )
 
         return joined
 
-    def encode_invited(self, rooms, filter, time_now, token_id):
+    def encode_invited(self, rooms, time_now, token_id):
         """
         Encode the invited rooms in a sync result
 
         :param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of
              sync results for rooms this user is joined to
-        :param FilterCollection filter: filters to apply to the results
         :param int time_now: current time - used as a baseline for age
             calculations
         :param int token_id: ID of the user's auth token - used for namespacing
@@ -237,7 +230,9 @@ class SyncRestServlet(RestServlet):
                 room.invite, time_now, token_id=token_id,
                 event_format=format_event_for_client_v2_without_room_id,
             )
-            invited_state = invite.get("unsigned", {}).pop("invite_room_state", [])
+            unsigned = dict(invite.get("unsigned", {}))
+            invite["unsigned"] = unsigned
+            invited_state = list(unsigned.pop("invite_room_state", []))
             invited_state.append(invite)
             invited[room.room_id] = {
                 "invite_state": {"events": invited_state}
@@ -245,13 +240,12 @@ class SyncRestServlet(RestServlet):
 
         return invited
 
-    def encode_archived(self, rooms, filter, time_now, token_id):
+    def encode_archived(self, rooms, time_now, token_id):
         """
         Encode the archived rooms in a sync result
 
         :param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of
              sync results for rooms this user is joined to
-        :param FilterCollection filter: filters to apply to the results
         :param int time_now: current time - used as a baseline for age
             calculations
         :param int token_id: ID of the user's auth token - used for namespacing
@@ -263,17 +257,16 @@ class SyncRestServlet(RestServlet):
         joined = {}
         for room in rooms:
             joined[room.room_id] = self.encode_room(
-                room, filter, time_now, token_id, joined=False
+                room, time_now, token_id, joined=False
             )
 
         return joined
 
     @staticmethod
-    def encode_room(room, filter, time_now, token_id, joined=True):
+    def encode_room(room, time_now, token_id, joined=True):
         """
         :param JoinedSyncResult|ArchivedSyncResult room: sync result for a
             single room
-        :param FilterCollection filter: filters to apply to the results
         :param int time_now: current time - used as a baseline for age
             calculations
         :param int token_id: ID of the user's auth token - used for namespacing
@@ -292,19 +285,17 @@ class SyncRestServlet(RestServlet):
             )
 
         state_dict = room.state
-        timeline_events = filter.filter_room_timeline(room.timeline.events)
+        timeline_events = room.timeline.events
 
         state_dict = SyncRestServlet._rollback_state_for_timeline(
             state_dict, timeline_events)
 
-        state_events = filter.filter_room_state(state_dict.values())
+        state_events = state_dict.values()
 
         serialized_state = [serialize(e) for e in state_events]
         serialized_timeline = [serialize(e) for e in timeline_events]
 
-        account_data = filter.filter_room_account_data(
-            room.account_data
-        )
+        account_data = room.account_data
 
         result = {
             "timeline": {
@@ -317,7 +308,7 @@ class SyncRestServlet(RestServlet):
         }
 
         if joined:
-            ephemeral_events = filter.filter_room_ephemeral(room.ephemeral)
+            ephemeral_events = room.ephemeral
             result["ephemeral"] = {"events": ephemeral_events}
             result["unread_notifications"] = room.unread_notifications
 
@@ -334,8 +325,6 @@ class SyncRestServlet(RestServlet):
         :param list[synapse.events.EventBase] timeline: the event timeline
         :return: updated state dictionary
         """
-        logger.debug("Processing state dict %r; timeline %r", state,
-                     [e.get_dict() for e in timeline])
 
         result = state.copy()
 
@@ -357,8 +346,8 @@ class SyncRestServlet(RestServlet):
                 # the event graph, and the state is no longer valid. Really,
                 # the event shouldn't be in the timeline. We're going to ignore
                 # it for now, however.
-                logger.warn("Found state event %r in timeline which doesn't "
-                            "match state dictionary", timeline_event)
+                logger.debug("Found state event %r in timeline which doesn't "
+                             "match state dictionary", timeline_event)
                 continue
 
             prev_event_id = timeline_event.unsigned.get("replaces_state", None)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 183a752387..90d7aee94a 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -629,6 +629,78 @@ class SQLBaseStore(object):
 
         return self.cursor_to_dict(txn)
 
+    @defer.inlineCallbacks
+    def _simple_select_many_batch(self, table, column, iterable, retcols,
+                                  keyvalues={}, desc="_simple_select_many_batch",
+                                  batch_size=100):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        results = []
+
+        if not iterable:
+            defer.returnValue(results)
+
+        chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)]
+        for chunk in chunks:
+            rows = yield self.runInteraction(
+                desc,
+                self._simple_select_many_txn,
+                table, column, chunk, keyvalues, retcols
+            )
+
+            results.extend(rows)
+
+        defer.returnValue(results)
+
+    def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        if not iterable:
+            return []
+
+        sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+
+        clauses = []
+        values = []
+        clauses.append(
+            "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
+        )
+        values.extend(iterable)
+
+        for key, value in keyvalues.items():
+            clauses.append("%s = ?" % (key,))
+            values.append(value)
+
+        if clauses:
+            sql = "%s WHERE %s" % (
+                sql,
+                " AND ".join(clauses),
+            )
+
+        txn.execute(sql, values)
+        return self.cursor_to_dict(txn)
+
     def _simple_update_one(self, table, keyvalues, updatevalues,
                            desc="_simple_update_one"):
         """Executes an UPDATE query on the named table, setting new values for
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 1095d52ace..9b3aecaf8c 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -48,24 +48,25 @@ class PresenceStore(SQLBaseStore):
             desc="get_presence_state",
         )
 
-    @cachedList(get_presence_state.cache, list_name="user_localparts")
+    @cachedList(get_presence_state.cache, list_name="user_localparts",
+                inlineCallbacks=True)
     def get_presence_states(self, user_localparts):
-        def f(txn):
-            results = {}
-            for user_localpart in user_localparts:
-                res = self._simple_select_one_txn(
-                    txn,
-                    table="presence",
-                    keyvalues={"user_id": user_localpart},
-                    retcols=["state", "status_msg", "mtime"],
-                    allow_none=True,
-                )
-                if res:
-                    results[user_localpart] = res
-
-            return results
-
-        return self.runInteraction("get_presence_states", f)
+        rows = yield self._simple_select_many_batch(
+            table="presence",
+            column="user_id",
+            iterable=user_localparts,
+            retcols=("user_id", "state", "status_msg", "mtime",),
+            desc="get_presence_states",
+        )
+
+        defer.returnValue({
+            row["user_id"]: {
+                "state": row["state"],
+                "status_msg": row["status_msg"],
+                "mtime": row["mtime"],
+            }
+            for row in rows
+        })
 
     def set_presence_state(self, user_localpart, new_state):
         res = self._simple_update_one(
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 35ec7e8cef..1f51c90ee5 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -65,32 +65,20 @@ class PushRuleStore(SQLBaseStore):
         if not user_ids:
             defer.returnValue({})
 
-        batch_size = 100
-
-        def f(txn, user_ids_to_fetch):
-            sql = (
-                "SELECT pr.*"
-                " FROM push_rules AS pr"
-                " LEFT JOIN push_rules_enable AS pre"
-                " ON pr.user_name = pre.user_name AND pr.rule_id = pre.rule_id"
-                " WHERE pr.user_name"
-                " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
-                " AND (pre.enabled IS NULL OR pre.enabled = 1)"
-                " ORDER BY pr.user_name, pr.priority_class DESC, pr.priority DESC"
-            )
-            txn.execute(sql, user_ids_to_fetch)
-            return self.cursor_to_dict(txn)
-
         results = {}
 
-        chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
-        for batch_user_ids in chunks:
-            rows = yield self.runInteraction(
-                "bulk_get_push_rules", f, batch_user_ids
-            )
+        rows = yield self._simple_select_many_batch(
+            table="push_rules",
+            column="user_name",
+            iterable=user_ids,
+            retcols=("*",),
+            desc="bulk_get_push_rules",
+        )
+
+        rows.sort(key=lambda e: (-e["priority_class"], -e["priority"]))
 
-            for row in rows:
-                results.setdefault(row['user_name'], []).append(row)
+        for row in rows:
+            results.setdefault(row['user_name'], []).append(row)
         defer.returnValue(results)
 
     @defer.inlineCallbacks
@@ -98,28 +86,17 @@ class PushRuleStore(SQLBaseStore):
         if not user_ids:
             defer.returnValue({})
 
-        batch_size = 100
-
-        def f(txn, user_ids_to_fetch):
-            sql = (
-                "SELECT user_name, rule_id, enabled"
-                " FROM push_rules_enable"
-                " WHERE user_name"
-                " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
-            )
-            txn.execute(sql, user_ids_to_fetch)
-            return self.cursor_to_dict(txn)
-
         results = {}
 
-        chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
-        for batch_user_ids in chunks:
-            rows = yield self.runInteraction(
-                "bulk_get_push_rules_enabled", f, batch_user_ids
-            )
-
-            for row in rows:
-                results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
+        rows = yield self._simple_select_many_batch(
+            table="push_rules_enable",
+            column="user_name",
+            iterable=user_ids,
+            retcols=("user_name", "rule_id", "enabled",),
+            desc="bulk_get_push_rules_enabled",
+        )
+        for row in rows:
+            results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
         defer.returnValue(results)
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 68ac88905f..edfecced05 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -110,6 +110,7 @@ class RoomMemberStore(SQLBaseStore):
             membership=membership,
         ).addCallback(self._get_events)
 
+    @cached()
     def get_invites_for_user(self, user_id):
         """ Get all the invite events for a user
         Args:
diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/schema/delta/28/event_push_actions.sql
index bdf6ae3f24..4d519849df 100644
--- a/synapse/storage/schema/delta/28/event_push_actions.sql
+++ b/synapse/storage/schema/delta/28/event_push_actions.sql
@@ -24,3 +24,4 @@ CREATE TABLE IF NOT EXISTS event_push_actions(
 
 
 CREATE INDEX event_push_actions_room_id_event_id_user_id_profile_tag on event_push_actions(room_id, event_id, user_id, profile_tag);
+CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id);
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 16ee6bbe6a..1a4e439d30 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -13,26 +13,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from collections import namedtuple
 from tests import unittest
 from twisted.internet import defer
 
-from mock import Mock, NonCallableMock
+from mock import Mock
 from tests.utils import (
     MockHttpResource, DeferredMockCallable, setup_test_homeserver
 )
 
 from synapse.types import UserID
-from synapse.api.filtering import FilterCollection, Filter
+from synapse.api.filtering import Filter
+from synapse.events import FrozenEvent
 
 user_localpart = "test_user"
 # MockEvent = namedtuple("MockEvent", "sender type room_id")
 
 
 def MockEvent(**kwargs):
-    ev = NonCallableMock(spec_set=kwargs.keys())
-    ev.configure_mock(**kwargs)
-    return ev
+    return FrozenEvent(kwargs)
 
 
 class FilteringTestCase(unittest.TestCase):