diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 0db26fcfd7..40c3e9db0d 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -52,7 +52,7 @@ class Auth(object):
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at
- # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
+ # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to
# delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([
@@ -63,6 +63,17 @@ class Auth(object):
"user_id = ",
])
+ @defer.inlineCallbacks
+ def check_from_context(self, event, context, do_sig_check=True):
+ auth_events_ids = yield self.compute_auth_events(
+ event, context.current_state_ids, for_verification=True,
+ )
+ auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = {
+ (e.type, e.state_key): e for e in auth_events.values()
+ }
+ self.check(event, auth_events=auth_events, do_sig_check=False)
+
def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed.
@@ -847,7 +858,7 @@ class Auth(object):
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
- auth_ids = self.compute_auth_events(builder, context.current_state)
+ auth_ids = yield self.compute_auth_events(builder, context.current_state_ids)
auth_events_entries = yield self.store.add_event_hashes(
auth_ids
@@ -855,30 +866,32 @@ class Auth(object):
builder.auth_events = auth_events_entries
- def compute_auth_events(self, event, current_state):
+ @defer.inlineCallbacks
+ def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create:
- return []
+ defer.returnValue([])
auth_ids = []
key = (EventTypes.PowerLevels, "", )
- power_level_event = current_state.get(key)
+ power_level_event_id = current_state_ids.get(key)
- if power_level_event:
- auth_ids.append(power_level_event.event_id)
+ if power_level_event_id:
+ auth_ids.append(power_level_event_id)
key = (EventTypes.JoinRules, "", )
- join_rule_event = current_state.get(key)
+ join_rule_event_id = current_state_ids.get(key)
key = (EventTypes.Member, event.user_id, )
- member_event = current_state.get(key)
+ member_event_id = current_state_ids.get(key)
key = (EventTypes.Create, "", )
- create_event = current_state.get(key)
- if create_event:
- auth_ids.append(create_event.event_id)
+ create_event_id = current_state_ids.get(key)
+ if create_event_id:
+ auth_ids.append(create_event_id)
- if join_rule_event:
+ if join_rule_event_id:
+ join_rule_event = yield self.store.get_event(join_rule_event_id)
join_rule = join_rule_event.content.get("join_rule")
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
else:
@@ -887,15 +900,21 @@ class Auth(object):
if event.type == EventTypes.Member:
e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]:
- if join_rule_event:
- auth_ids.append(join_rule_event.event_id)
+ if join_rule_event_id:
+ auth_ids.append(join_rule_event_id)
if e_type == Membership.JOIN:
- if member_event and not is_public:
- auth_ids.append(member_event.event_id)
+ if member_event_id and not is_public:
+ auth_ids.append(member_event_id)
else:
- if member_event:
- auth_ids.append(member_event.event_id)
+ if member_event_id:
+ auth_ids.append(member_event_id)
+
+ if for_verification:
+ key = (EventTypes.Member, event.state_key, )
+ existing_event_id = current_state_ids.get(key)
+ if existing_event_id:
+ auth_ids.append(existing_event_id)
if e_type == Membership.INVITE:
if "third_party_invite" in event.content:
@@ -903,14 +922,15 @@ class Auth(object):
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
- third_party_invite = current_state.get(key)
- if third_party_invite:
- auth_ids.append(third_party_invite.event_id)
- elif member_event:
+ third_party_invite_id = current_state_ids.get(key)
+ if third_party_invite_id:
+ auth_ids.append(third_party_invite_id)
+ elif member_event_id:
+ member_event = yield self.store.get_event(member_event_id)
if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id)
- return auth_ids
+ defer.returnValue(auth_ids)
def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
|