summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/_base.py7
-rw-r--r--synapse/handlers/events.py10
-rw-r--r--synapse/handlers/message.py47
-rw-r--r--synapse/handlers/presence.py4
-rw-r--r--synapse/handlers/private_user_data.py2
-rw-r--r--synapse/handlers/receipts.py6
-rw-r--r--synapse/handlers/room.py11
-rw-r--r--synapse/handlers/sync.py20
-rw-r--r--synapse/handlers/typing.py11
-rw-r--r--synapse/notifier.py42
-rw-r--r--synapse/rest/client/v1/events.py13
-rw-r--r--synapse/rest/client/v1/room.py6
-rw-r--r--synapse/storage/events.py2
-rw-r--r--synapse/storage/room.py13
-rw-r--r--synapse/storage/schema/delta/25/history_visibility.sql26
-rw-r--r--synapse/storage/stream.py46
-rw-r--r--tests/handlers/test_presence.py71
-rw-r--r--tests/handlers/test_typing.py30
-rw-r--r--tests/rest/client/v1/test_presence.py9
-rw-r--r--tests/rest/client/v1/test_typing.py5
20 files changed, 299 insertions, 82 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 6873a4575d..a9e43052b7 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -47,7 +47,8 @@ class BaseHandler(object):
         self.event_builder_factory = hs.get_event_builder_factory()
 
     @defer.inlineCallbacks
-    def _filter_events_for_client(self, user_id, events, is_guest=False):
+    def _filter_events_for_client(self, user_id, events, is_guest=False,
+                                  require_all_visible_for_guests=True):
         # Assumes that user has at some point joined the room if not is_guest.
 
         def allowed(event, membership, visibility):
@@ -100,7 +101,9 @@ class BaseHandler(object):
             if should_include:
                 events_to_return.append(event)
 
-        if is_guest and len(events_to_return) < len(events):
+        if (require_all_visible_for_guests
+                and is_guest
+                and len(events_to_return) < len(events)):
             # This indicates that some events in the requested range were not
             # visible to guest users. To be safe, we reject the entire request,
             # so that we don't have to worry about interpreting visibility
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 53c8ca3a26..0e4c0d4d06 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -100,7 +100,7 @@ class EventStreamHandler(BaseHandler):
     @log_function
     def get_stream(self, auth_user_id, pagin_config, timeout=0,
                    as_client_event=True, affect_presence=True,
-                   only_room_events=False):
+                   only_room_events=False, room_id=None, is_guest=False):
         """Fetches the events stream for a given user.
 
         If `only_room_events` is `True` only room events will be returned.
@@ -119,9 +119,15 @@ class EventStreamHandler(BaseHandler):
                 # thundering herds on restart.
                 timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
 
+            if is_guest:
+                yield self.distributor.fire(
+                    "user_joined_room", user=auth_user, room_id=room_id
+                )
+
             events, tokens = yield self.notifier.get_events_for(
                 auth_user, pagin_config, timeout,
-                only_room_events=only_room_events
+                only_room_events=only_room_events,
+                is_guest=is_guest, guest_room_id=room_id
             )
 
             time_now = self.clock.time_msec()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 687e1527f7..654ecd2b37 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, AuthError, Codes
 from synapse.streams.config import PaginationConfig
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
@@ -229,7 +229,7 @@ class MessageHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def get_room_data(self, user_id=None, room_id=None,
-                      event_type=None, state_key=""):
+                      event_type=None, state_key="", is_guest=False):
         """ Get data from a room.
 
         Args:
@@ -239,23 +239,42 @@ class MessageHandler(BaseHandler):
         Raises:
             SynapseError if something went wrong.
         """
-        member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+        membership, membership_event_id = yield self._check_in_room_or_world_readable(
+            room_id, user_id, is_guest
+        )
 
-        if member_event.membership == Membership.JOIN:
+        if membership == Membership.JOIN:
             data = yield self.state_handler.get_current_state(
                 room_id, event_type, state_key
             )
-        elif member_event.membership == Membership.LEAVE:
+        elif membership == Membership.LEAVE:
             key = (event_type, state_key)
             room_state = yield self.store.get_state_for_events(
-                [member_event.event_id], [key]
+                [membership_event_id], [key]
             )
-            data = room_state[member_event.event_id].get(key)
+            data = room_state[membership_event_id].get(key)
 
         defer.returnValue(data)
 
     @defer.inlineCallbacks
-    def get_state_events(self, user_id, room_id):
+    def _check_in_room_or_world_readable(self, room_id, user_id, is_guest):
+        if is_guest:
+            visibility = yield self.state_handler.get_current_state(
+                room_id, EventTypes.RoomHistoryVisibility, ""
+            )
+            if visibility.content["history_visibility"] == "world_readable":
+                defer.returnValue((Membership.JOIN, None))
+                return
+            else:
+                raise AuthError(
+                    403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+                )
+        else:
+            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+            defer.returnValue((member_event.membership, member_event.event_id))
+
+    @defer.inlineCallbacks
+    def get_state_events(self, user_id, room_id, is_guest=False):
         """Retrieve all state events for a given room. If the user is
         joined to the room then return the current state. If the user has
         left the room return the state events from when they left.
