diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3882ba79ed..8d99101619 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -39,7 +39,7 @@ from twisted.internet import defer
import itertools
import logging
-
+from synapse.util.thirdpartyinvites import ThirdPartyInvites
logger = logging.getLogger(__name__)
@@ -125,72 +125,60 @@ class FederationHandler(BaseHandler):
)
if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.")
+ current_state = state
- try:
- event_stream_id, max_stream_id = yield self._persist_auth_tree(
- auth_chain, state, event
- )
- except AuthError as e:
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=event.event_id,
- )
+ event_ids = set()
+ if state:
+ event_ids |= {e.event_id for e in state}
+ if auth_chain:
+ event_ids |= {e.event_id for e in auth_chain}
- else:
- event_ids = set()
- if state:
- event_ids |= {e.event_id for e in state}
- if auth_chain:
- event_ids |= {e.event_id for e in auth_chain}
-
- seen_ids = set(
- (yield self.store.have_events(event_ids)).keys()
- )
+ seen_ids = set(
+ (yield self.store.have_events(event_ids)).keys()
+ )
- if state and auth_chain is not None:
- # If we have any state or auth_chain given to us by the replication
- # layer, then we should handle them (if we haven't before.)
+ if state and auth_chain is not None:
+ # If we have any state or auth_chain given to us by the replication
+ # layer, then we should handle them (if we haven't before.)
- event_infos = []
+ event_infos = []
- for e in itertools.chain(auth_chain, state):
- if e.event_id in seen_ids:
- continue
- e.internal_metadata.outlier = True
- auth_ids = [e_id for e_id, _ in e.auth_events]
- auth = {
- (e.type, e.state_key): e for e in auth_chain
- if e.event_id in auth_ids or e.type == EventTypes.Create
- }
- event_infos.append({
- "event": e,
- "auth_events": auth,
- })
- seen_ids.add(e.event_id)
+ for e in itertools.chain(auth_chain, state):
+ if e.event_id in seen_ids:
+ continue
+ e.internal_metadata.outlier = True
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+ event_infos.append({
+ "event": e,
+ "auth_events": auth,
+ })
+ seen_ids.add(e.event_id)
- yield self._handle_new_events(
- origin,
- event_infos,
- outliers=True
- )
+ yield self._handle_new_events(
+ origin,
+ event_infos,
+ outliers=True
+ )
- try:
- _, event_stream_id, max_stream_id = yield self._handle_new_event(
- origin,
- event,
- state=state,
- backfilled=backfilled,
- current_state=current_state,
- )
- except AuthError as e:
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=event.event_id,
- )
+ try:
+ _, event_stream_id, max_stream_id = yield self._handle_new_event(
+ origin,
+ event,
+ state=state,
+ backfilled=backfilled,
+ current_state=current_state,
+ )
+ except AuthError as e:
+ raise FederationError(
+ "ERROR",
+ e.code,
+ e.msg,
+ affected=event.event_id,
+ )
# if we're receiving valid events from an origin,
# it's probably a good idea to mark it as not in retry-state
@@ -584,7 +572,8 @@ class FederationHandler(BaseHandler):
origin, pdu = yield self.replication_layer.make_join(
target_hosts,
room_id,
- joinee
+ joinee,
+ content
)
logger.debug("Got response to make_join: %s", pdu)
@@ -661,8 +650,35 @@ class FederationHandler(BaseHandler):
# FIXME
pass
- event_stream_id, max_stream_id = yield self._persist_auth_tree(
- auth_chain, state, event
+ ev_infos = []
+ for e in itertools.chain(state, auth_chain):
+ if e.event_id == event.event_id:
+ continue
+
+ e.internal_metadata.outlier = True
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ ev_infos.append({
+ "event": e,
+ "auth_events": {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+ })
+
+ yield self._handle_new_events(origin, ev_infos, outliers=True)
+
+ auth_ids = [e_id for e_id, _ in event.auth_events]
+ auth_events = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+
+ _, event_stream_id, max_stream_id = yield self._handle_new_event(
+ origin,
+ new_event,
+ state=state,
+ current_state=state,
+ auth_events=auth_events,
)
with PreserveLoggingContext():
@@ -697,14 +713,18 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def on_make_join_request(self, room_id, user_id):
+ def on_make_join_request(self, room_id, user_id, query):
""" We've received a /make_join/ request, so we create a partial
join event for the room and return that. We don *not* persist or
process it until the other server has signed it and sent it back.
"""
+ event_content = {"membership": Membership.JOIN}
+ if ThirdPartyInvites.has_join_keys(query):
+ ThirdPartyInvites.copy_join_keys(query, event_content)
+
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
- "content": {"membership": Membership.JOIN},
+ "content": event_content,
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
@@ -716,6 +736,9 @@ class FederationHandler(BaseHandler):
self.auth.check(event, auth_events=context.current_state)
+ if ThirdPartyInvites.has_join_keys(event.content):
+ ThirdPartyInvites.check_key_valid(self.hs.get_simple_http_client(), event)
+
defer.returnValue(event)
@defer.inlineCallbacks
@@ -1012,76 +1035,6 @@ class FederationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def _persist_auth_tree(self, auth_events, state, event):
- """Checks the auth chain is valid (and passes auth checks) for the
- state and event. Then persists the auth chain and state atomically.
- Persists the event seperately.
-
- Returns:
- 2-tuple of (event_stream_id, max_stream_id) from the persist_event
- call for `event`
- """
- events_to_context = {}
- for e in itertools.chain(auth_events, state):
- ctx = yield self.state_handler.compute_event_context(
- e, outlier=True,
- )
- events_to_context[e.event_id] = ctx
- e.internal_metadata.outlier = True
-
- event_map = {
- e.event_id: e
- for e in auth_events
- }
-
- create_event = None
- for e in auth_events:
- if (e.type, e.state_key) == (EventTypes.Create, ""):
- create_event = e
- break
-
- for e in itertools.chain(auth_events, state, [event]):
- auth_for_e = {
- (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
- for e_id, _ in e.auth_events
- }
- if create_event:
- auth_for_e[(EventTypes.Create, "")] = create_event
-
- try:
- self.auth.check(e, auth_events=auth_for_e)
- except AuthError as err:
- logger.warn(
- "Rejecting %s because %s",
- e.event_id, err.msg
- )
-
- if e == event:
- raise
- events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
-
- yield self.store.persist_events(
- [
- (e, events_to_context[e.event_id])
- for e in itertools.chain(auth_events, state)
- ],
- is_new_state=False,
- )
-
- new_event_context = yield self.state_handler.compute_event_context(
- event, old_state=state, outlier=False,
- )
-
- event_stream_id, max_stream_id = yield self.store.persist_event(
- event, new_event_context,
- backfilled=False,
- is_new_state=True,
- current_state=state,
- )
-
- defer.returnValue((event_stream_id, max_stream_id))
-
- @defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
|