summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2014-10-17 15:04:17 +0100
committerErik Johnston <erik@matrix.org>2014-10-17 15:04:17 +0100
commitf71627567b4aa58c5aba7e79c6d972b8ac26b449 (patch)
tree08f077bc1cbcc2d84e5783d41b2bbe84fbc3b24c /synapse
parentStart implementing the invite/join dance. Continue moving auth to use event.s... (diff)
downloadsynapse-f71627567b4aa58c5aba7e79c6d972b8ac26b449.tar.xz
Finish implementing the new join dance.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py9
-rw-r--r--synapse/api/events/factory.py14
-rw-r--r--synapse/federation/replication.py66
-rw-r--r--synapse/federation/transport.py68
-rw-r--r--synapse/handlers/federation.py181
-rw-r--r--synapse/state.py10
6 files changed, 222 insertions, 126 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 12ddef1b00..d1eca791ab 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -48,6 +48,15 @@ class Auth(object):
         """
         try:
             if hasattr(event, "room_id"):
+                if not event.old_state_events:
+                    # Oh, we don't know what the state of the room was, so we
+                    # are trusting that this is allowed (at least for now)
+                    defer.returnValue(True)
+
+                if hasattr(event, "outlier") and event.outlier:
+                    # TODO (erikj): Auth for outliers is done differently.
+                    defer.returnValue(True)
+
                 is_state = hasattr(event, "state_key")
 
                 if event.type == RoomMemberEvent.TYPE:
diff --git a/synapse/api/events/factory.py b/synapse/api/events/factory.py
index 0d94850cec..c6d1151cac 100644
--- a/synapse/api/events/factory.py
+++ b/synapse/api/events/factory.py
@@ -51,12 +51,20 @@ class EventFactory(object):
         self.clock = hs.get_clock()
         self.hs = hs
 
+        self.event_id_count = 0
+
+    def create_event_id(self):
+        i = str(self.event_id_count)
+        self.event_id_count += 1
+
+        local_part = str(int(self.clock.time())) + i + random_string(5)
+
+        return "%s@%s" % (local_part, self.hs.hostname)
+
     def create_event(self, etype=None, **kwargs):
         kwargs["type"] = etype
         if "event_id" not in kwargs:
-            kwargs["event_id"] = "%s@%s" % (
-                random_string(10), self.hs.hostname
-            )
+            kwargs["event_id"] = self.create_event_id()
 
         if "ts" not in kwargs:
             kwargs["ts"] = int(self.clock.time_msec())
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 08c29dece5..d482193851 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -244,13 +244,14 @@ class ReplicationLayer(object):
         pdu = None
         if pdu_list:
             pdu = pdu_list[0]
-            yield self._handle_new_pdu(pdu)
+            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdu)
 
     @defer.inlineCallbacks
     @log_function
-    def get_state_for_context(self, destination, context):
+    def get_state_for_context(self, destination, context, pdu_id=None,
+                              pdu_origin=None):
         """Requests all of the `current` state PDUs for a given context from
         a remote home server.
 
