summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py4
-rw-r--r--synapse/events/snapshot.py5
-rw-r--r--synapse/handlers/federation.py31
-rw-r--r--synapse/handlers/message.py10
-rw-r--r--synapse/handlers/room_member.py6
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py2
-rw-r--r--synapse/state.py31
-rw-r--r--synapse/storage/roommember.py27
-rw-r--r--synapse/storage/state.py3
-rw-r--r--tests/replication/slave/storage/test_events.py4
-rw-r--r--tests/replication/test_resource.py10
-rw-r--r--tests/test_state.py38
12 files changed, 102 insertions, 69 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index f26e585623..fcf0b0d25f 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -66,7 +66,7 @@ class Auth(object):
     @defer.inlineCallbacks
     def check_from_context(self, event, context, do_sig_check=True):
         auth_events_ids = yield self.compute_auth_events(
-            event, context.current_state_ids, for_verification=True,
+            event, context.prev_state_ids, for_verification=True,
         )
         auth_events = yield self.store.get_events(auth_events_ids)
         auth_events = {
@@ -852,7 +852,7 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def add_auth_events(self, builder, context):
-        auth_ids = yield self.compute_auth_events(builder, context.current_state_ids)
+        auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
 
         auth_events_entries = yield self.store.add_event_hashes(
             auth_ids
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index c75afd02d8..e895b1c450 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,8 +15,9 @@
 
 
 class EventContext(object):
-    def __init__(self, current_state_ids=None):
-        self.current_state_ids = current_state_ids
+    def __init__(self):
+        self.current_state_ids = None
+        self.prev_state_ids = None
         self.state_group = None
         self.rejected = False
         self.push_actions = []
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index a7ea8fb98f..8e61d74b13 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -222,7 +222,7 @@ class FederationHandler(BaseHandler):
                 # joined the room. Don't bother if the user is just
                 # changing their profile info.
                 newly_joined = True
-                prev_state_id = context.current_state_ids.get(
+                prev_state_id = context.prev_state_ids.get(
                     (event.type, event.state_key)
                 )
                 if prev_state_id:
@@ -835,12 +835,12 @@ class FederationHandler(BaseHandler):
 
         self.replication_layer.send_pdu(new_pdu, destinations)
 
-        state_ids = context.current_state_ids.values()
+        state_ids = context.prev_state_ids.values()
         auth_chain = yield self.store.get_auth_chain(set(
             [event.event_id] + state_ids
         ))
 
-        state = yield self.store.get_events(context.current_state_ids.values())
+        state = yield self.store.get_events(context.prev_state_ids.values())
 
         defer.returnValue({
             "state": state.values(),
@@ -1333,7 +1333,7 @@ class FederationHandler(BaseHandler):
 
         if not auth_events:
             auth_events_ids = yield self.auth.compute_auth_events(
-                event, context.current_state_ids, for_verification=True,
+                event, context.prev_state_ids, for_verification=True,
             )
             auth_events = yield self.store.get_events(auth_events_ids)
             auth_events = {
@@ -1432,6 +1432,11 @@ class FederationHandler(BaseHandler):
         current_state = set(e.event_id for e in auth_events.values())
         event_auth_events = set(e_id for e_id, _ in event.auth_events)
 
+        if event.is_state():
+            event_key = (event.type, event.state_key)
+        else:
+            event_key = None
+
         if event_auth_events - current_state:
             have_events = yield self.store.have_events(
                 event_auth_events - current_state
@@ -1537,8 +1542,12 @@ class FederationHandler(BaseHandler):
 
                 context.current_state_ids.update({
                     k: a.event_id for k, a in auth_events.items()
+                    if k != event_key
+                })
+                context.prev_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
                 })
-                context.state_group = None
+                context.state_group = self.store.get_next_state_group()
 
         if different_auth and not event.internal_metadata.is_outlier():
             logger.info("Different auth after resolution: %s", different_auth)
@@ -1560,7 +1569,7 @@ class FederationHandler(BaseHandler):
             if do_resolution:
                 # 1. Get what we think is the auth chain.
                 auth_ids = yield self.auth.compute_auth_events(
-                    event, context.current_state_ids
+                    event, context.prev_state_ids
                 )
                 local_auth_chain = yield self.store.get_auth_chain(auth_ids)
 
@@ -1618,8 +1627,12 @@ class FederationHandler(BaseHandler):
 
                 context.current_state_ids.update({
                     k: a.event_id for k, a in auth_events.items()
+                    if k != event_key
+                })
+                context.prev_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
                 })
-                context.state_group = None
+                context.state_group = self.store.get_next_state_group()
 
         try:
             self.auth.check(event, auth_events=auth_events)
@@ -1855,7 +1868,7 @@ class FederationHandler(BaseHandler):
             event.content["third_party_invite"]["signed"]["token"]
         )
         original_invite = None
-        original_invite_id = context.current_state_ids.get(key)
+        original_invite_id = context.prev_state_ids.get(key)
         if original_invite_id:
             original_invite = yield self.store.get_event(
                 original_invite_id, allow_none=True
@@ -1893,7 +1906,7 @@ class FederationHandler(BaseHandler):
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        invite_event_id = context.current_state_ids.get(
+        invite_event_id = context.prev_state_ids.get(
             (EventTypes.ThirdPartyInvite, token,)
         )
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e2f4387f60..3577db0595 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -272,7 +272,7 @@ class MessageHandler(BaseHandler):
         If so, returns the version of the event in context.
         Otherwise, returns None.
         """
-        prev_event_id = context.current_state_ids.get((event.type, event.state_key))
+        prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
         prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
         if not prev_event:
             return
@@ -808,8 +808,8 @@ class MessageHandler(BaseHandler):
         event = builder.build()
 
         logger.debug(
-            "Created event %s with current state: %s",
-            event.event_id, context.current_state_ids,
+            "Created event %s with state: %s",
+            event.event_id, context.prev_state_ids,
         )
 
         defer.returnValue(
@@ -904,7 +904,7 @@ class MessageHandler(BaseHandler):
 
         if event.type == EventTypes.Redaction:
             auth_events_ids = yield self.auth.compute_auth_events(
-                event, context.current_state_ids, for_verification=True,
+                event, context.prev_state_ids, for_verification=True,
             )
             auth_events = yield self.store.get_events(auth_events_ids)
             auth_events = {
@@ -924,7 +924,7 @@ class MessageHandler(BaseHandler):
                         "You don't have permission to redact events"
                     )
 
-        if event.type == EventTypes.Create and context.current_state_ids:
+        if event.type == EventTypes.Create and context.prev_state_ids:
             raise AuthError(
                 403,
                 "Changing the room create event is forbidden",
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index dd4b90ee24..3ba5335af7 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -93,7 +93,7 @@ class RoomMemberHandler(BaseHandler):
             ratelimit=ratelimit,
         )
 
-        prev_member_event_id = context.current_state_ids.get(
+        prev_member_event_id = context.prev_state_ids.get(
             (EventTypes.Member, target.to_string()),
             None
         )
@@ -341,7 +341,7 @@ class RoomMemberHandler(BaseHandler):
 
         if event.membership == Membership.JOIN:
             if requester.is_guest:
-                guest_can_join = yield self._can_guest_join(context.current_state_ids)
+                guest_can_join = yield self._can_guest_join(context.prev_state_ids)
                 if not guest_can_join:
                     # This should be an auth check, but guests are a local concept,
                     # so don't really fit into the general auth process.
@@ -355,7 +355,7 @@ class RoomMemberHandler(BaseHandler):
             ratelimit=ratelimit,
         )
 
-        prev_member_event_id = context.current_state_ids.get(
+        prev_member_event_id = context.prev_state_ids.get(
             (EventTypes.Member, event.state_key),
             None
         )
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 51cb21ee9d..6ff9a06de1 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -87,7 +87,7 @@ class BulkPushRuleEvaluator:
         )
 
         room_members = yield self.store.get_joined_users_from_context(
-            event.room_id, context.state_group, context.current_state_ids
+            event, context
         )
 
         evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
diff --git a/synapse/state.py b/synapse/state.py
index 147416fd81..a0f807e3b9 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -128,7 +128,7 @@ class StateHandler(object):
     def get_current_user_in_room(self, room_id):
         latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
         group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
-        joined_users = yield self.store.get_joined_users_from_context(
+        joined_users = yield self.store.get_joined_users_from_state(
             room_id, group, state_ids
         )
         defer.returnValue(joined_users)
@@ -154,27 +154,38 @@ class StateHandler(object):
             # state. Certainly store.get_current_state won't return any, and
             # persisting the event won't store the state group.
             if old_state:
-                context.current_state_ids = {
+                context.prev_state_ids = {
                     (s.type, s.state_key): s.event_id for s in old_state
                 }
+                if event.is_state():
+                    context.current_state_events = dict(context.prev_state_ids)
+                    key = (event.type, event.state_key)
+                    context.current_state_events[key] = event.event_id
+                else:
+                    context.current_state_events = context.prev_state_ids
             else:
                 context.current_state_ids = {}
+                context.prev_state_ids = {}
             context.prev_state_events = []
             context.state_group = self.store.get_next_state_group()
             defer.returnValue(context)
 
         if old_state:
-            context.current_state_ids = {
+            context.prev_state_ids = {
                 (s.type, s.state_key): s.event_id for s in old_state
             }
             context.state_group = self.store.get_next_state_group()
 
             if event.is_state():
                 key = (event.type, event.state_key)
-                if key in context.current_state_ids:
-                    replaces = context.current_state_ids[key]
+                if key in context.prev_state_ids:
+                    replaces = context.prev_state_ids[key]
                     if replaces != event.event_id:  # Paranoia check
                         event.unsigned["replaces_state"] = replaces
+                context.current_state_ids = dict(context.prev_state_ids)
+                context.current_state_ids[key] = event.event_id
+            else:
+                context.current_state_ids = context.prev_state_ids
 
             context.prev_state_events = []
             defer.returnValue(context)
@@ -192,7 +203,7 @@ class StateHandler(object):
 
         group, curr_state = ret
 
-        context.current_state_ids = curr_state
+        context.prev_state_ids = curr_state
         if event.is_state() or group is None:
             context.state_group = self.store.get_next_state_group()
         else:
@@ -200,9 +211,13 @@ class StateHandler(object):
 
         if event.is_state():
             key = (event.type, event.state_key)
-            if key in context.current_state_ids:
-                replaces = context.current_state_ids[key]
+            if key in context.prev_state_ids:
+                replaces = context.prev_state_ids[key]
                 event.unsigned["replaces_state"] = replaces
+            context.current_state_ids = dict(context.prev_state_ids)
+            context.current_state_ids[key] = event.event_id
+        else:
+            context.current_state_ids = context.prev_state_ids
 
         context.prev_state_events = []
         defer.returnValue(context)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index cab1660830..6ab10db328 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore):
             desc="who_forgot"
         )
 
-    def get_joined_users_from_context(self, room_id, state_group, state_ids):
+    def get_joined_users_from_context(self, event, context):
+        state_group = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore):
             state_group = object()
 
         return self._get_joined_users_from_context(
-            room_id, state_group, state_ids
+            event.room_id, state_group, context.current_state_ids, event=event,
+        )
+
+    def get_joined_users_from_state(self, room_id, state_group, state_ids):
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_users_from_context(
+            room_id, state_group, state_ids,
         )
 
     @cachedInlineCallbacks(num_args=2, cache_context=True)
     def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
-                                       cache_context):
+                                       cache_context, event=None):
         # We don't use `state_group`, its there so that we can cache based
         # on it. However, its important that its never None, since two current_state's
         # with a state_group of None are likely to be different.
@@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore):
             desc="_get_joined_users_from_context",
         )
 
-        defer.returnValue(set(row["user_id"] for row in rows))
+        users_in_room = set(row["user_id"] for row in rows)
+        if event is not None and event.type == EventTypes.Member:
+            if event.membership == Membership.JOIN:
+                if event.event_id in member_event_ids:
+                    users_in_room.add(event.state_key)
+
+        defer.returnValue(users_in_room)
 
     def is_host_joined(self, room_id, host, state_group, state_ids):
         if not state_group:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 56bfdc0b55..dce5a2f135 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -108,9 +108,6 @@ class StateStore(SQLBaseStore):
 
             state_event_ids = dict(context.current_state_ids)
 
-            if event.is_state():
-                state_event_ids[(event.type, event.state_key)] = event.event_id
-
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 218cb24889..44e859b5d1 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -312,7 +312,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         else:
             state_ids = None
 
-        context = EventContext(current_state_ids=state_ids)
+        context = EventContext()
+        context.current_state_ids = state_ids
+        context.prev_state_ids = state_ids
         context.push_actions = push_actions
 
         ordering = None
diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py
index e70ac6f14d..b69832cc1b 100644
--- a/tests/replication/test_resource.py
+++ b/tests/replication/test_resource.py
@@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase):
         self.assertEquals(body, {})
 
     @defer.inlineCallbacks
-    def test_events_and_state(self):
-        get = self.get(events="-1", state="-1", timeout="0")
+    def test_events(self):
+        get = self.get(events="-1", timeout="0")
         yield self.hs.get_handlers().room_creation_handler.create_room(
             synapse.types.create_requester(self.user), {}
         )
@@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase):
         self.assertEquals(body["events"]["field_names"], [
             "position", "internal", "json", "state_group"
         ])
-        self.assertEquals(body["state_groups"]["field_names"], [
-            "position", "room_id", "event_id"
-        ])
-        self.assertEquals(body["state_group_state"]["field_names"], [
-            "position", "type", "state_key", "event_id"
-        ])
 
     @defer.inlineCallbacks
     def test_presence(self):
diff --git a/tests/test_state.py b/tests/test_state.py
index de2d35145a..6454f994e3 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -86,17 +86,8 @@ class StateGroupStore(object):
 
         state_events = dict(context.current_state_ids)
 
-        if event.is_state():
-            state_events[(event.type, event.state_key)] = event.event_id
-
-        state_group = context.state_group
-        if not state_group:
-            state_group = self._next_group
-            self._next_group += 1
-
-            self._group_to_state[state_group] = state_events
-
-        self._event_to_state_group[event.event_id] = state_group
+        self._group_to_state[context.state_group] = state_events
+        self._event_to_state_group[event.event_id] = context.state_group
 
     def get_events(self, event_ids, **kwargs):
         return {
@@ -151,6 +142,7 @@ class StateTestCase(unittest.TestCase):
                 "get_state_groups_ids",
                 "add_event_hashes",
                 "get_events",
+                "get_next_state_group",
             ]
         )
         hs = Mock(spec_set=[
@@ -161,6 +153,8 @@ class StateTestCase(unittest.TestCase):
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
 
+        self.store.get_next_state_group.side_effect = Mock
+
         self.state = StateHandler(hs)
         self.event_id = 0
 
@@ -209,7 +203,7 @@ class StateTestCase(unittest.TestCase):
             store.store_state_groups(event, context)
             context_store[event.event_id] = context
 
-        self.assertEqual(2, len(context_store["D"].current_state_ids))
+        self.assertEqual(2, len(context_store["D"].prev_state_ids))
 
     @defer.inlineCallbacks
     def test_branch_basic_conflict(self):
@@ -265,7 +259,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"START", "A", "C"},
-            {e_id for e_id in context_store["D"].current_state_ids.values()}
+            {e_id for e_id in context_store["D"].prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -331,7 +325,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"START", "A", "B", "C"},
-            {e for e in context_store["E"].current_state_ids.values()}
+            {e for e in context_store["E"].prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -414,7 +408,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"A1", "A2", "A3", "A5", "B"},
-            {e for e in context_store["D"].current_state_ids.values()}
+            {e for e in context_store["D"].prev_state_ids.values()}
         )
 
     def _add_depths(self, nodes, edges):
@@ -447,7 +441,7 @@ class StateTestCase(unittest.TestCase):
             set(e.event_id for e in old_state), set(context.current_state_ids.values())
         )
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_annotate_with_old_state(self):
@@ -464,11 +458,9 @@ class StateTestCase(unittest.TestCase):
         )
 
         self.assertEqual(
-            set(e.event_id for e in old_state), set(context.current_state_ids.values())
+            set(e.event_id for e in old_state), set(context.prev_state_ids.values())
         )
 
-        self.assertIsNone(context.state_group)
-
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
         event = create_event(type="test_message", name="event")
@@ -514,10 +506,10 @@ class StateTestCase(unittest.TestCase):
 
         self.assertEqual(
             set([e.event_id for e in old_state]),
-            set(context.current_state_ids.values())
+            set(context.prev_state_ids.values())
         )
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_resolve_message_conflict(self):
@@ -550,7 +542,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertEqual(len(context.current_state_ids), 6)
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_resolve_state_conflict(self):
@@ -583,7 +575,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertEqual(len(context.current_state_ids), 6)
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_standard_depth_conflict(self):