diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index b176db8ce1..90f9eb6847 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -28,6 +28,12 @@ import logging
logger = logging.getLogger(__name__)
+AuthEventTypes = (
+ EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
+ EventTypes.JoinRules,
+)
+
+
class Auth(object):
def __init__(self, hs):
@@ -166,6 +172,7 @@ class Auth(object):
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
+ target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
@@ -194,6 +201,7 @@ class Auth(object):
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
+ "target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
@@ -202,6 +210,11 @@ class Auth(object):
}
)
+ if ban_level:
+ ban_level = int(ban_level)
+ else:
+ ban_level = 50 # FIXME (erikj): What should we do here?
+
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
@@ -212,6 +225,10 @@ class Auth(object):
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
+ elif target_banned:
+ raise AuthError(
+ 403, "%s is banned from the room" % (target_user_id,)
+ )
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
@@ -221,6 +238,8 @@ class Auth(object):
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
+ elif target_banned:
+ raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
@@ -238,6 +257,10 @@ class Auth(object):
403,
"%s not in room %s." % (target_user_id, event.room_id,)
)
+ elif target_banned and user_level < ban_level:
+ raise AuthError(
+ 403, "You cannot unban user &s." % (target_user_id,)
+ )
elif target_user_id != event.user_id:
if kick_level:
kick_level = int(kick_level)
@@ -249,11 +272,6 @@ class Auth(object):
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
- if ban_level:
- ban_level = int(ban_level)
- else:
- ban_level = 50 # FIXME (erikj): What should we do here?
-
if user_level < ban_level:
raise AuthError(403, "You don't have permission to ban")
else:
@@ -412,12 +430,6 @@ class Auth(object):
builder.auth_events = auth_events_entries
- context.auth_events = {
- k: v
- for k, v in context.current_state.items()
- if v.event_id in auth_ids
- }
-
def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create:
return []
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 7e98bdef28..4ecadf0879 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -16,8 +16,7 @@
class EventContext(object):
- def __init__(self, current_state=None, auth_events=None):
+ def __init__(self, current_state=None):
self.current_state = current_state
- self.auth_events = auth_events
self.state_group = None
self.rejected = False
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 7f07f05215..48816a242d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -90,8 +90,8 @@ class BaseHandler(object):
event = builder.build()
logger.debug(
- "Created event %s with auth_events: %s, current state: %s",
- event.event_id, context.auth_events, context.current_state,
+ "Created event %s with current state: %s",
+ event.event_id, context.current_state,
)
defer.returnValue(
@@ -106,7 +106,7 @@ class BaseHandler(object):
# We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth:
- self.auth.check(event, auth_events=context.auth_events)
+ self.auth.check(event, auth_events=context.current_state)
yield self.store.persist_event(event, context=context)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ae4e9b316d..65cfacba2e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -464,11 +464,9 @@ class FederationHandler(BaseHandler):
builder=builder,
)
- self.auth.check(event, auth_events=context.auth_events)
+ self.auth.check(event, auth_events=context.current_state)
- pdu = event
-
- defer.returnValue(pdu)
+ defer.returnValue(event)
@defer.inlineCallbacks
@log_function
@@ -705,7 +703,7 @@ class FederationHandler(BaseHandler):
)
if not auth_events:
- auth_events = context.auth_events
+ auth_events = context.current_state
logger.debug(
"_handle_new_event: %s, auth_events: %s",
diff --git a/synapse/state.py b/synapse/state.py
index 80cced351d..ba2500d61c 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -21,6 +21,7 @@ from synapse.util.async import run_on_reactor
from synapse.util.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
+from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext
from collections import namedtuple
@@ -38,12 +39,6 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
-AuthEventTypes = (
- EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
- EventTypes.JoinRules,
-)
-
-
SIZE_OF_CACHE = 1000
EVICTION_TIMEOUT_SECONDS = 20
@@ -139,18 +134,6 @@ class StateHandler(object):
}
context.state_group = None
- if hasattr(event, "auth_events") and event.auth_events:
- auth_ids = self.hs.get_auth().compute_auth_events(
- event, context.current_state
- )
- context.auth_events = {
- k: v
- for k, v in context.current_state.items()
- if v.event_id in auth_ids
- }
- else:
- context.auth_events = {}
-
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:
@@ -187,18 +170,6 @@ class StateHandler(object):
replaces = context.current_state[key]
event.unsigned["replaces_state"] = replaces.event_id
- if hasattr(event, "auth_events") and event.auth_events:
- auth_ids = self.hs.get_auth().compute_auth_events(
- event, context.current_state
- )
- context.auth_events = {
- k: v
- for k, v in context.current_state.items()
- if v.event_id in auth_ids
- }
- else:
- context.auth_events = {}
-
context.prev_state_events = prev_state
defer.returnValue(context)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 71db16d0e5..456e4bd45d 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -82,7 +82,7 @@ class StateStore(SQLBaseStore):
if context.current_state is None:
return
- state_events = context.current_state
+ state_events = dict(context.current_state)
if event.is_state():
state_events[(event.type, event.state_key)] = event
|