diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4ff20599d6..c1bce07e31 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,6 +21,7 @@ from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
)
from synapse.api.constants import EventTypes, Membership, RejectedReason
+from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
@@ -40,7 +41,6 @@ from twisted.internet import defer
import itertools
import logging
-
logger = logging.getLogger(__name__)
@@ -58,6 +58,8 @@ class FederationHandler(BaseHandler):
def __init__(self, hs):
super(FederationHandler, self).__init__(hs)
+ self.hs = hs
+
self.distributor.observe(
"user_joined_room",
self._on_user_joined
@@ -68,12 +70,9 @@ class FederationHandler(BaseHandler):
self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer()
self.state_handler = hs.get_state_handler()
- # self.auth_handler = gs.get_auth_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
- self.lock_manager = hs.get_room_lock_manager()
-
self.replication_layer.set_handler(self)
# When joining a room we need to queue any events for that room up
@@ -125,60 +124,72 @@ class FederationHandler(BaseHandler):
)
if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.")
- current_state = state
- event_ids = set()
- if state:
- event_ids |= {e.event_id for e in state}
- if auth_chain:
- event_ids |= {e.event_id for e in auth_chain}
+ try:
+ event_stream_id, max_stream_id = yield self._persist_auth_tree(
+ auth_chain, state, event
+ )
+ except AuthError as e:
+ raise FederationError(
+ "ERROR",
+ e.code,
+ e.msg,
+ affected=event.event_id,
+ )
- seen_ids = set(
- (yield self.store.have_events(event_ids)).keys()
- )
+ else:
+ event_ids = set()
+ if state:
+ event_ids |= {e.event_id for e in state}
+ if auth_chain:
+ event_ids |= {e.event_id for e in auth_chain}
+
+ seen_ids = set(
+ (yield self.store.have_events(event_ids)).keys()
+ )
- if state and auth_chain is not None:
- # If we have any state or auth_chain given to us by the replication
- # layer, then we should handle them (if we haven't before.)
+ if state and auth_chain is not None:
+ # If we have any state or auth_chain given to us by the replication
+ # layer, then we should handle them (if we haven't before.)
- event_infos = []
+ event_infos = []
- for e in itertools.chain(auth_chain, state):
- if e.event_id in seen_ids:
- continue
- e.internal_metadata.outlier = True
- auth_ids = [e_id for e_id, _ in e.auth_events]
- auth = {
- (e.type, e.state_key): e for e in auth_chain
- if e.event_id in auth_ids
- }
- event_infos.append({
- "event": e,
- "auth_events": auth,
- })
- seen_ids.add(e.event_id)
+ for e in itertools.chain(auth_chain, state):
+ if e.event_id in seen_ids:
+ continue
+ e.internal_metadata.outlier = True
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids or e.type == EventTypes.Create
+ }
+ event_infos.append({
+ "event": e,
+ "auth_events": auth,
+ })
+ seen_ids.add(e.event_id)
- yield self._handle_new_events(
- origin,
- event_infos,
- outliers=True
- )
+ yield self._handle_new_events(
+ origin,
+ event_infos,
+ outliers=True
+ )
- try:
- _, event_stream_id, max_stream_id = yield self._handle_new_event(
- origin,
- event,
- state=state,
- backfilled=backfilled,
- current_state=current_state,
- )
- except AuthError as e:
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=event.event_id,
- )
+ try:
+ _, event_stream_id, max_stream_id = yield self._handle_new_event(
+ origin,
+ event,
+ state=state,
+ backfilled=backfilled,
+ current_state=current_state,
+ )
+ except AuthError as e:
+ raise FederationError(
+ "ERROR",
+ e.code,
+ e.msg,
+ affected=event.event_id,
+ )
# if we're receiving valid events from an origin,
# it's probably a good idea to mark it as not in retry-state
@@ -230,7 +241,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events(
- room_id, frozenset(e.event_id for e in events),
+ frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
@@ -553,7 +564,7 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
- def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot):
+ def do_invite_join(self, target_hosts, room_id, joinee, content):
""" Attempts to join the `joinee` to the room `room_id` via the
server `target_host`.
@@ -569,49 +580,19 @@ class FederationHandler(BaseHandler):
yield self.store.clean_room_for_join(room_id)
- origin, pdu = yield self.replication_layer.make_join(
+ origin, event = yield self._make_and_verify_event(
target_hosts,
room_id,
- joinee
+ joinee,
+ "join",
+ content,
)
- logger.debug("Got response to make_join: %s", pdu)
-
- event = pdu
-
- # We should assert some things.
- # FIXME: Do this in a nicer way
- assert(event.type == EventTypes.Member)
- assert(event.user_id == joinee)
- assert(event.state_key == joinee)
- assert(event.room_id == room_id)
-
- event.internal_metadata.outlier = False
-
self.room_queues[room_id] = []
-
- builder = self.event_builder_factory.new(
- unfreeze(event.get_pdu_json())
- )
-
handled_events = set()
try:
- builder.event_id = self.event_builder_factory.create_event_id()
- builder.origin = self.hs.hostname
- builder.content = content
-
- if not hasattr(event, "signatures"):
- builder.signatures = {}
-
- add_hashes_and_signatures(
- builder,
- self.hs.hostname,
- self.hs.config.signing_key[0],
- )
-
- new_event = builder.build()
-
+ new_event = self._sign_event(event)
# Try the host we successfully got a response to /make_join/
# request first.
try:
@@ -619,11 +600,7 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin)
except ValueError:
pass
-
- ret = yield self.replication_layer.send_join(
- target_hosts,
- new_event
- )
+ ret = yield self.replication_layer.send_join(target_hosts, new_event)
origin = ret["origin"]
state = ret["state"]
@@ -649,35 +626,8 @@ class FederationHandler(BaseHandler):
# FIXME
pass
- ev_infos = []
- for e in itertools.chain(state, auth_chain):
- if e.event_id == event.event_id:
- continue
-
- e.internal_metadata.outlier = True
- auth_ids = [e_id for e_id, _ in e.auth_events]
- ev_infos.append({
- "event": e,
- "auth_events": {
- (e.type, e.state_key): e for e in auth_chain
- if e.event_id in auth_ids
- }
- })
-
- yield self._handle_new_events(origin, ev_infos, outliers=True)
-
- auth_ids = [e_id for e_id, _ in event.auth_events]
- auth_events = {
- (e.type, e.state_key): e for e in auth_chain
- if e.event_id in auth_ids
- }
-
- _, event_stream_id, max_stream_id = yield self._handle_new_event(
- origin,
- new_event,
- state=state,
- current_state=state,
- auth_events=auth_events,
+ event_stream_id, max_stream_id = yield self._persist_auth_tree(
+ auth_chain, state, event
)
with PreserveLoggingContext():
@@ -714,12 +664,14 @@ class FederationHandler(BaseHandler):
@log_function
def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial
- join event for the room and return that. We don *not* persist or
+ join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
"""
+ event_content = {"membership": Membership.JOIN}
+
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
- "content": {"membership": Membership.JOIN},
+ "content": event_content,
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
@@ -865,6 +817,168 @@ class FederationHandler(BaseHandler):
defer.returnValue(event)
@defer.inlineCallbacks
+ def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
+ origin, event = yield self._make_and_verify_event(
+ target_hosts,
+ room_id,
+ user_id,
+ "leave"
+ )
+ signed_event = self._sign_event(event)
+
+ # Try the host we successfully got a response to /make_join/
+ # request first.
+ try:
+ target_hosts.remove(origin)
+ target_hosts.insert(0, origin)
+ except ValueError:
+ pass
+
+ yield self.replication_layer.send_leave(
+ target_hosts,
+ signed_event
+ )
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
+ content={},):
+ origin, pdu = yield self.replication_layer.make_membership_event(
+ target_hosts,
+ room_id,
+ user_id,
+ membership,
+ content,
+ )
+
+ logger.debug("Got response to make_%s: %s", membership, pdu)
+
+ event = pdu
+
+ # We should assert some things.
+ # FIXME: Do this in a nicer way
+ assert(event.type == EventTypes.Member)
+ assert(event.user_id == user_id)
+ assert(event.state_key == user_id)
+ assert(event.room_id == room_id)
+ defer.returnValue((origin, event))
+
+ def _sign_event(self, event):
+ event.internal_metadata.outlier = False
+
+ builder = self.event_builder_factory.new(
+ unfreeze(event.get_pdu_json())
+ )
+
+ builder.event_id = self.event_builder_factory.create_event_id()
+ builder.origin = self.hs.hostname
+
+ if not hasattr(event, "signatures"):
+ builder.signatures = {}
+
+ add_hashes_and_signatures(
+ builder,
+ self.hs.hostname,
+ self.hs.config.signing_key[0],
+ )
+
+ return builder.build()
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_make_leave_request(self, room_id, user_id):
+ """ We've received a /make_leave/ request, so we create a partial
+ join event for the room and return that. We do *not* persist or
+ process it until the other server has signed it and sent it back.
+ """
+ builder = self.event_builder_factory.new({
+ "type": EventTypes.Member,
+ "content": {"membership": Membership.LEAVE},
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": user_id,
+ })
+
+ event, context = yield self._create_new_client_event(
+ builder=builder,
+ )
+
+ self.auth.check(event, auth_events=context.current_state)
+
+ defer.returnValue(event)
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_send_leave_request(self, origin, pdu):
+ """ We have received a leave event for a room. Fully process it."""
+ event = pdu
+
+ logger.debug(
+ "on_send_leave_request: Got event: %s, signatures: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ event.internal_metadata.outlier = False
+
+ context, event_stream_id, max_stream_id = yield self._handle_new_event(
+ origin, event
+ )
+
+ logger.debug(
+ "on_send_leave_request: After _handle_new_event: %s, sigs: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ extra_users = []
+ if event.type == EventTypes.Member:
+ target_user_id = event.state_key
+ target_user = UserID.from_string(target_user_id)
+ extra_users.append(target_user)
+
+ with PreserveLoggingContext():
+ d = self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id, extra_users=extra_users
+ )
+
+ def log_failure(f):
+ logger.warn(
+ "Failed to notify about %s: %s",
+ event.event_id, f.value
+ )
+
+ d.addErrback(log_failure)
+
+ new_pdu = event
+
+ destinations = set()
+
+ for k, s in context.current_state.items():
+ try:
+ if k[0] == EventTypes.Member:
+ if s.content["membership"] == Membership.LEAVE:
+ destinations.add(
+ UserID.from_string(s.state_key).domain
+ )
+ except:
+ logger.warn(
+ "Failed to get destination from event %s", s.event_id
+ )
+
+ destinations.discard(origin)
+
+ logger.debug(
+ "on_send_leave_request: Sending event: %s, signatures: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ self.replication_layer.send_pdu(new_pdu, destinations)
+
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor()
@@ -986,8 +1100,6 @@ class FederationHandler(BaseHandler):
context = yield self._prep_event(
origin, event,
state=state,
- backfilled=backfilled,
- current_state=current_state,
auth_events=auth_events,
)
@@ -1010,7 +1122,6 @@ class FederationHandler(BaseHandler):
origin,
ev_info["event"],
state=ev_info.get("state"),
- backfilled=backfilled,
auth_events=ev_info.get("auth_events"),
)
for ev_info in event_infos
@@ -1027,8 +1138,77 @@ class FederationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def _prep_event(self, origin, event, state=None, backfilled=False,
- current_state=None, auth_events=None):
+ def _persist_auth_tree(self, auth_events, state, event):
+ """Checks the auth chain is valid (and passes auth checks) for the
+ state and event. Then persists the auth chain and state atomically.
+ Persists the event seperately.
+
+ Returns:
+ 2-tuple of (event_stream_id, max_stream_id) from the persist_event
+ call for `event`
+ """
+ events_to_context = {}
+ for e in itertools.chain(auth_events, state):
+ ctx = yield self.state_handler.compute_event_context(
+ e, outlier=True,
+ )
+ events_to_context[e.event_id] = ctx
+ e.internal_metadata.outlier = True
+
+ event_map = {
+ e.event_id: e
+ for e in auth_events
+ }
+
+ create_event = None
+ for e in auth_events:
+ if (e.type, e.state_key) == (EventTypes.Create, ""):
+ create_event = e
+ break
+
+ for e in itertools.chain(auth_events, state, [event]):
+ auth_for_e = {
+ (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
+ for e_id, _ in e.auth_events
+ }
+ if create_event:
+ auth_for_e[(EventTypes.Create, "")] = create_event
+
+ try:
+ self.auth.check(e, auth_events=auth_for_e)
+ except AuthError as err:
+ logger.warn(
+ "Rejecting %s because %s",
+ e.event_id, err.msg
+ )
+
+ if e == event:
+ raise
+ events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
+
+ yield self.store.persist_events(
+ [
+ (e, events_to_context[e.event_id])
+ for e in itertools.chain(auth_events, state)
+ ],
+ is_new_state=False,
+ )
+
+ new_event_context = yield self.state_handler.compute_event_context(
+ event, old_state=state, outlier=False,
+ )
+
+ event_stream_id, max_stream_id = yield self.store.persist_event(
+ event, new_event_context,
+ backfilled=False,
+ is_new_state=True,
+ current_state=state,
+ )
+
+ defer.returnValue((event_stream_id, max_stream_id))
+
+ @defer.inlineCallbacks
+ def _prep_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context(
@@ -1061,6 +1241,10 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR
+ if event.type == EventTypes.GuestAccess:
+ full_context = yield self.store.get_current_state(room_id=event.room_id)
+ yield self.maybe_kick_guest_users(event, full_context)
+
defer.returnValue(context)
@defer.inlineCallbacks
@@ -1166,7 +1350,7 @@ class FederationHandler(BaseHandler):
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in remote_auth_chain
- if e.event_id in auth_ids
+ if e.event_id in auth_ids or e.type == EventTypes.Create
}
e.internal_metadata.outlier = True
@@ -1284,6 +1468,7 @@ class FederationHandler(BaseHandler):
(e.type, e.state_key): e
for e in result["auth_chain"]
if e.event_id in auth_ids
+ or event.type == EventTypes.Create
}
ev.internal_metadata.outlier = True
@@ -1458,50 +1643,73 @@ class FederationHandler(BaseHandler):
})
@defer.inlineCallbacks
- def _handle_auth_events(self, origin, auth_events):
- auth_ids_to_deferred = {}
-
- def process_auth_ev(ev):
- auth_ids = [e_id for e_id, _ in ev.auth_events]
-
- prev_ds = [
- auth_ids_to_deferred[i]
- for i in auth_ids
- if i in auth_ids_to_deferred
- ]
-
- d = defer.Deferred()
+ @log_function
+ def exchange_third_party_invite(self, invite):
+ sender = invite["sender"]
+ room_id = invite["room_id"]
- auth_ids_to_deferred[ev.event_id] = d
+ event_dict = {
+ "type": EventTypes.Member,
+ "content": {
+ "membership": Membership.INVITE,
+ "third_party_invite": invite,
+ },
+ "room_id": room_id,
+ "sender": sender,
+ "state_key": invite["mxid"],
+ }
+
+ if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
+ builder = self.event_builder_factory.new(event_dict)
+ EventValidator().validate_new(builder)
+ event, context = yield self._create_new_client_event(builder=builder)
+ self.auth.check(event, context.current_state)
+ yield self._validate_keyserver(event, auth_events=context.current_state)
+ member_handler = self.hs.get_handlers().room_member_handler
+ yield member_handler.change_membership(event, context)
+ else:
+ destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
+ yield self.replication_layer.forward_third_party_invite(
+ destinations,
+ room_id,
+ event_dict,
+ )
- @defer.inlineCallbacks
- def f(*_):
- ev.internal_metadata.outlier = True
+ @defer.inlineCallbacks
+ @log_function
+ def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
+ builder = self.event_builder_factory.new(event_dict)
- try:
- auth = {
- (e.type, e.state_key): e for e in auth_events
- if e.event_id in auth_ids
- }
+ event, context = yield self._create_new_client_event(
+ builder=builder,
+ )
- yield self._handle_new_event(
- origin, ev, auth_events=auth
- )
- except:
- logger.exception(
- "Failed to handle auth event %s",
- ev.event_id,
- )
+ self.auth.check(event, auth_events=context.current_state)
+ yield self._validate_keyserver(event, auth_events=context.current_state)
- d.callback(None)
+ returned_invite = yield self.send_invite(origin, event)
+ # TODO: Make sure the signatures actually are correct.
+ event.signatures.update(returned_invite.signatures)
+ member_handler = self.hs.get_handlers().room_member_handler
+ yield member_handler.change_membership(event, context)
- if prev_ds:
- dx = defer.DeferredList(prev_ds)
- dx.addBoth(f)
- else:
- f()
+ @defer.inlineCallbacks
+ def _validate_keyserver(self, event, auth_events):
+ token = event.content["third_party_invite"]["signed"]["token"]
- for e in auth_events:
- process_auth_ev(e)
+ invite_event = auth_events.get(
+ (EventTypes.ThirdPartyInvite, token,)
+ )
- yield defer.DeferredList(auth_ids_to_deferred.values())
+ try:
+ response = yield self.hs.get_simple_http_client().get_json(
+ invite_event.content["key_validity_url"],
+ {"public_key": invite_event.content["public_key"]}
+ )
+ except Exception:
+ raise SynapseError(
+ 502,
+ "Third party certificate could not be checked"
+ )
+ if "valid" not in response or not response["valid"]:
+ raise AuthError(403, "Third party certificate was invalid")
|