summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py16
-rw-r--r--synapse/federation/handler.py11
-rw-r--r--synapse/handlers/room.py96
-rw-r--r--synapse/storage/__init__.py125
-rw-r--r--synapse/storage/_base.py12
-rw-r--r--synapse/storage/feedback.py4
-rw-r--r--synapse/storage/room.py10
-rw-r--r--synapse/storage/roommember.py10
-rw-r--r--synapse/storage/stream.py13
-rw-r--r--tests/handlers/test_room.py25
-rw-r--r--tests/utils.py7
11 files changed, 222 insertions, 107 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 15407df14a..646f6dc06c 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -33,7 +33,7 @@ class Auth(object):
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
-    def check(self, event, raises=False):
+    def check(self, event, snapshot, raises=False):
         """ Checks if this event is correctly authed.
 
         Returns:
@@ -48,7 +48,11 @@ class Auth(object):
                     allowed = yield self.is_membership_change_allowed(event)
                     defer.returnValue(allowed)
                 else:
-                    yield self.check_joined_room(event.room_id, event.user_id)
+                    self._check_joined_room(
+                        member=snapshot.membership_state,
+                        user_id=snapshot.user_id,
+                        room_id=snapshot.room_id,
+                    )
                     defer.returnValue(True)
             else:
                 raise AuthError(500, "Unknown event: %s" % event)
@@ -66,14 +70,16 @@ class Auth(object):
                 room_id=room_id,
                 user_id=user_id
             )
-            if not member or member.membership != Membership.JOIN:
-                raise AuthError(403, "User %s not in room %s" %
-                                (user_id, room_id))
+            self._check_joined_room(member, user_id, room_id)
             defer.returnValue(member)
         except AttributeError:
             pass
         defer.returnValue(None)
 
+    def _check_joined_room(self, member, user_id, room_id):
+        if not member or member.membership != Membership.JOIN:
+            raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
+
     @defer.inlineCallbacks
     def is_membership_change_allowed(self, event):
         target_user_id = event.state_key
diff --git a/synapse/federation/handler.py b/synapse/federation/handler.py
index 984c1558e9..ce98f4f94a 100644
--- a/synapse/federation/handler.py
+++ b/synapse/federation/handler.py
@@ -51,19 +51,20 @@ class FederationEventHandler(object):
 
     @log_function
     @defer.inlineCallbacks
-    def handle_new_event(self, event):
+    def handle_new_event(self, event, snapshot):
         """ Takes in an event from the client to server side, that has already
         been authed and handled by the state module, and sends it to any
         remote home servers that may be interested.
 
         Args:
             event
+            snapshot (.storage.Snapshot): THe snapshot the event happened after
 
         Returns:
             Deferred: Resolved when it has successfully been queued for
             processing.
         """
-        yield self.fill_out_prev_events(event)
+        yield self.fill_out_prev_events(event, snapshot)
 
         pdu = self.pdu_codec.pdu_from_event(event)
 
@@ -137,13 +138,11 @@ class FederationEventHandler(object):
         yield self.event_handler.on_receive(new_state_event)
 
     @defer.inlineCallbacks
-    def fill_out_prev_events(self, event):
+    def fill_out_prev_events(self, event, snapshot):
         if hasattr(event, "prev_events"):
             return
 
-        results = yield self.store.get_latest_pdus_in_context(
-            event.room_id
-        )
+        results = snapshot.prev_pdus
 
         es = [
             "%s@%s" % (p_id, origin) for p_id, origin, _ in results
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index c2b10f4189..7e34b4a6fc 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -84,20 +84,21 @@ class MessageHandler(BaseHandler):
         if stamp_event:
             event.content["hsob_ts"] = int(self.clock.time_msec())
 
-        with (yield self.room_lock.lock(event.room_id)):
-            if not suppress_auth:
-                yield self.auth.check(event, raises=True)
+        snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
 
-            # store message in db
-            store_id = yield self.store.persist_event(event)
+        if not suppress_auth:
+            yield self.auth.check(event, snapshot, raises=True)
 
-            event.destinations = yield self.store.get_joined_hosts_for_room(
-                event.room_id
-            )
+        # store message in db
+        store_id = yield self.store.persist_event(event)
+
+        event.destinations = yield self.store.get_joined_hosts_for_room(
+            event.room_id
+        )
 
-            self.notifier.on_new_room_event(event, store_id)
+        self.notifier.on_new_room_event(event, store_id)
 
-        yield self.hs.get_federation().handle_new_event(event)
+        yield self.hs.get_federation().handle_new_event(event, snapshot)
 
     @defer.inlineCallbacks
     def get_messages(self, user_id=None, room_id=None, pagin_config=None,
@@ -134,23 +135,24 @@ class MessageHandler(BaseHandler):
             SynapseError if something went wrong.
         """
 