@@ -263,13 +264,14 @@ class ReplicationLayer(object):
         """
 
         transaction_data = yield self.transport_layer.get_context_state(
-            destination, context)
+            destination, context, pdu_id=pdu_id, pdu_origin=pdu_origin,
+        )
 
         transaction = Transaction(**transaction_data)
 
         pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
         for pdu in pdus:
-            yield self._handle_new_pdu(pdu)
+            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdus)
 
@@ -315,7 +317,7 @@ class ReplicationLayer(object):
 
         dl = []
         for pdu in pdu_list:
-            dl.append(self._handle_new_pdu(pdu))
+            dl.append(self._handle_new_pdu(transaction.origin, pdu))
 
         if hasattr(transaction, "edus"):
             for edu in [Edu(**x) for x in transaction.edus]:
@@ -347,14 +349,19 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def on_context_state_request(self, context):
-        results = yield self.store.get_current_state_for_context(
-            context
-        )
+    def on_context_state_request(self, context, pdu_id, pdu_origin):
+        if pdu_id and pdu_origin:
+            pdus = yield self.handler.get_state_for_pdu(
+                pdu_id, pdu_origin
+            )
+        else:
+            results = yield self.store.get_current_state_for_context(
+                context
+            )
+            pdus = [Pdu.from_pdu_tuple(p) for p in results]
 
-        logger.debug("Context returning %d results", len(results))
+        logger.debug("Context returning %d results", len(pdus))
 
-        pdus = [Pdu.from_pdu_tuple(p) for p in results]
         defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
 
     @defer.inlineCallbacks
@@ -396,9 +403,10 @@ class ReplicationLayer(object):
             defer.returnValue(
                 (404, "No handler for Query type '%s'" % (query_type, ))
             )
-
+    @defer.inlineCallbacks
     def on_make_join_request(self, context, user_id):
-        return self.handler.on_make_join_request(context, user_id)
+        pdu = yield self.handler.on_make_join_request(context, user_id)
+        defer.returnValue(pdu.get_dict())
 
     @defer.inlineCallbacks
     def on_send_join_request(self, origin, content):
@@ -406,13 +414,27 @@ class ReplicationLayer(object):
         state = yield self.handler.on_send_join_request(origin, pdu)
         defer.returnValue((200, self._transaction_from_pdus(state).get_dict()))
 
+    @defer.inlineCallbacks
     def make_join(self, destination, context, user_id):
-        return self.transport_layer.make_join(
+        pdu_dict = yield self.transport_layer.make_join(
             destination=destination,
             context=context,
             user_id=user_id,
         )
 
+        logger.debug("Got response to make_join: %s", pdu_dict)
+
+        defer.returnValue(Pdu(**pdu_dict))
+
+    def send_join(self, destination, pdu):
+        return self.transport_layer.send_join(
+            destination,
+            pdu.context,
+            pdu.pdu_id,
+            pdu.origin,
+            pdu.get_dict(),
+        )
+
     @defer.inlineCallbacks
     @log_function
     def _get_persisted_pdu(self, pdu_id, pdu_origin):
@@ -443,7 +465,7 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def _handle_new_pdu(self, pdu, backfilled=False):
+    def _handle_new_pdu(self, origin, pdu, backfilled=False):
         # We reprocess pdus when we have seen them only as outliers
         existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
 
@@ -452,6 +474,8 @@ class ReplicationLayer(object):
             defer.returnValue({})
             return
 
+        state = None
+
         # Get missing pdus if necessary.
         is_new = yield self.pdu_actions.is_new(pdu)
         if is_new and not pdu.outlier:
@@ -475,12 +499,22 @@ class ReplicationLayer(object):
                         except:
                             # TODO(erikj): Do some more intelligent retries.
                             logger.exception("Failed to get PDU")
+            else:
+                # We need to get the state at this event, since we have reached
+                # a backward extremity edge.
+                state = yield self.get_state_for_context(
+                    origin, pdu.context, pdu.pdu_id, pdu.origin,
+                )
 
         # Persist the Pdu, but don't mark it as processed yet.
         yield self.store.persist_event(pdu=pdu)
 
         if not backfilled:
-            ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+            ret = yield self.handler.on_receive_pdu(
+                pdu,
+                backfilled=backfilled,
+                state=state,
+            )
         else:
             ret = None
 
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index 4f552272e6..a0d34fd24d 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -72,7 +72,8 @@ class TransportLayer(object):
         self.received_handler = None
 
     @log_function
-    def get_context_state(self, destination, context):
+    def get_context_state(self, destination, context, pdu_id=None,
+                          pdu_origin=None):
         """ Requests all state for a given context (i.e. room) from the
         given server.
 
@@ -89,7 +90,14 @@ class TransportLayer(object):
 
         subpath = "/state/%s/" % context
 