@@ -266,15 +285,17 @@ class MessageHandler(BaseHandler):
         Returns:
             A list of dicts representing state events. [{}, {}, {}]
         """
-        member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+        membership, membership_event_id = yield self._check_in_room_or_world_readable(
+            room_id, user_id, is_guest
+        )
 
-        if member_event.membership == Membership.JOIN:
+        if membership == Membership.JOIN:
             room_state = yield self.state_handler.get_current_state(room_id)
-        elif member_event.membership == Membership.LEAVE:
+        elif membership == Membership.LEAVE:
             room_state = yield self.store.get_state_for_events(
-                [member_event.event_id], None
+                [membership_event_id], None
             )
-            room_state = room_state[member_event.event_id]
+            room_state = room_state[membership_event_id]
 
         now = self.clock.time_msec()
         defer.returnValue(
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index ce60642127..0b780cd528 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1142,8 +1142,9 @@ class PresenceEventSource(object):
 
     @defer.inlineCallbacks
     @log_function
-    def get_new_events_for_user(self, user, from_key, limit):
+    def get_new_events(self, user, from_key, room_ids=None, **kwargs):
         from_key = int(from_key)
+        room_ids = room_ids or []
 
         presence = self.hs.get_handlers().presence_handler
         cachemap = presence._user_cachemap
@@ -1161,7 +1162,6 @@ class PresenceEventSource(object):
             user_ids_to_check |= set(
                 UserID.from_string(p["observed_user_id"]) for p in presence_list
             )
-        room_ids = yield presence.get_joined_rooms_for_user(user)
         for room_id in set(room_ids) & set(presence._room_serials):
             if presence._room_serials[room_id] > from_key:
                 joined = yield presence.get_joined_users_for_room_id(room_id)
diff --git a/synapse/handlers/private_user_data.py b/synapse/handlers/private_user_data.py
index 1778c71325..1abe45ed7b 100644
--- a/synapse/handlers/private_user_data.py
+++ b/synapse/handlers/private_user_data.py
@@ -24,7 +24,7 @@ class PrivateUserDataEventSource(object):
         return self.store.get_max_private_user_data_stream_id()
 
     @defer.inlineCallbacks
-    def get_new_events_for_user(self, user, from_key, limit):
+    def get_new_events(self, user, from_key, **kwargs):
         user_id = user.to_string()
         last_stream_id = from_key
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a47ae3df42..973f4d5cae 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -164,17 +164,15 @@ class ReceiptEventSource(object):
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
-    def get_new_events_for_user(self, user, from_key, limit):
+    def get_new_events(self, from_key, room_ids, **kwargs):
         from_key = int(from_key)
         to_key = yield self.get_current_key()
 
         if from_key == to_key:
             defer.returnValue(([], to_key))
 
-        rooms = yield self.store.get_rooms_for_user(user.to_string())
-        rooms = [room.room_id for room in rooms]
         events = yield self.store.get_linearized_receipts_for_rooms(
-            rooms,
+            room_ids,
             from_key=from_key,
             to_key=to_key,
         )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9184dcd048..736ffe9066 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -807,7 +807,14 @@ class RoomEventSource(object):
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
-    def get_new_events_for_user(self, user, from_key, limit):
+    def get_new_events(
+            self,
+            user,
+            from_key,
+            limit,
+            room_ids,
+            is_guest,
+    ):
         # We just ignore the key for now.
 
         to_key = yield self.get_current_key()
@@ -828,6 +835,8 @@ class RoomEventSource(object):
                 from_key=from_key,
                 to_key=to_key,
                 limit=limit,
+                room_ids=room_ids,
+                is_guest=is_guest,
             )
 
         defer.returnValue((events, end_key))
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 1c1ee34b1e..5294d96466 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -295,11 +295,16 @@ class SyncHandler(BaseHandler):
 
         typing_key = since_token.typing_key if since_token else "0"
 
+        rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
+        room_ids = [room.room_id for room in rooms]
+
         typing_source = self.event_sources.sources["typing"]
-        typing, typing_key = yield typing_source.get_new_events_for_user(
+        typing, typing_key = yield typing_source.get_new_events(
             user=sync_config.user,
             from_key=typing_key,
             limit=sync_config.filter.ephemeral_limit(),
+            room_ids=room_ids,
+            is_guest=False,
         )
         now_token = now_token.copy_and_replace("typing_key", typing_key)
 
@@ -312,10 +317,13 @@ class SyncHandler(BaseHandler):
         receipt_key = since_token.receipt_key if since_token else "0"
 
         receipt_source = self.event_sources.sources["receipt"]
-        receipts, receipt_key = yield receipt_source.get_new_events_for_user(
+        receipts, receipt_key = yield receipt_source.get_new_events(
             user=sync_config.user,
             from_key=receipt_key,
             limit=sync_config.filter.ephemeral_limit(),
+            room_ids=room_ids,
+            # /sync doesn't support guest access, they can't get to this point in code
+            is_guest=False,
         )
         now_token = now_token.copy_and_replace("receipt_key", receipt_key)
 
@@ -360,11 +368,17 @@ class SyncHandler(BaseHandler):
         """
         now_token = yield self.event_sources.get_current_token()
 
+        rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
+        room_ids = [room.room_id for room in rooms]
+
         presence_source = self.event_sources.sources["presence"]
-        presence, presence_key = yield presence_source.get_new_events_for_user(
+        presence, presence_key = yield presence_source.get_new_events(
             user=sync_config.user,
             from_key=since_token.presence_key,
             limit=sync_config.filter.presence_limit(),
+            room_ids=room_ids,
+            # /sync doesn't support guest access, they can't get to this point in code
+            is_guest=False,
         )
         now_token = now_token.copy_and_replace("presence_key", presence_key)
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index d7096aab8c..2846f3e6e8 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -246,17 +246,12 @@ class TypingNotificationEventSource(object):
             },
         }
 
-    @defer.inlineCallbacks
-    def get_new_events_for_user(self, user, from_key, limit):
+    def get_new_events(self, from_key, room_ids, **kwargs):
         from_key = int(from_key)
         handler = self.handler()
 
-        joined_room_ids = (
-            yield self.room_member_handler().get_joined_rooms_for_user(user)
-        )
-
         events = []
-        for room_id in joined_room_ids:
+        for room_id in room_ids:
             if room_id not in handler._room_serials:
                 continue
             if handler._room_serials[room_id] <= from_key:
