diff --git a/synapse/notifier.py b/synapse/notifier.py
index f998fc83bf..e3b42e2331 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -14,6 +14,8 @@
# limitations under the License.
from twisted.internet import defer
+from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, ObservableDeferred
@@ -269,8 +271,8 @@ class Notifier(object):
logger.exception("Failed to notify listener")
@defer.inlineCallbacks
- def wait_for_events(self, user, rooms, timeout, callback,
- from_token=StreamToken("s0", "0", "0", "0")):
+ def wait_for_events(self, user, timeout, callback, room_ids=None,
+ from_token=StreamToken("s0", "0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
@@ -279,11 +281,12 @@ class Notifier(object):
if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user)
current_token = yield self.event_sources.get_current_token()
- rooms = yield self.store.get_rooms_for_user(user)
- rooms = [room.room_id for room in rooms]
+ if room_ids is None:
+ rooms = yield self.store.get_rooms_for_user(user)
+ room_ids = [room.room_id for room in rooms]
user_stream = _NotifierUserStream(
user=user,
- rooms=rooms,
+ rooms=room_ids,
appservice=appservice,
current_token=current_token,
time_now_ms=self.clock.time_msec(),
@@ -328,8 +331,9 @@ class Notifier(object):
defer.returnValue(result)
@defer.inlineCallbacks
- def get_events_for(self, user, rooms, pagination_config, timeout,
- only_room_events=False):
+ def get_events_for(self, user, pagination_config, timeout,
+ only_room_events=False,
+ is_guest=False, guest_room_id=None):
""" For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning.
@@ -342,6 +346,16 @@ class Notifier(object):
limit = pagination_config.limit
+ room_ids = []
+ if is_guest:
+ if guest_room_id:
+ if not self._is_world_readable(guest_room_id):
+ raise AuthError(403, "Guest access not allowed")
+ room_ids = [guest_room_id]
+ else:
+ rooms = yield self.store.get_rooms_for_user(user.to_string())
+ room_ids = [room.room_id for room in rooms]
+
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
@@ -349,6 +363,7 @@ class Notifier(object):
events = []
end_token = from_token
+
for name, source in self.event_sources.sources.items():
keyname = "%s_key" % name
before_id = getattr(before_token, keyname)
@@ -357,9 +372,23 @@ class Notifier(object):
continue
if only_room_events and name != "room":
continue
- new_events, new_key = yield source.get_new_events_for_user(
- user, getattr(from_token, keyname), limit,
+ new_events, new_key = yield source.get_new_events(
+ user=user,
+ from_key=getattr(from_token, keyname),
+ limit=limit,
+ is_guest=is_guest,
+ room_ids=room_ids,
)
+
+ if name == "room":
+ room_member_handler = self.hs.get_handlers().room_member_handler
+ new_events = yield room_member_handler._filter_events_for_client(
+ user.to_string(),
+ new_events,
+ is_guest=is_guest,
+ require_all_visible_for_guests=False
+ )
+
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
@@ -369,7 +398,7 @@ class Notifier(object):
defer.returnValue(None)
result = yield self.wait_for_events(
- user, rooms, timeout, check_for_updates, from_token=from_token
+ user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token
)
if result is None:
@@ -377,6 +406,17 @@ class Notifier(object):
defer.returnValue(result)
+ @defer.inlineCallbacks
+ def _is_world_readable(self, room_id):
+ state = yield self.hs.get_state_handler().get_current_state(
+ room_id,
+ EventTypes.RoomHistoryVisibility
+ )
+ if state and "history_visibility" in state.content:
+ defer.returnValue(state.content["history_visibility"] == "world_readable")
+ else:
+ defer.returnValue(False)
+
@log_function
def remove_expired_streams(self):
time_now_ms = self.clock.time_msec()
|