-        return self._do_request_for_transaction(destination, subpath)
+        args = {}
+        if pdu_id and pdu_origin:
+            args["pdu_id"] = pdu_id
+            args["pdu_origin"] = pdu_origin
+
+        return self._do_request_for_transaction(
+            destination, subpath, args=args
+        )
 
     @log_function
     def get_pdu(self, destination, pdu_origin, pdu_id):
@@ -135,8 +143,10 @@ class TransportLayer(object):
 
         subpath = "/backfill/%s/" % context
 
-        args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
-        args["limit"] = limit
+        args = {
+            "v": ["%s,%s" % (i, o) for i, o in pdu_tuples],
+            "limit": limit,
+        }
 
         return self._do_request_for_transaction(
             dest,
@@ -211,6 +221,23 @@ class TransportLayer(object):
         defer.returnValue(response)
 
     @defer.inlineCallbacks
+    @log_function
+    def send_join(self, destination, context, pdu_id, origin, content):
+        path = PREFIX + "/send_join/%s/%s/%s" % (
+            context,
+            origin,
+            pdu_id,
+        )
+
+        response = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+        )
+
+        defer.returnValue(response)
+
+    @defer.inlineCallbacks
     def _authenticate_request(self, request):
         json_request = {
             "method": request.method,
@@ -330,7 +357,11 @@ class TransportLayer(object):
             re.compile("^" + PREFIX + "/state/([^/]*)/$"),
             self._with_authentication(
                 lambda origin, content, query, context:
-                handler.on_context_state_request(context)
+                handler.on_context_state_request(
+                    context,
+                    query.get("pdu_id", [None])[0],
+                    query.get("pdu_origin", [None])[0]
+                )
             )
         )
 
@@ -369,7 +400,23 @@ class TransportLayer(object):
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
-            self._on_make_join_request
+            self._with_authentication(
+                lambda origin, content, query, context, user_id:
+                self._on_make_join_request(
+                    origin, content, query, context, user_id
+                )
+            )
+        )
+
+        self.server.register_path(
+            "PUT",
+            re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, pdu_origin, pdu_id:
+                self._on_send_join_request(
+                    origin, content, query,
+                )
+            )
         )
 
     @defer.inlineCallbacks
@@ -460,18 +507,23 @@ class TransportLayer(object):
             context, versions, limit
         )
 
+    @defer.inlineCallbacks
     @log_function
     def _on_make_join_request(self, origin, content, query, context, user_id):
