summary refs log tree commit diff
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2014-08-22 17:00:10 +0100
committerMark Haines <mark.haines@matrix.org>2014-08-22 17:00:10 +0100
commit1379dcae6fb30c772fd77d68b27833fb7f418104 (patch)
tree145a0503a56336da9f25e2817f4dd87e2c8e01a7
parentMerge branch 'master' of github.com:matrix-org/synapse into develop (diff)
downloadsynapse-1379dcae6fb30c772fd77d68b27833fb7f418104.tar.xz
Take a snapshot of the state of the room before performing updates
-rw-r--r--synapse/api/auth.py16
-rw-r--r--synapse/federation/handler.py11
-rw-r--r--synapse/handlers/room.py97
-rw-r--r--synapse/storage/__init__.py64
-rw-r--r--tests/handlers/test_room.py25
-rw-r--r--tests/utils.py7
6 files changed, 162 insertions, 58 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 31852b29a5..91ec0995f9 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -34,7 +34,7 @@ class Auth(object):
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
-    def check(self, event, raises=False):
+    def check(self, event, snapshot, raises=False):
         """ Checks if this event is correctly authed.
 
         Returns:
@@ -46,7 +46,11 @@ class Auth(object):
         try:
             if event.type in [RoomTopicEvent.TYPE, MessageEvent.TYPE,
                               FeedbackEvent.TYPE]:
-                yield self.check_joined_room(event.room_id, event.user_id)
+                self._check_joined_room(
+                    member=snapshot.membership_state,
+                    user_id=snapshot.user_id,
+                    room_id=snapshot.room_id,
+                )
                 defer.returnValue(True)
             elif event.type == RoomMemberEvent.TYPE:
                 allowed = yield self.is_membership_change_allowed(event)
@@ -67,14 +71,16 @@ class Auth(object):
                 room_id=room_id,
                 user_id=user_id
             )
-            if not member or member.membership != Membership.JOIN:
-                raise AuthError(403, "User %s not in room %s" %
-                                (user_id, room_id))
+            self._check_joined_room(member, user_id, room_id)
             defer.returnValue(member)
         except AttributeError:
             pass
         defer.returnValue(None)
 
+    def _check_joined_room(self, member, user_id, room_id):
+        if not member or member.membership != Membership.JOIN:
+            raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
+
     @defer.inlineCallbacks
     def is_membership_change_allowed(self, event):
         # does this room even exist
diff --git a/synapse/federation/handler.py b/synapse/federation/handler.py
index 984c1558e9..ce98f4f94a 100644
--- a/synapse/federation/handler.py
+++ b/synapse/federation/handler.py
@@ -51,19 +51,20 @@ class FederationEventHandler(object):
 
     @log_function
     @defer.inlineCallbacks
-    def handle_new_event(self, event):
+    def handle_new_event(self, event, snapshot):
         """ Takes in an event from the client to server side, that has already
         been authed and handled by the state module, and sends it to any
         remote home servers that may be interested.
 
         Args:
             event
+            snapshot (.storage.Snapshot): THe snapshot the event happened after
 
         Returns:
             Deferred: Resolved when it has successfully been queued for
             processing.
         """
-        yield self.fill_out_prev_events(event)
+        yield self.fill_out_prev_events(event, snapshot)
 
         pdu = self.pdu_codec.pdu_from_event(event)
 
@@ -137,13 +138,11 @@ class FederationEventHandler(object):
         yield self.event_handler.on_receive(new_state_event)
 
     @defer.inlineCallbacks
-    def fill_out_prev_events(self, event):
+    def fill_out_prev_events(self, event, snapshot):
         if hasattr(event, "prev_events"):
             return
 
-        results = yield self.store.get_latest_pdus_in_context(
-            event.room_id
-        )
+        results = snapshot.prev_pdus
 
         es = [
             "%s@%s" % (p_id, origin) for p_id, origin, _ in results
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 6229ee9bfa..074363e80d 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -85,20 +85,21 @@ class MessageHandler(BaseHandler):
         if stamp_event:
             event.content["hsob_ts"] = int(self.clock.time_msec())
 
-        with (yield self.room_lock.lock(event.room_id)):
-            if not suppress_auth:
-                yield self.auth.check(event, raises=True)
+        snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
 
-            # store message in db
-            store_id = yield self.store.persist_event(event)
+        if not suppress_auth:
+            yield self.auth.check(event, snapshot, raises=True)
 
-            event.destinations = yield self.store.get_joined_hosts_for_room(
-                event.room_id
-            )
+        # store message in db
+        store_id = yield self.store.persist_event(event)
 
-            self.notifier.on_new_room_event(event, store_id)
+        event.destinations = yield self.store.get_joined_hosts_for_room(
+            event.room_id
+        )
 
-        yield self.hs.get_federation().handle_new_event(event)
+        self.notifier.on_new_room_event(event, store_id)
+
+        yield self.hs.get_federation().handle_new_event(event, snapshot)
 
     @defer.inlineCallbacks
     def get_messages(self, user_id=None, room_id=None, pagin_config=None,
@@ -135,23 +136,24 @@ class MessageHandler(BaseHandler):
             SynapseError if something went wrong.
         """
 