-        with (yield self.room_lock.lock(event.room_id)):
-            yield self.auth.check(event, raises=True)
+        snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
 
-            if stamp_event:
-                event.content["hsob_ts"] = int(self.clock.time_msec())
+        yield self.auth.check(event, snapshot, raises=True)
 
-            yield self.state_handler.handle_new_event(event)
+        if stamp_event:
+            event.content["hsob_ts"] = int(self.clock.time_msec())
 
-            # store in db
-            store_id = yield self.store.persist_event(event)
+        yield self.state_handler.handle_new_event(event)
 
-            event.destinations = yield self.store.get_joined_hosts_for_room(
-                event.room_id
-            )
-            self.notifier.on_new_room_event(event, store_id)
+        # store in db
+        store_id = yield self.store.persist_event(event)
+
+        event.destinations = yield self.store.get_joined_hosts_for_room(
+            event.room_id
+        )
+        self.notifier.on_new_room_event(event, store_id)
 
-        yield self.hs.get_federation().handle_new_event(event)
+        yield self.hs.get_federation().handle_new_event(event, snapshot)
 
     @defer.inlineCallbacks
     def get_room_data(self, user_id=None, room_id=None,
@@ -219,16 +221,17 @@ class MessageHandler(BaseHandler):
         if stamp_event:
             event.content["hsob_ts"] = int(self.clock.time_msec())
 
-        with (yield self.room_lock.lock(event.room_id)):
-            yield self.auth.check(event, raises=True)
+        snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
 
-            # store message in db
-            store_id = yield self.store.persist_event(event)
+        yield self.auth.check(event, snapshot, raises=True)
 
-            event.destinations = yield self.store.get_joined_hosts_for_room(
-                event.room_id
-            )
-        yield self.hs.get_federation().handle_new_event(event)
+        # store message in db
+        store_id = yield self.store.persist_event(event)
+
+        event.destinations = yield self.store.get_joined_hosts_for_room(
+            event.room_id
+        )
+        yield self.hs.get_federation().handle_new_event(event, snapshot)
 
         self.notifier.on_new_room_event(event, store_id)
 
@@ -525,6 +528,11 @@ class RoomMemberHandler(BaseHandler):
         """
         target_user_id = event.state_key
 
+        snapshot = yield self.store.snapshot_room(
+            event.room_id, event.user_id,
+            RoomMemberEvent.TYPE, event.target_user_id
+        )
+        ## TODO(markjh): get prev state from snapshot.
         prev_state = yield self.store.get_room_member(
             target_user_id, event.room_id
         )
@@ -545,24 +553,22 @@ class RoomMemberHandler(BaseHandler):
         # if this HS is not currently in the room, i.e. we have to do the
         # invite/join dance.
         if event.membership == Membership.JOIN:
-            yield self._do_join(event, do_auth=do_auth)
+            yield self._do_join(event, snapshot, do_auth=do_auth)
         else:
             # This is not a JOIN, so we can handle it normally.
             if do_auth:
-                yield self.auth.check(event, raises=True)
+                yield self.auth.check(event, snapshot, raises=True)
 
-            prev_state = yield self.store.get_room_member(
-                target_user_id, event.room_id
-            )
             if prev_state and prev_state.membership == event.membership:
                 # double same action, treat this event as a NOOP.
                 defer.returnValue({})
                 return
 
-            yield self.state_handler.handle_new_event(event)
+            yield self.state_handler.handle_new_event(event, snapshot)
             yield self._do_local_membership_update(
                 event,
                 membership=event.content["membership"],
+                snapshot=snapshot,
             )
 
         defer.returnValue({"room_id": room_id})
@@ -592,12 +598,16 @@ class RoomMemberHandler(BaseHandler):
             content=content,
         )
 
-        yield self._do_join(new_event, room_host=host, do_auth=True)
+        snapshot = yield self.store.snapshot_room(
+            room_id, joinee, RoomMemberEvent.TYPE, joinee
+        )
+
+        yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
 
         defer.returnValue({"room_id": room_id})
 
     @defer.inlineCallbacks
-    def _do_join(self, event, room_host=None, do_auth=True):
+    def _do_join(self, event, snapshot, room_host=None, do_auth=True):
         joinee = self.hs.parse_userid(event.state_key)
         # room_id = RoomID.from_string(event.room_id, self.hs)
         room_id = event.room_id
@@ -619,6 +629,7 @@ class RoomMemberHandler(BaseHandler):
         elif room_host:
             should_do_dance = True
         else:
+            # TODO(markjh): get prev_state from snapshot
             prev_state = yield self.store.get_room_member(
                 joinee.to_string(), room_id
             )
@@ -646,12 +657,13 @@ class RoomMemberHandler(BaseHandler):
             logger.debug("Doing normal join")
 
             if do_auth:
-                yield self.auth.check(event, raises=True)
+                yield self.auth.check(event, snapshot, raises=True)
 
-            yield self.state_handler.handle_new_event(event)
+            yield self.state_handler.handle_new_event(event, snapshot)
             yield self._do_local_membership_update(
                 event,
                 membership=event.content["membership"],
+                snapshot=snapshot,
             )
 
         user = self.hs.parse_userid(event.user_id)
@@ -696,7 +708,7 @@ class RoomMemberHandler(BaseHandler):
         defer.returnValue([r.room_id for r in rooms])
 
     @defer.inlineCallbacks
-    def _do_local_membership_update(self, event, membership):
+    def _do_local_membership_update(self, event, membership, snapshot):
         # store membership
         store_id = yield self.store.persist_event(event)
 
@@ -723,7 +735,7 @@ class RoomMemberHandler(BaseHandler):
 
         event.destinations = list(set(destinations))
 
-        yield self.hs.get_federation().handle_new_event(event)
+        yield self.hs.get_federation().handle_new_event(event, snapshot)
         self.notifier.on_new_room_event(event, store_id)
 
 
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index a97a42e1e3..5e52e9fecf 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -58,19 +58,21 @@ class DataStore(RoomMemberStore, RoomStore,
     @defer.inlineCallbacks
     @log_function
     def persist_event(self, event, backfilled=False):
-        if event.type == RoomMemberEvent.TYPE:
-            yield self._store_room_member(event)
-        elif event.type == FeedbackEvent.TYPE:
-            yield self._store_feedback(event)
-#        elif event.type == RoomConfigEvent.TYPE:
-#            yield self._store_room_config(event)
-        elif event.type == RoomNameEvent.TYPE:
-            yield self._store_room_name(event)
-        elif event.type == RoomTopicEvent.TYPE:
-            yield self._store_room_topic(event)
+        # FIXME (erikj): This should be removed when we start amalgamating
+        # event and pdu storage
+        yield self.hs.get_federation().fill_out_prev_events(event)
 
-        ret = yield self._store_event(event, backfilled)
-        defer.returnValue(ret)
+        stream_ordering = None
+        if backfilled:
+            if not self.min_token_deferred.called:
+                yield self.min_token_deferred
+            self.min_token -= 1
+            stream_ordering = self.min_token
+
+        latest = yield self._db_pool.runInteraction(
+            self._persist_event_txn, event, backfilled, stream_ordering
+        )
+        defer.returnValue(latest)
 
     @defer.inlineCallbacks
     def get_event(self, event_id):
@@ -90,12 +92,18 @@ class DataStore(RoomMemberStore, RoomStore,
         event = self._parse_event_from_row(events_dict)
         defer.returnValue(event)
 
-    @defer.inlineCallbacks
     @log_function
-    def _store_event(self, event, backfilled):
-        # FIXME (erikj): This should be removed when we start amalgamating
-        # event and pdu storage
-        yield self.hs.get_federation().fill_out_prev_events(event)
+    def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None):
+        if event.type == RoomMemberEvent.TYPE:
+            self._store_room_member_txn(txn, event)
+        elif event.type == FeedbackEvent.TYPE:
+            self._store_feedback_txn(txn,event)
+#        elif event.type == RoomConfigEvent.TYPE:
+#            self._store_room_config_txn(txn, event)
+        elif event.type == RoomNameEvent.TYPE:
+            self._store_room_name_txn(txn, event)
+        elif event.type == RoomTopicEvent.TYPE:
+            self._store_room_topic_txn(txn, event)
 
         vals = {
             "topological_ordering": event.depth,
@@ -106,17 +114,14 @@ class DataStore(RoomMemberStore, RoomStore,
             "processed": True,
         }
 
+        if stream_ordering is not None:
+            vals["stream_ordering"] = stream_ordering
+
         if hasattr(event, "outlier"):
             vals["outlier"] = event.outlier
         else:
             vals["outlier"] = False
 
-        if backfilled:
-            if not self.min_token_deferred.called:
-                yield self.min_token_deferred
-            self.min_token -= 1
-            vals["stream_ordering"] = self.min_token
-
         unrec = {
             k: v
             for k, v in event.get_full_dict().items()
@@ -125,7 +130,7 @@ class DataStore(RoomMemberStore, RoomStore,
         vals["unrecognized_keys"] = json.dumps(unrec)
 
         try:
-            yield self._simple_insert("events", vals)
+            self._simple_insert_txn(txn, "events", vals)
         except:
             logger.exception(
                 "Failed to persist, probably duplicate: %s",
@@ -144,9 +149,10 @@ class DataStore(RoomMemberStore, RoomStore,
             if hasattr(event, "prev_state"):
                 vals["prev_state"] = event.prev_state
 
-            yield self._simple_insert("state_events", vals)
+            self._simple_insert_txn(txn, "state_events", vals)
 
-            yield self._simple_insert(
+            self._simple_insert_txn(
+                txn,
                 "current_state_events",
                 {
                     "event_id": event.event_id,
@@ -156,8 +162,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 }
             )
 
-        latest = yield self.get_room_events_max_id()
-        defer.returnValue(latest)
+        return self._get_room_events_max_id_(txn)
 
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key=""):
@@ -193,6 +198,70 @@ class DataStore(RoomMemberStore, RoomStore,
         defer.returnValue(self.min_token)
 
 
+    def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+        """Snapshot the room for an update by a user
+        Args:
+            room_id (synapse.types.RoomId): The room to snapshot.
+            user_id (synapse.types.UserId): The user to snapshot the room for.
+            state_type (str): Optional state type to snapshot.
+            state_key (str): Optional state key to snapshot.
+        Returns:
+            synapse.storage.Snapshot: A snapshot of the state of the room.
+        """
+        def _snapshot(txn):
+            membership_state = self._get_room_member(txn, user_id)
+            prev_pdus = self._get_latest_pdus_in_context(
+                txn, room_id
+            )
+            if state_type is not None and state_key is not None:
+                prev_state_pdu = self._get_current_state_pdu(
+                    txn, room_id, state_type, state_key
+                )
+            else:
+                prev_state_pdu = None
+
+            return Snapshot(
+                store=self,
+                room_id=room_id,
+                user_id=user_id,
+                prev_pdus=prev_pdus,
+                membership_state=membership_state,
+                state_type=state_type,
+                state_key=state_key,
+                prev_state_pdu=prev_state_pdu,
+            )
+
+        return self._db_pool.runInteraction(_snapshot)
+
+
+class Snapshot(object):
+    """Snapshot of the state of a room
+    Args:
+        store (DataStore): The datastore.
+        room_id (RoomId): The room of the snapshot.
+        user_id (UserId): The user this snapshot is for.
+        prev_pdus (list): The list of PDU ids this snapshot is after.
+        membership_state (RoomMemberEvent): The current state of the user in
+            the room.
+        state_type (str, optional): State type captured by the snapshot
+        state_key (str, optional): State key captured by the snapshot
+        prev_state_pdu (PduEntry, optional): pdu id of
+            the previous value of the state type and key in the room.
+    """
+
+    def __init__(self, store, room_id, user_id, prev_pdus,
+                 membership_state, state_type=None, state_key=None,
+                 prev_state_pdu=None):
+        self.store = store
+        self.room_id = room_id
+        self.user_id = user_id
+        self.prev_pdus = prev_pdus
+        self.membership_state
+        self.state_type = state_type
+        self.state_key = state_key
+        self.prev_state_pdu = prev_state_pdu
+
+
 def schema_path(schema):
     """ Get a filesystem path for the named database schema
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 75aab2d3b9..33d56f47ce 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -86,16 +86,18 @@ class SQLBaseStore(object):
             table : string giving the table name
             values : dict of new column names and values for them
         """
+        return self._db_pool.runInteraction(
+            self._simple_insert_txn, table, values,
+        )
+
+    def _simple_insert_txn(self, txn, table, values):
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
             table,
             ", ".join(k for k in values),
             ", ".join("?" for k in values)
         )
-
-        def func(txn):
-            txn.execute(sql, values.values())
-            return txn.lastrowid
-        return self._db_pool.runInteraction(func)
+        txn.execute(sql, values.values())
+        return txn.lastrowid
 
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False):
diff --git a/synapse/storage/feedback.py b/synapse/storage/feedback.py
index cdc6670116..336192ad01 100644
--- a/synapse/storage/feedback.py
+++ b/synapse/storage/feedback.py
@@ -20,8 +20,8 @@ from ._base import SQLBaseStore
 
 class FeedbackStore(SQLBaseStore):
 
-    def _store_feedback(self, event):
-        return self._simple_insert("feedback", {
+    def _store_feedback_txn(self, txn, event):
+        self._simple_insert_txn(txn, "feedback", {
             "event_id": event.event_id,
             "feedback_type": event.feedback_type,
             "room_id": event.room_id,
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index a5751005ef..d1f1a232f8 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -129,8 +129,9 @@ class RoomStore(SQLBaseStore):
 
         defer.returnValue(ret)
 
-    def _store_room_topic(self, event):
-        return self._simple_insert(
+    def _store_room_topic_txn(self, txn, event):
+        self._simple_insert_txn(
+            txn,
             "topics",
             {
                 "event_id": event.event_id,
@@ -139,8 +140,9 @@ class RoomStore(SQLBaseStore):
             }
         )
 
-    def _store_room_name(self, event):
-        return self._simple_insert(
+    def _store_room_name_txn(self, txn, event):
+        self._simple_insert_txn(
+            txn,
             "room_names",
             {
                 "event_id": event.event_id,
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 4ad37af0f3..5038aeea03 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -26,14 +26,14 @@ logger = logging.getLogger(__name__)
 
 class RoomMemberStore(SQLBaseStore):
 
-    @defer.inlineCallbacks
-    def _store_room_member(self, event):
+    def _store_room_member_txn(self, txn, event):
         """Store a room member in the database.
         """
         target_user_id = event.state_key
         domain = self.hs.parse_userid(target_user_id).domain
 
-        yield self._simple_insert(
+        self._simple_insert_txn(
+            txn,
             "room_memberships",
             {
                 "event_id": event.event_id,
@@ -50,13 +50,13 @@ class RoomMemberStore(SQLBaseStore):
                 "INSERT OR IGNORE INTO room_hosts (room_id, host) "
                 "VALUES (?, ?)"
             )
-            yield self._execute(None, sql, event.room_id, domain)
+            txn.execute(sql, event.room_id, domain)
         else:
             sql = (
                 "DELETE FROM room_hosts WHERE room_id = ? AND host = ?"
             )
 
-            yield self._execute(None, sql, event.room_id, domain)
+            txn.execute(sql, event.room_id, domain)
 
     @defer.inlineCallbacks
     def get_room_member(self, user_id, room_id):
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index cae80563b4..ac887e2957 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -281,17 +281,20 @@ class StreamStore(SQLBaseStore):
             )
         )
 
-    @defer.inlineCallbacks
     def get_room_events_max_id(self):
-        res = yield self._execute_and_decode(
+        return self._db_pool.runInteraction(self._get_room_events_max_id_txn)
+
+    def _get_room_events_max_id_txn(self, txn):
+        txn.execute(
             "SELECT MAX(stream_ordering) as m FROM events"
         )
 
+        res = self.cursor_to_dict(txn)
+
         logger.debug("get_room_events_max_id: %s", res)
 
         if not res or not res[0] or not res[0]["m"]:
-            defer.returnValue("s1")
-            return
+            return "s1"
 
         key = res[0]["m"] + 1
-        defer.returnValue("s%d" % (key,))
+        return "s%d" % (key,)
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index fddab8f74f..a1ab8dde68 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -45,6 +45,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
                 "get_room_member",
                 "get_room",
                 "store_room",
+                "snapshot_room",
             ]),
             resource_for_federation=NonCallableMock(),
             http_client=NonCallableMock(spec_set=[]),
@@ -75,6 +76,10 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         self.handlers.profile_handler = ProfileHandler(self.hs)
         self.room_member_handler = self.handlers.room_member_handler
 
+        self.snapshot = Mock()
+        self.datastore.snapshot_room.return_value = self.snapshot
+
+
     @defer.inlineCallbacks
     def test_invite(self):
         room_id = "!foo:red"
@@ -104,8 +109,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         # Actual invocation
         yield self.room_member_handler.change_membership(event)
 
-        self.state_handler.handle_new_event.assert_called_once_with(event)
-        self.federation.handle_new_event.assert_called_once_with(event)
+        self.state_handler.handle_new_event.assert_called_once_with(
+            event, self.snapshot,
+        )
+        self.federation.handle_new_event.assert_called_once_with(
+            event, self.snapshot,
+        )
 
         self.assertEquals(
             set(["blue", "red", "green"]),
@@ -116,7 +125,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
             event
         )
         self.notifier.on_new_room_event.assert_called_once_with(
-                event, store_id)
+            event, store_id
+        )
 
         self.assertFalse(self.datastore.get_room.called)
         self.assertFalse(self.datastore.store_room.called)
@@ -148,6 +158,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
 
         self.datastore.get_joined_hosts_for_room.side_effect = get_joined
 
+
         store_id = "store_id_fooo"
         self.datastore.persist_event.return_value = defer.succeed(store_id)
         self.datastore.get_room.return_value = defer.succeed(1)  # Not None.
@@ -163,8 +174,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         # Actual invocation
         yield self.room_member_handler.change_membership(event)
 
-        self.state_handler.handle_new_event.assert_called_once_with(event)
-        self.federation.handle_new_event.assert_called_once_with(event)
+        self.state_handler.handle_new_event.assert_called_once_with(
+            event, self.snapshot
+        )
+        self.federation.handle_new_event.assert_called_once_with(
+            event, self.snapshot
+        )
 
         self.assertEquals(
             set(["red", "green"]),
diff --git a/tests/utils.py b/tests/utils.py
index f40cbce51d..aa9608a1ed 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -127,6 +127,13 @@ class MemoryDataStore(object):
         self.current_state = {}
         self.events = []
 
+    Snapshot = namedtuple("Snapshot", "room_id user_id membership_state")
+
+    def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+        return self.Snapshot(
+            room_id, user_id, self.get_room_member(user_id, room_id)
+        )
+
     def register(self, user_id, token, password_hash):
         if user_id in self.tokens_to_users.values():
             raise StoreError(400, "User in use.")