diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 98a50f0948..d60c1b15ae 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -72,7 +72,7 @@ class Auth(object):
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
- self.check(event, auth_events=auth_events, do_sig_check=False)
+ self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed.
@@ -92,9 +92,21 @@ class Auth(object):
raise AuthError(500, "Event has no room_id: %s" % event)
sender_domain = get_domain_from_id(event.sender)
+ event_id_domain = get_domain_from_id(event.event_id)
+
+ is_invite_via_3pid = (
+ event.type == EventTypes.Member
+ and event.membership == Membership.INVITE
+ and "third_party_invite" in event.content
+ )
# Check the sender's domain has signed the event
if do_sig_check and not event.signatures.get(sender_domain):
+ if not is_invite_via_3pid:
+ raise AuthError(403, "Event not signed by sender's server")
+
+ # Check the event_id's domain has signed the event
+ if do_sig_check and not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
@@ -491,6 +503,9 @@ class Auth(object):
if not invite_event:
return False
+ if invite_event.sender != event.sender:
+ return False
+
if event.user_id != invite_event.user_id:
return False
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f7cb3c1bb2..a393263e1e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1922,15 +1922,15 @@ class FederationHandler(BaseHandler):
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
)
- if not original_invite:
+ if original_invite:
+ display_name = original_invite.content["display_name"]
+ event_dict["content"]["third_party_invite"]["display_name"] = display_name
+ else:
logger.info(
- "Could not find invite event for third_party_invite - "
- "discarding: %s" % (event_dict,)
+ "Could not find invite event for third_party_invite: %r",
+ event_dict
)
- return
- display_name = original_invite.content["display_name"]
- event_dict["content"]["third_party_invite"]["display_name"] = display_name
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler
|