summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-08-31 13:55:02 +0100
committerErik Johnston <erik@matrix.org>2016-08-31 14:26:22 +0100
commitc10cb581c6ce54e7dfa1f8a0f6449ee7f6d049d4 (patch)
tree18c7bb219894d6d75489ff33e92b3cba897aa872 /synapse/handlers
parentGenerate state group ids in state layer (diff)
downloadsynapse-c10cb581c6ce54e7dfa1f8a0f6449ee7f6d049d4.tar.xz
Correctly handle the difference between prev and current state
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/federation.py31
-rw-r--r--synapse/handlers/message.py10
-rw-r--r--synapse/handlers/room_member.py6
3 files changed, 30 insertions, 17 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index a7ea8fb98f..8e61d74b13 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -222,7 +222,7 @@ class FederationHandler(BaseHandler):
                 # joined the room. Don't bother if the user is just
                 # changing their profile info.
                 newly_joined = True
-                prev_state_id = context.current_state_ids.get(
+                prev_state_id = context.prev_state_ids.get(
                     (event.type, event.state_key)
                 )
                 if prev_state_id:
@@ -835,12 +835,12 @@ class FederationHandler(BaseHandler):
 
         self.replication_layer.send_pdu(new_pdu, destinations)
 
-        state_ids = context.current_state_ids.values()
+        state_ids = context.prev_state_ids.values()
         auth_chain = yield self.store.get_auth_chain(set(
             [event.event_id] + state_ids
         ))
 
-        state = yield self.store.get_events(context.current_state_ids.values())
+        state = yield self.store.get_events(context.prev_state_ids.values())
 
         defer.returnValue({
             "state": state.values(),
@@ -1333,7 +1333,7 @@ class FederationHandler(BaseHandler):
 
         if not auth_events:
             auth_events_ids = yield self.auth.compute_auth_events(
-                event, context.current_state_ids, for_verification=True,
+                event, context.prev_state_ids, for_verification=True,
             )
             auth_events = yield self.store.get_events(auth_events_ids)
             auth_events = {
@@ -1432,6 +1432,11 @@ class FederationHandler(BaseHandler):
         current_state = set(e.event_id for e in auth_events.values())
         event_auth_events = set(e_id for e_id, _ in event.auth_events)
 
+        if event.is_state():
+            event_key = (event.type, event.state_key)
+        else:
+            event_key = None
+
         if event_auth_events - current_state:
             have_events = yield self.store.have_events(
                 event_auth_events - current_state
@@ -1537,8 +1542,12 @@ class FederationHandler(BaseHandler):
 
                 context.current_state_ids.update({
                     k: a.event_id for k, a in auth_events.items()
+                    if k != event_key
+                })
+                context.prev_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
                 })
-                context.state_group = None
+                context.state_group = self.store.get_next_state_group()
 
         if different_auth and not event.internal_metadata.is_outlier():
             logger.info("Different auth after resolution: %s", different_auth)
@@ -1560,7 +1569,7 @@ class FederationHandler(BaseHandler):
             if do_resolution:
                 # 1. Get what we think is the auth chain.
                 auth_ids = yield self.auth.compute_auth_events(
-                    event, context.current_state_ids
+                    event, context.prev_state_ids
                 )
                 local_auth_chain = yield self.store.get_auth_chain(auth_ids)
 
@@ -1618,8 +1627,12 @@ class FederationHandler(BaseHandler):
 
                 context.current_state_ids.update({
                     k: a.event_id for k, a in auth_events.items()
+                    if k != event_key
+                })
+                context.prev_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
                 })
-                context.state_group = None
+                context.state_group = self.store.get_next_state_group()
 
         try:
             self.auth.check(event, auth_events=auth_events)
@@ -1855,7 +1868,7 @@ class FederationHandler(BaseHandler):
             event.content["third_party_invite"]["signed"]["token"]
         )
         original_invite = None
-        original_invite_id = context.current_state_ids.get(key)
+        original_invite_id = context.prev_state_ids.get(key)
         if original_invite_id:
             original_invite = yield self.store.get_event(
                 original_invite_id, allow_none=True
@@ -1893,7 +1906,7 @@ class FederationHandler(BaseHandler):
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        invite_event_id = context.current_state_ids.get(
+        invite_event_id = context.prev_state_ids.get(
             (EventTypes.ThirdPartyInvite, token,)
         )
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e2f4387f60..3577db0595 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -272,7 +272,7 @@ class MessageHandler(BaseHandler):
         If so, returns the version of the event in context.
         Otherwise, returns None.
         """
-        prev_event_id = context.current_state_ids.get((event.type, event.state_key))
+        prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
         prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
         if not prev_event:
             return
@@ -808,8 +808,8 @@ class MessageHandler(BaseHandler):
         event = builder.build()
 
         logger.debug(
-            "Created event %s with current state: %s",
-            event.event_id, context.current_state_ids,
+            "Created event %s with state: %s",
+            event.event_id, context.prev_state_ids,
         )
 
         defer.returnValue(
@@ -904,7 +904,7 @@ class MessageHandler(BaseHandler):
 
         if event.type == EventTypes.Redaction:
             auth_events_ids = yield self.auth.compute_auth_events(
-                event, context.current_state_ids, for_verification=True,
+                event, context.prev_state_ids, for_verification=True,
             )
             auth_events = yield self.store.get_events(auth_events_ids)
             auth_events = {
@@ -924,7 +924,7 @@ class MessageHandler(BaseHandler):
                         "You don't have permission to redact events"
                     )
 
-        if event.type == EventTypes.Create and context.current_state_ids:
+        if event.type == EventTypes.Create and context.prev_state_ids:
             raise AuthError(
                 403,
                 "Changing the room create event is forbidden",
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index dd4b90ee24..3ba5335af7 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -93,7 +93,7 @@ class RoomMemberHandler(BaseHandler):
             ratelimit=ratelimit,
         )
 
-        prev_member_event_id = context.current_state_ids.get(
+        prev_member_event_id = context.prev_state_ids.get(
             (EventTypes.Member, target.to_string()),
             None
         )
@@ -341,7 +341,7 @@ class RoomMemberHandler(BaseHandler):
 
         if event.membership == Membership.JOIN:
             if requester.is_guest:
-                guest_can_join = yield self._can_guest_join(context.current_state_ids)
+                guest_can_join = yield self._can_guest_join(context.prev_state_ids)
                 if not guest_can_join:
                     # This should be an auth check, but guests are a local concept,
                     # so don't really fit into the general auth process.
@@ -355,7 +355,7 @@ class RoomMemberHandler(BaseHandler):
             ratelimit=ratelimit,
         )
 
-        prev_member_event_id = context.current_state_ids.get(
+        prev_member_event_id = context.prev_state_ids.get(
             (EventTypes.Member, event.state_key),
             None
         )