@@ -264,7 +259,7 @@ class TypingNotificationEventSource(object):
 
             events.append(self._make_event_for(room_id))
 
-        defer.returnValue((events, handler._latest_room_serial))
+        return events, handler._latest_room_serial
 
     def get_current_key(self):
         return self.handler()._latest_room_serial
diff --git a/synapse/notifier.py b/synapse/notifier.py
index b69da63d43..56c4c863b5 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -269,7 +269,7 @@ class Notifier(object):
                 logger.exception("Failed to notify listener")
 
     @defer.inlineCallbacks
-    def wait_for_events(self, user, timeout, callback,
+    def wait_for_events(self, user, timeout, callback, room_ids=None,
                         from_token=StreamToken("s0", "0", "0", "0", "0")):
         """Wait until the callback returns a non empty response or the
         timeout fires.
@@ -279,11 +279,12 @@ class Notifier(object):
         if user_stream is None:
             appservice = yield self.store.get_app_service_by_user_id(user)
             current_token = yield self.event_sources.get_current_token()
-            rooms = yield self.store.get_rooms_for_user(user)
-            rooms = [room.room_id for room in rooms]
+            if room_ids is None:
+                rooms = yield self.store.get_rooms_for_user(user)
+                room_ids = [room.room_id for room in rooms]
             user_stream = _NotifierUserStream(
                 user=user,
-                rooms=rooms,
+                rooms=room_ids,
                 appservice=appservice,
                 current_token=current_token,
                 time_now_ms=self.clock.time_msec(),
@@ -329,7 +330,8 @@ class Notifier(object):
 
     @defer.inlineCallbacks
     def get_events_for(self, user, pagination_config, timeout,
-                       only_room_events=False):
+                       only_room_events=False,
+                       is_guest=False, guest_room_id=None):
         """ For the given user and rooms, return any new events for them. If
         there are no new events wait for up to `timeout` milliseconds for any
         new events to happen before returning.
@@ -342,6 +344,16 @@ class Notifier(object):
 
         limit = pagination_config.limit
 
+        room_ids = []
+        if is_guest:
+            # TODO(daniel): Deal with non-room events too
+            only_room_events = True
+            if guest_room_id:
+                room_ids = [guest_room_id]
+        else:
+            rooms = yield self.store.get_rooms_for_user(user.to_string())
+            room_ids = [room.room_id for room in rooms]
+
         @defer.inlineCallbacks
         def check_for_updates(before_token, after_token):
             if not after_token.is_after(before_token):
@@ -357,9 +369,23 @@ class Notifier(object):
                     continue
                 if only_room_events and name != "room":
                     continue
-                new_events, new_key = yield source.get_new_events_for_user(
-                    user, getattr(from_token, keyname), limit,
+                new_events, new_key = yield source.get_new_events(
+                    user=user,
+                    from_key=getattr(from_token, keyname),
+                    limit=limit,
+                    is_guest=is_guest,
+                    room_ids=room_ids,
                 )
+
+                if is_guest:
+                    room_member_handler = self.hs.get_handlers().room_member_handler
+                    new_events = yield room_member_handler._filter_events_for_client(
+                        user.to_string(),
+                        new_events,
+                        is_guest=is_guest,
+                        require_all_visible_for_guests=False
+                    )
+
                 events.extend(new_events)
                 end_token = end_token.copy_and_replace(keyname, new_key)
 
@@ -369,7 +395,7 @@ class Notifier(object):
                 defer.returnValue(None)
 
         result = yield self.wait_for_events(
-            user, timeout, check_for_updates, from_token=from_token
+            user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token
         )
 
         if result is None:
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 4073b0d2d1..3e1750d1a1 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -34,7 +34,15 @@ class EventStreamRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        auth_user, _, _ = yield self.auth.get_user_by_req(request)
+        auth_user, _, is_guest = yield self.auth.get_user_by_req(
+            request,
+            allow_guest=True
+        )
+        room_id = None
+        if is_guest:
+            if "room_id" not in request.args:
+                raise SynapseError(400, "Guest users must specify room_id param")
+            room_id = request.args["room_id"][0]
         try:
             handler = self.handlers.event_stream_handler
             pagin_config = PaginationConfig.from_request(request)
@@ -49,7 +57,8 @@ class EventStreamRestServlet(ClientV1RestServlet):
 
             chunk = yield handler.get_stream(
                 auth_user.to_string(), pagin_config, timeout=timeout,
-                as_client_event=as_client_event
+                as_client_event=as_client_event, affect_presence=(not is_guest),
+                room_id=room_id, is_guest=is_guest
             )
         except:
             logger.exception("Event stream failed")
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 0876e593c5..afb802baec 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, room_id, event_type, state_key):
-        user, _, _ = yield self.auth.get_user_by_req(request)
+        user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
 
         msg_handler = self.handlers.message_handler
         data = yield msg_handler.get_room_data(
@@ -133,6 +133,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
             room_id=room_id,
             event_type=event_type,
             state_key=state_key,
+            is_guest=is_guest,
         )
 
         if not data:
@@ -348,12 +349,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, room_id):
-        user, _, _ = yield self.auth.get_user_by_req(request)
+        user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
         handler = self.handlers.message_handler
         # Get all the current state for this room
         events = yield handler.get_state_events(
             room_id=room_id,
             user_id=user.to_string(),
+            is_guest=is_guest,
         )
         defer.returnValue((200, events))
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index e6c1abfc27..59c9987202 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -311,6 +311,8 @@ class EventsStore(SQLBaseStore):
                 self._store_room_message_txn(txn, event)
             elif event.type == EventTypes.Redaction:
                 self._store_redaction(txn, event)