-        with (yield self.room_lock.lock(event.room_id)):
-            yield self.auth.check(event, raises=True)
+        snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
 
-            if stamp_event:
-                event.content["hsob_ts"] = int(self.clock.time_msec())
+        yield self.auth.check(event, snapshot, raises=True)
 
-            yield self.state_handler.handle_new_event(event)
+        if stamp_event:
+            event.content["hsob_ts"] = int(self.clock.time_msec())
 
-            # store in db
-            store_id = yield self.store.persist_event(event)
+        yield self.state_handler.handle_new_event(event)
 
-            event.destinations = yield self.store.get_joined_hosts_for_room(
-                event.room_id
-            )
-            self.notifier.on_new_room_event(event, store_id)
+        # store in db
+        store_id = yield self.store.persist_event(event)
 
-        yield self.hs.get_federation().handle_new_event(event)
+        event.destinations = yield self.store.get_joined_hosts_for_room(
+            event.room_id
+        )
+        self.notifier.on_new_room_event(event, store_id)
+
+        yield self.hs.get_federation().handle_new_event(event, snapshot)
 
     @defer.inlineCallbacks
     def get_room_data(self, user_id=None, room_id=None,
@@ -220,16 +222,18 @@ class MessageHandler(BaseHandler):
         if stamp_event:
             event.content["hsob_ts"] = int(self.clock.time_msec())
 
-        with (yield self.room_lock.lock(event.room_id)):
-            yield self.auth.check(event, raises=True)
+        snapshot = yield self.store.snapshot_room(event.room_id, user_id)
 
-            # store message in db
-            store_id = yield self.store.persist_event(event)
 
-            event.destinations = yield self.store.get_joined_hosts_for_room(
-                event.room_id
-            )
-        yield self.hs.get_federation().handle_new_event(event)
+        yield self.auth.check(event, snapshot, raises=True)
+
+        # store message in db
+        store_id = yield self.store.persist_event(event)
+
+        event.destinations = yield self.store.get_joined_hosts_for_room(
+            event.room_id
+        )
+        yield self.hs.get_federation().handle_new_event(event, snapshot)
 
         self.notifier.on_new_room_event(event, store_id)
 
@@ -503,6 +507,11 @@ class RoomMemberHandler(BaseHandler):
             SynapseError if there was a problem changing the membership.
         """
 
+        snapshot = yield self.store.snapshot_room(
+            event.room_id, event.user_id,
+            RoomMemberEvent.TYPE, event.target_user_id
+        )
+        ## TODO(markjh): get prev state from snapshot.
         prev_state = yield self.store.get_room_member(
             event.target_user_id, event.room_id
         )
@@ -523,24 +532,22 @@ class RoomMemberHandler(BaseHandler):
         # if this HS is not currently in the room, i.e. we have to do the
         # invite/join dance.
         if event.membership == Membership.JOIN:
-            yield self._do_join(event, do_auth=do_auth)
+            yield self._do_join(event, snapshot, do_auth=do_auth)
         else:
             # This is not a JOIN, so we can handle it normally.
             if do_auth:
-                yield self.auth.check(event, raises=True)
+                yield self.auth.check(event, snapshot, raises=True)
 
-            prev_state = yield self.store.get_room_member(
-                event.target_user_id, event.room_id
-            )
             if prev_state and prev_state.membership == event.membership:
                 # double same action, treat this event as a NOOP.
                 defer.returnValue({})
                 return
 
-            yield self.state_handler.handle_new_event(event)
+            yield self.state_handler.handle_new_event(event, snapshot)
             yield self._do_local_membership_update(
                 event,
                 membership=event.content["membership"],
+                snapshot=snapshot,
             )
 
         defer.returnValue({"room_id": room_id})
@@ -570,12 +577,16 @@ class RoomMemberHandler(BaseHandler):
             content=content,
         )
 
-        yield self._do_join(new_event, room_host=host, do_auth=True)
+        snapshot = yield store.snapshot_room(
+            room_id, joinee, RoomMemberEvent.TYPE, event.target_user_id
+        )
+
+        yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
 
         defer.returnValue({"room_id": room_id})
 
     @defer.inlineCallbacks
-    def _do_join(self, event, room_host=None, do_auth=True):
+    def _do_join(self, event, snapshot, room_host=None, do_auth=True):
         joinee = self.hs.parse_userid(event.target_user_id)
         # room_id = RoomID.from_string(event.room_id, self.hs)
         room_id = event.room_id
@@ -597,6 +608,7 @@ class RoomMemberHandler(BaseHandler):
         elif room_host:
             should_do_dance = True
         else:
+            # TODO(markjh): get prev_state from snapshot
             prev_state = yield self.store.get_room_member(
                 joinee.to_string(), room_id
             )
@@ -624,12 +636,13 @@ class RoomMemberHandler(BaseHandler):
             logger.debug("Doing normal join")
 
             if do_auth:
-                yield self.auth.check(event, raises=True)
+                yield self.auth.check(event, snapshot, raises=True)
 
-            yield self.state_handler.handle_new_event(event)
+            yield self.state_handler.handle_new_event(event, snapshot)
             yield self._do_local_membership_update(
                 event,
                 membership=event.content["membership"],
+                snapshot=snapshot,
             )
 
         user = self.hs.parse_userid(event.user_id)
@@ -674,7 +687,7 @@ class RoomMemberHandler(BaseHandler):
         defer.returnValue([r.room_id for r in rooms])
 
     @defer.inlineCallbacks
-    def _do_local_membership_update(self, event, membership):
+    def _do_local_membership_update(self, event, membership, snapshot):
         # store membership
         store_id = yield self.store.persist_event(event)
 
@@ -700,7 +713,7 @@ class RoomMemberHandler(BaseHandler):
 
         event.destinations = list(set(destinations))
 
-        yield self.hs.get_federation().handle_new_event(event)
+        yield self.hs.get_federation().handle_new_event(event, snapshot)
         self.notifier.on_new_room_event(event, store_id)
 
 
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7732906927..d23df15092 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -187,6 +187,70 @@ class DataStore(RoomMemberStore, RoomStore,
         defer.returnValue(self.min_token)
 
 
+    def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+        """Snapshot the room for an update by a user
+        Args:
+            room_id (synapse.types.RoomId): The room to snapshot.
+            user_id (synapse.types.UserId): The user to snapshot the room for.
+            state_type (str): Optional state type to snapshot.
+            state_key (str): Optional state key to snapshot.
+        Returns:
+            synapse.storage.Snapshot: A snapshot of the state of the room.
+        """
+        def _snapshot(txn):
+            membership_state = self._get_room_member(txn, user_id)
+            prev_pdus = self._get_latest_pdus_in_context(
+                txn, room_id
+            )
+            if state_type is not None and state_key is not None:
+                prev_state_pdu = self._get_current_state_pdu(
+                    txn, room_id, state_type, state_key
+                )
+            else:
+                prev_state_pdu = None
+
+            return Snapshot(
+                store=self,
+                room_id=room_id,
+                user_id=user_id,
+                prev_pdus=prev_pdus,
+                membership_state=membership_state,
+                state_type=state_type,
+                state_key=state_key,
+                prev_state_pdu=prev_state_pdu,
+            )
+
+        return self._db_pool.runInteraction(_snapshot)
+
+
+class Snapshot(object):
+    """Snapshot of the state of a room
+    Args:
+        store (DataStore): The datastore.
+        room_id (RoomId): The room of the snapshot.
+        user_id (UserId): The user this snapshot is for.
+        prev_pdus (list): The list of PDU ids this snapshot is after.
+        membership_state (RoomMemberEvent): The current state of the user in
+            the room.
+        state_type (str, optional): State type captured by the snapshot
+        state_key (str, optional): State key captured by the snapshot
+        prev_state_pdu (PduEntry, optional): pdu id of
+            the previous value of the state type and key in the room.
+    """
+
+    def __init__(self, store, room_id, user_id, prev_pdus,
+                 membership_state, state_type=None, state_key=None,
+                 prev_state_pdu=None):
+        self.store = store
+        self.room_id = room_id
+        self.user_id = user_id
+        self.prev_pdus = prev_pdus
+        self.membership_state
+        self.state_type = state_type
+        self.state_key = state_key
+        self.prev_state_pdu = prev_state_pdu
+
+
 def schema_path(schema):
     """ Get a filesystem path for the named database schema
 
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index bf71d3be3b..9ab4096438 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -45,6 +45,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
                 "get_room_member",
                 "get_room",
                 "store_room",
+                "snapshot_room",
             ]),
             resource_for_federation=NonCallableMock(),
             http_client=NonCallableMock(spec_set=[]),