-        return self.request_handler.on_make_join_request(
+        content = yield self.request_handler.on_make_join_request(
             context, user_id,
         )
+        defer.returnValue((200, content))
 
+    @defer.inlineCallbacks
     @log_function
     def _on_send_join_request(self, origin, content, query):
-        return self.request_handler.on_send_join_request(
+        content = yield self.request_handler.on_send_join_request(
             origin, content,
         )
 
+        defer.returnValue((200, content))
+
 
 class TransportReceivedHandler(object):
     """ Callbacks used when we receive a transaction
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index a4f6c739c3..0ae0541bd3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -20,7 +20,7 @@ from ._base import BaseHandler
 from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent
 from synapse.api.constants import Membership
 from synapse.util.logutils import log_function
-from synapse.federation.pdu_codec import PduCodec
+from synapse.federation.pdu_codec import PduCodec, encode_event_id
 from synapse.api.errors import SynapseError
 
 from twisted.internet import defer, reactor
@@ -87,7 +87,7 @@ class FederationHandler(BaseHandler):
 
     @log_function
     @defer.inlineCallbacks
-    def on_receive_pdu(self, pdu, backfilled):
+    def on_receive_pdu(self, pdu, backfilled, state=None):
         """ Called by the ReplicationLayer when we have a new pdu. We need to
         do auth checks and put it through the StateHandler.
         """
@@ -95,7 +95,10 @@ class FederationHandler(BaseHandler):
 
         logger.debug("Got event: %s", event.event_id)
 
-        yield self.state_handler.annotate_state_groups(event)
+        if state:
+            state = [self.pdu_codec.event_from_pdu(p) for p in state]
+            state = {(e.type, e.state_key): e for e in state}
+        yield self.state_handler.annotate_state_groups(event, state=state)
 
         logger.debug("Event: %s", event)
 
@@ -108,83 +111,55 @@ class FederationHandler(BaseHandler):
             )
         else:
             is_new_state = False
+
         # TODO: Implement something in federation that allows us to
         # respond to PDU.
 
-        target_is_mine = False
-        if hasattr(event, "target_host"):
-            target_is_mine = event.target_host == self.hs.hostname
-
-        if event.type == InviteJoinEvent.TYPE:
-            if not target_is_mine:
-                logger.debug("Ignoring invite/join event %s", event)
-                return
-
-            # If we receive an invite/join event then we need to join the
-            # sender to the given room.
-            # TODO: We should probably auth this or some such
-            content = event.content
-            content.update({"membership": Membership.JOIN})
-            new_event = self.event_factory.create_event(
-                etype=RoomMemberEvent.TYPE,
-                state_key=event.user_id,
-                room_id=event.room_id,
-                user_id=event.user_id,
-                membership=Membership.JOIN,
-                content=content
+        with (yield self.room_lock.lock(event.room_id)):
+            yield self.store.persist_event(
+                event,
+                backfilled,
+                is_new_state=is_new_state
             )
 
-            yield self.hs.get_handlers().room_member_handler.change_membership(
-                new_event,
-                do_auth=False,
-            )
+        room = yield self.store.get_room(event.room_id)
 
-        else:
-            with (yield self.room_lock.lock(event.room_id)):
-                yield self.store.persist_event(
-                    event,
-                    backfilled,
-                    is_new_state=is_new_state
+        if not room:
+            # Huh, let's try and get the current state
+            try:
+                yield self.replication_layer.get_state_for_context(
+                    event.origin, event.room_id, pdu.pdu_id, pdu.origin,
                 )
 
-            room = yield self.store.get_room(event.room_id)
-
-            if not room:
-                # Huh, let's try and get the current state
-                try:
-                    yield self.replication_layer.get_state_for_context(
-                        event.origin, event.room_id
-                    )
-
-                    hosts = yield self.store.get_joined_hosts_for_room(
-                        event.room_id
-                    )
-                    if self.hs.hostname in hosts:
-                        try:
-                            yield self.store.store_room(
-                                room_id=event.room_id,
-                                room_creator_user_id="",
-                                is_public=False,
-                            )
-                        except:
-                            pass
-                except:
-                    logger.exception(
-                        "Failed to get current state for room %s",
-                        event.room_id
-                    )
-
-            if not backfilled:
-                extra_users = []
-                if event.type == RoomMemberEvent.TYPE:
-                    target_user_id = event.state_key
-                    target_user = self.hs.parse_userid(target_user_id)
-                    extra_users.append(target_user)
-
-                yield self.notifier.on_new_room_event(
-                    event, extra_users=extra_users
+                hosts = yield self.store.get_joined_hosts_for_room(
+                    event.room_id
+                )
+                if self.hs.hostname in hosts:
+                    try:
+                        yield self.store.store_room(
+                            room_id=event.room_id,
+                            room_creator_user_id="",
+                            is_public=False,
+                        )
+                    except:
+                        pass
+            except:
+                logger.exception(
+                    "Failed to get current state for room %s",
+                    event.room_id
                 )
 
+        if not backfilled:
+            extra_users = []
+            if event.type == RoomMemberEvent.TYPE:
+                target_user_id = event.state_key
+                target_user = self.hs.parse_userid(target_user_id)
+                extra_users.append(target_user)
+
+            yield self.notifier.on_new_room_event(
+                event, extra_users=extra_users
+            )
+
         if event.type == RoomMemberEvent.TYPE:
             if event.membership == Membership.JOIN:
                 user = self.hs.parse_userid(event.state_key)
@@ -214,40 +189,35 @@ class FederationHandler(BaseHandler):
     @log_function
     @defer.inlineCallbacks
     def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
-
         hosts = yield self.store.get_joined_hosts_for_room(room_id)
         if self.hs.hostname in hosts:
             # We are already in the room.
             logger.debug("We're already in the room apparently")
             defer.returnValue(False)
 
-        # First get current state to see if we are already joined.
-        try:
-            yield self.replication_layer.get_state_for_context(
-                target_host, room_id
-            )
-
-            hosts = yield self.store.get_joined_hosts_for_room(room_id)
-            if self.hs.hostname in hosts:
-                # Oh, we were actually in the room already.
-                logger.debug("We're already in the room apparently")
-                defer.returnValue(False)
-        except Exception:
-            logger.exception("Failed to get current state")
-
-        new_event = self.event_factory.create_event(
-            etype=InviteJoinEvent.TYPE,
-            target_host=target_host,
-            room_id=room_id,
-            user_id=joinee,
-            content=content
+        pdu = yield self.replication_layer.make_join(
+            target_host,
+            room_id,
+            joinee
         )
 
-        new_event.destinations = [target_host]
+        logger.debug("Got response to make_join: %s", pdu)
 
-        snapshot.fill_out_prev_events(new_event)
-        yield self.state_handler.annotate_state_groups(new_event)
-        yield self.handle_new_event(new_event, snapshot)
+        event = self.pdu_codec.event_from_pdu(pdu)
+
+        # We should assert some things.
+        assert(event.type == RoomMemberEvent.TYPE)
+        assert(event.user_id == joinee)
+        assert(event.state_key == joinee)
+        assert(event.room_id == room_id)
+
+        event.event_id = self.event_factory.create_event_id()
+        event.content = content
+
+        state = yield self.replication_layer.send_join(
+            target_host,
+            self.pdu_codec.pdu_from_event(event)
+        )
 
         # TODO (erikj): Time out here.
         d = defer.Deferred()
@@ -326,14 +296,31 @@ class FederationHandler(BaseHandler):
                     "user_joined_room", user=user, room_id=event.room_id
                 )
 
-        pdu.destinations = yield self.store.get_joined_hosts_for_room(
+        new_pdu = self.pdu_codec.pdu_from_event(event);
+        new_pdu.destinations = yield self.store.get_joined_hosts_for_room(
             event.room_id
         )
 
-        yield self.replication_layer.send_pdu(pdu)
+        yield self.replication_layer.send_pdu(new_pdu)
 
         defer.returnValue(event.state_events.values())
 
+    @defer.inlineCallbacks
+    def get_state_for_pdu(self, pdu_id, pdu_origin):
+        state_groups = yield self.store.get_state_groups(
+            [encode_event_id(pdu_id, pdu_origin)]
+        )
+
+        if state_groups:
+            defer.returnValue(
+                [
+                    self.pdu_codec.pdu_from_event(s)
+                    for s in state_groups[0].state
+                ]
+            )
+        else:
+            defer.returnValue([])
+
     @log_function
     def _on_user_joined(self, user, room_id):
         waiters = self.waiting_for_join_list.get((user.to_string(), room_id), [])
diff --git a/synapse/state.py b/synapse/state.py
index 9be6b716e2..8c4eeb8924 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -130,7 +130,13 @@ class StateHandler(object):
         defer.returnValue(is_new)
 
     @defer.inlineCallbacks
-    def annotate_state_groups(self, event):
+    def annotate_state_groups(self, event, state=None):
+        if state:
+            event.state_group = None
+            event.old_state_events = None
+            event.state_events = state
+            return
+
         state_groups = yield self.store.get_state_groups(
             event.prev_events
         )
@@ -177,7 +183,7 @@ class StateHandler(object):
         new_powers_deferreds = []
         for e in curr_events:
             new_powers_deferreds.append(
-                self.store.get_power_level(e.context, e.user_id)
+                self.store.get_power_level(e.room_id, e.user_id)
             )
 
         new_powers = yield defer.gatherResults(