+            elif event.type == EventTypes.RoomHistoryVisibility:
+                self._store_history_visibility_txn(txn, event)
 
         self._store_room_members_txn(
             txn,
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 13441fcdce..1c79626736 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -202,6 +202,19 @@ class RoomStore(SQLBaseStore):
                 txn, event, "content.body", event.content["body"]
             )
 
+    def _store_history_visibility_txn(self, txn, event):
+        if hasattr(event, "content") and "history_visibility" in event.content:
+            sql = (
+                "INSERT INTO history_visibility"
+                " (event_id, room_id, history_visibility)"
+                " VALUES (?, ?, ?)"
+            )
+            txn.execute(sql, (
+                event.event_id,
+                event.room_id,
+                event.content["history_visibility"]
+            ))
+
     def _store_event_search_txn(self, txn, event, key, value):
         if isinstance(self.database_engine, PostgresEngine):
             sql = (
diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/schema/delta/25/history_visibility.sql
new file mode 100644
index 0000000000..9f387ed69f
--- /dev/null
+++ b/synapse/storage/schema/delta/25/history_visibility.sql
@@ -0,0 +1,26 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This is a manual index of history_visibility content of state events,
+ * so that we can join on them in SELECT statements.
+ */
+CREATE TABLE IF NOT EXISTS history_visibility(
+    id INTEGER PRIMARY KEY,
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    history_visibility TEXT NOT NULL,
+    UNIQUE (event_id)
+);
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index c728013f4c..be8ba76aae 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -158,13 +158,40 @@ class StreamStore(SQLBaseStore):
         defer.returnValue(results)
 
     @log_function
-    def get_room_events_stream(self, user_id, from_key, to_key, limit=0):
-        current_room_membership_sql = (
-            "SELECT m.room_id FROM room_memberships as m "
-            " INNER JOIN current_state_events as c"
-            " ON m.event_id = c.event_id AND c.state_key = m.user_id"
-            " WHERE m.user_id = ? AND m.membership = 'join'"
-        )
+    def get_room_events_stream(
+        self,
+        user_id,
+        from_key,
+        to_key,
+        limit=0,
+        is_guest=False,
+        room_ids=None
+    ):
+        room_ids = room_ids or []
+        room_ids = [r for r in room_ids]
+        if is_guest:
+            current_room_membership_sql = (
+                "SELECT c.room_id FROM history_visibility AS h"
+                " INNER JOIN current_state_events AS c"
+                " ON h.event_id = c.event_id"
+                " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
+                    ",".join(map(lambda _: "?", room_ids))
+                )
+            )
+            current_room_membership_args = room_ids
+        else:
+            current_room_membership_sql = (
+                "SELECT m.room_id FROM room_memberships as m "
+                " INNER JOIN current_state_events as c"
+                " ON m.event_id = c.event_id AND c.state_key = m.user_id"
+                " WHERE m.user_id = ? AND m.membership = 'join'"
+            )
+            current_room_membership_args = [user_id]
+            if room_ids:
+                current_room_membership_sql += " AND m.room_id in (%s)" % (
+                    ",".join(map(lambda _: "?", room_ids))
+                )
+                current_room_membership_args = [user_id] + room_ids
 
         # We also want to get any membership events about that user, e.g.
         # invites or leave notifications.
