diff --git a/synapse/visibility.py b/synapse/visibility.py
index 015c2bab37..dc33f61d2b 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -16,10 +16,13 @@ import itertools
import logging
import operator
+import six
+
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event
+from synapse.types import get_domain_from_id
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
logger = logging.getLogger(__name__)
@@ -225,3 +228,141 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
# we turn it into a list before returning it.
defer.returnValue(list(filtered_events))
+
+
+@defer.inlineCallbacks
+def filter_events_for_server(store, server_name, events):
+ # First lets check to see if all the events have a history visibility
+ # of "shared" or "world_readable". If thats the case then we don't
+ # need to check membership (as we know the server is in the room).
+ event_to_state_ids = yield store.get_state_ids_for_events(
+ frozenset(e.event_id for e in events),
+ types=(
+ (EventTypes.RoomHistoryVisibility, ""),
+ )
+ )
+
+ visibility_ids = set()
+ for sids in event_to_state_ids.itervalues():
+ hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
+ if hist:
+ visibility_ids.add(hist)
+
+ # If we failed to find any history visibility events then the default
+ # is "shared" visiblity.
+ if not visibility_ids:
+ defer.returnValue(events)
+
+ event_map = yield store.get_events(visibility_ids)
+ all_open = all(
+ e.content.get("history_visibility") in (None, "shared", "world_readable")
+ for e in event_map.itervalues()
+ )
+
+ if all_open:
+ defer.returnValue(events)
+
+ # Ok, so we're dealing with events that have non-trivial visibility
+ # rules, so we need to also get the memberships of the room.
+
+ # first, for each event we're wanting to return, get the event_ids
+ # of the history vis and membership state at those events.
+ event_to_state_ids = yield store.get_state_ids_for_events(
+ frozenset(e.event_id for e in events),
+ types=(
+ (EventTypes.RoomHistoryVisibility, ""),
+ (EventTypes.Member, None),
+ )
+ )
+
+ # We only want to pull out member events that correspond to the
+ # server's domain.
+ #
+ # event_to_state_ids contains lots of duplicates, so it turns out to be
+ # cheaper to build a complete set of unique
+ # ((type, state_key), event_id) tuples, and then filter out the ones we
+ # don't want.
+ #
+ state_key_to_event_id_set = {
+ e
+ for key_to_eid in six.itervalues(event_to_state_ids)
+ for e in key_to_eid.items()
+ }
+
+ def include(typ, state_key):
+ if typ != EventTypes.Member:
+ return True
+
+ # we avoid using get_domain_from_id here for efficiency.
+ idx = state_key.find(":")
+ if idx == -1:
+ return False
+ return state_key[idx + 1:] == server_name
+
+ event_map = yield store.get_events([
+ e_id
+ for key, e_id in state_key_to_event_id_set
+ if include(key[0], key[1])
+ ])
+
+ event_to_state = {
+ e_id: {
+ key: event_map[inner_e_id]
+ for key, inner_e_id in key_to_eid.iteritems()
+ if inner_e_id in event_map
+ }
+ for e_id, key_to_eid in event_to_state_ids.iteritems()
+ }
+
+ erased_senders = yield store.are_users_erased(
+ e.sender for e in events,
+ )
+
+ def redact_disallowed(event, state):
+ # if the sender has been gdpr17ed, always return a redacted
+ # copy of the event.
+ if erased_senders[event.sender]:
+ logger.info(
+ "Sender of %s has been erased, redacting",
+ event.event_id,
+ )
+ return prune_event(event)
+
+ if not state:
+ return event
+
+ history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+ if history:
+ visibility = history.content.get("history_visibility", "shared")
+ if visibility in ["invited", "joined"]:
+ # We now loop through all state events looking for
+ # membership states for the requesting server to determine
+ # if the server is either in the room or has been invited
+ # into the room.
+ for ev in state.itervalues():
+ if ev.type != EventTypes.Member:
+ continue
+ try:
+ domain = get_domain_from_id(ev.state_key)
+ except Exception:
+ continue
+
+ if domain != server_name:
+ continue
+
+ memtype = ev.membership
+ if memtype == Membership.JOIN:
+ return event
+ elif memtype == Membership.INVITE:
+ if visibility == "invited":
+ return event
+ else:
+ # server has no users in the room: redact
+ return prune_event(event)
+
+ return event
+
+ defer.returnValue([
+ redact_disallowed(e, event_to_state[e.event_id])
+ for e in events
+ ])
|