@@ -75,6 +76,10 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         self.handlers.profile_handler = ProfileHandler(self.hs)
         self.room_member_handler = self.handlers.room_member_handler
 
+        self.snapshot = Mock()
+        self.datastore.snapshot_room.return_value = self.snapshot
+
+
     @defer.inlineCallbacks
     def test_invite(self):
         room_id = "!foo:red"
@@ -104,8 +109,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         # Actual invocation
         yield self.room_member_handler.change_membership(event)
 
-        self.state_handler.handle_new_event.assert_called_once_with(event)
-        self.federation.handle_new_event.assert_called_once_with(event)
+        self.state_handler.handle_new_event.assert_called_once_with(
+            event, self.snapshot,
+        )
+        self.federation.handle_new_event.assert_called_once_with(
+            event, self.snapshot,
+        )
 
         self.assertEquals(
             set(["blue", "red", "green"]),
@@ -116,7 +125,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
             event
         )
         self.notifier.on_new_room_event.assert_called_once_with(
-                event, store_id)
+            event, store_id
+        )
 
         self.assertFalse(self.datastore.get_room.called)
         self.assertFalse(self.datastore.store_room.called)
@@ -148,6 +158,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
 
         self.datastore.get_joined_hosts_for_room.side_effect = get_joined
 
+
         store_id = "store_id_fooo"
         self.datastore.persist_event.return_value = defer.succeed(store_id)
         self.datastore.get_room.return_value = defer.succeed(1)  # Not None.
@@ -163,8 +174,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         # Actual invocation
         yield self.room_member_handler.change_membership(event)
 
-        self.state_handler.handle_new_event.assert_called_once_with(event)
-        self.federation.handle_new_event.assert_called_once_with(event)
+        self.state_handler.handle_new_event.assert_called_once_with(
+            event, self.snapshot
+        )
+        self.federation.handle_new_event.assert_called_once_with(
+            event, self.snapshot
+        )
 
         self.assertEquals(
             set(["red", "green"]),
diff --git a/tests/utils.py b/tests/utils.py
index c68b17f7b9..f10bb8960f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -127,6 +127,13 @@ class MemoryDataStore(object):
         self.current_state = {}
         self.events = []
 
+    Snapshot = namedtuple("Snapshot", "room_id user_id membership_state")
+
+    def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+        return self.Snapshot(
+            room_id, user_id, self.get_room_member(user_id, room_id)
+        )
+
     def register(self, user_id, token, password_hash):
         if user_id in self.tokens_to_users.values():
             raise StoreError(400, "User in use.")