@@ -173,6 +200,7 @@ class StreamStore(SQLBaseStore):
             "INNER JOIN current_state_events as c ON m.event_id = c.event_id "
             "WHERE m.user_id = ? "
         )
+        membership_args = [user_id]
 
         if limit:
             limit = max(limit, MAX_STREAM_SIZE)
@@ -199,7 +227,9 @@ class StreamStore(SQLBaseStore):
         }
 
         def f(txn):
-            txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,))
+            args = ([False] + current_room_membership_args + membership_args +
+                    [from_id.stream, to_id.stream])
+            txn.execute(sql, args)
 
             rows = self.cursor_to_dict(txn)
 
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 29372d488a..10d4482cce 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -650,9 +650,30 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
             {"presence": ONLINE}
         )
 
+        # Apple sees self-reflection even without room_id
+        (events, _) = yield self.event_source.get_new_events(
+            user=self.u_apple,
+            from_key=0,
+        )
+
+        self.assertEquals(self.event_source.get_current_key(), 1)
+        self.assertEquals(events,
+            [
+                {"type": "m.presence",
+                 "content": {
+                    "user_id": "@apple:test",
+                    "presence": ONLINE,
+                    "last_active_ago": 0,
+                }},
+            ],
+            msg="Presence event should be visible to self-reflection"
+        )
+
         # Apple sees self-reflection
-        (events, _) = yield self.event_source.get_new_events_for_user(
-            self.u_apple, 0, None
+        (events, _) = yield self.event_source.get_new_events(
+            user=self.u_apple,
+            from_key=0,
+            room_ids=[self.room_id],
         )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
@@ -684,8 +705,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
         )
 
         # Banana sees it because of presence subscription
-        (events, _) = yield self.event_source.get_new_events_for_user(
-            self.u_banana, 0, None
+        (events, _) = yield self.event_source.get_new_events(
+            user=self.u_banana,
+            from_key=0,
+            room_ids=[self.room_id],
         )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
@@ -702,8 +725,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
         )
 
         # Elderberry sees it because of same room
-        (events, _) = yield self.event_source.get_new_events_for_user(
-            self.u_elderberry, 0, None
+        (events, _) = yield self.event_source.get_new_events(
+            user=self.u_elderberry,
+            from_key=0,
+            room_ids=[self.room_id],
         )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
@@ -720,8 +745,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
         )
 
         # Durian is not in the room, should not see this event
-        (events, _) = yield self.event_source.get_new_events_for_user(
-            self.u_durian, 0, None
+        (events, _) = yield self.event_source.get_new_events(
+            user=self.u_durian,
+            from_key=0,
+            room_ids=[],
         )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
@@ -767,8 +794,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
                  "accepted": True},
         ], presence)
 
-        (events, _) = yield self.event_source.get_new_events_for_user(
-            self.u_apple, 1, None
+        (events, _) = yield self.event_source.get_new_events(
+            user=self.u_apple,
+            from_key=1,
         )
 
         self.assertEquals(self.event_source.get_current_key(), 2)
@@ -858,8 +886,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
             )
         )
 
-        (events, _) = yield self.event_source.get_new_events_for_user(
-            self.u_apple, 0, None
+        (events, _) = yield self.event_source.get_new_events(
+            user=self.u_apple,
+            from_key=0,
+            room_ids=[self.room_id],
         )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
@@ -905,8 +935,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 1)
 
-        (events, _) = yield self.event_source.get_new_events_for_user(
-            self.u_apple, 0, None
+        (events, _) = yield self.event_source.get_new_events(
+            user=self.u_apple,
+            from_key=0,
+            room_ids=[self.room_id,]
         )
         self.assertEquals(events,
             [
@@ -932,8 +964,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 2)
 
-        (events, _) = yield self.event_source.get_new_events_for_user(
-            self.u_apple, 0, None
+        (events, _) = yield self.event_source.get_new_events(
+            user=self.u_apple,
+            from_key=0,
+            room_ids=[self.room_id,]
         )
         self.assertEquals(events,
             [
@@ -966,8 +1000,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
 
         self.room_members.append(self.u_clementine)
 
-        (events, _) = yield self.event_source.get_new_events_for_user(
-            self.u_apple, 0, None
+        (events, _) = yield self.event_source.get_new_events(
+            user=self.u_apple,
+            from_key=0,
         )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 41bb08b7ca..2d7ba43561 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -187,7 +187,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
         ])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
+        events = yield self.event_source.get_new_events(
+            room_ids=[self.room_id],
+            from_key=0,
+        )
         self.assertEquals(
             events[0],
             [
@@ -250,7 +253,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
         ])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
+        events = yield self.event_source.get_new_events(
+            room_ids=[self.room_id],
+            from_key=0
+        )
         self.assertEquals(
             events[0],
             [
@@ -306,7 +312,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
         yield put_json.await_calls()
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
+        events = yield self.event_source.get_new_events(
+            room_ids=[self.room_id],
+            from_key=0,
+        )
         self.assertEquals(
             events[0],
             [
@@ -337,7 +346,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
+        events = yield self.event_source.get_new_events(
+            room_ids=[self.room_id],
+            from_key=0,
+        )
         self.assertEquals(
             events[0],
             [
@@ -356,7 +368,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
         ])
 
         self.assertEquals(self.event_source.get_current_key(), 2)
-        events = yield self.event_source.get_new_events_for_user(self.u_apple, 1, None)
+        events = yield self.event_source.get_new_events(
+            room_ids=[self.room_id],
+            from_key=1,
+        )
         self.assertEquals(
             events[0],
             [
@@ -383,7 +398,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 3)
-        events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
+        events = yield self.event_source.get_new_events(
+            room_ids=[self.room_id],
+            from_key=0,
+        )
         self.assertEquals(
             events[0],
             [
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 3e0f294630..7f29d73d95 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -47,7 +47,14 @@ class NullSource(object):
     def __init__(self, hs):
         pass
 
-    def get_new_events_for_user(self, user, from_key, limit):
+    def get_new_events(
+            self,
+            user,
+            from_key,
+            room_ids=None,
+            limit=None,
+            is_guest=None
+    ):
         return defer.succeed(([], from_key))
 
     def get_current_key(self, direction='f'):
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 8433585616..61b9cc743b 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -116,7 +116,10 @@ class RoomTypingTestCase(RestTestCase):
         self.assertEquals(200, code)
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = yield self.event_source.get_new_events_for_user(self.user, 0, None)
+        events = yield self.event_source.get_new_events(
+            from_key=0,
+            room_ids=[self.room_id],
+        )
         self.assertEquals(
             events[0],
             [