summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2014-10-17 18:56:42 +0100
committerErik Johnston <erik@matrix.org>2014-10-17 18:56:42 +0100
commit5ffe5ab43fa090111a0141b04ce6342172f60724 (patch)
tree45af4a0c2fdbb3c89853645cafe1440b13c6d3f4
parentFinish implementing the new join dance. (diff)
downloadsynapse-5ffe5ab43fa090111a0141b04ce6342172f60724.tar.xz
Use state groups to get current state. Make join dance actually work.
-rw-r--r--synapse/api/auth.py5
-rw-r--r--synapse/federation/replication.py17
-rw-r--r--synapse/federation/transport.py57
-rw-r--r--synapse/handlers/federation.py74
-rw-r--r--synapse/handlers/message.py6
-rw-r--r--synapse/rest/base.py5
-rw-r--r--synapse/rest/events.py34
-rw-r--r--synapse/state.py86
-rw-r--r--synapse/storage/pdu.py6
-rw-r--r--synapse/storage/state.py3
10 files changed, 226 insertions, 67 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index d1eca791ab..50ce7eb4cd 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -22,6 +22,7 @@ from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
 from synapse.api.events.room import (
     RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
     RoomJoinRulesEvent, RoomOpsPowerLevelsEvent, InviteJoinEvent,
+    RoomCreateEvent,
 )
 from synapse.util.logutils import log_function
 
@@ -59,6 +60,10 @@ class Auth(object):
 
                 is_state = hasattr(event, "state_key")
 
+                if event.type == RoomCreateEvent.TYPE:
+                    # FIXME
+                    defer.returnValue(True)
+
                 if event.type == RoomMemberEvent.TYPE:
                     yield self._can_replace_state(event)
                     allowed = yield self.is_membership_change_allowed(event)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index d482193851..8c7d510ef6 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -403,12 +403,19 @@ class ReplicationLayer(object):
             defer.returnValue(
                 (404, "No handler for Query type '%s'" % (query_type, ))
             )
+
     @defer.inlineCallbacks
     def on_make_join_request(self, context, user_id):
         pdu = yield self.handler.on_make_join_request(context, user_id)
         defer.returnValue(pdu.get_dict())
 
     @defer.inlineCallbacks
+    def on_invite_request(self, origin, content):
+        pdu = Pdu(**content)
+        ret_pdu = yield self.handler.on_send_join_request(origin, pdu)
+        defer.returnValue((200, ret_pdu.get_dict()))
+
+    @defer.inlineCallbacks
     def on_send_join_request(self, origin, content):
         pdu = Pdu(**content)
         state = yield self.handler.on_send_join_request(origin, pdu)
@@ -426,8 +433,9 @@ class ReplicationLayer(object):
 
         defer.returnValue(Pdu(**pdu_dict))
 
+    @defer.inlineCallbacks
     def send_join(self, destination, pdu):
-        return self.transport_layer.send_join(
+        _, content = yield self.transport_layer.send_join(
             destination,
             pdu.context,
             pdu.pdu_id,
@@ -435,6 +443,13 @@ class ReplicationLayer(object):
             pdu.get_dict(),
         )
 
+        logger.debug("Got content: %s", content)
+        pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])]
+        for pdu in pdus:
+            yield self._handle_new_pdu(destination, pdu)
+
+        defer.returnValue(pdus)
+
     @defer.inlineCallbacks
     @log_function
     def _get_persisted_pdu(self, pdu_id, pdu_origin):
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index a0d34fd24d..de64702e2f 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -229,13 +229,36 @@ class TransportLayer(object):
             pdu_id,
         )
 
-        response = yield self.client.put_json(
+        code, content = yield self.client.put_json(
             destination=destination,
             path=path,
             data=content,
         )
 
-        defer.returnValue(response)
+        if not 200 <= code < 300:
+            raise RuntimeError("Got %d from send_join", code)
+
+        defer.returnValue(json.loads(content))
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_invite(self, destination, context, pdu_id, origin, content):
+        path = PREFIX + "/invite/%s/%s/%s" % (
+            context,
+            origin,
+            pdu_id,
+        )
+
+        code, content = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+        )
+
+        if not 200 <= code < 300:
+            raise RuntimeError("Got %d from send_invite", code)
+
+        defer.returnValue(json.loads(content))
 
     @defer.inlineCallbacks
     def _authenticate_request(self, request):
@@ -297,9 +320,13 @@ class TransportLayer(object):
         @defer.inlineCallbacks
         def new_handler(request, *args, **kwargs):
             (origin, content) = yield self._authenticate_request(request)
-            response = yield handler(
-                origin, content, request.args, *args, **kwargs
-            )
+            try:
+                response = yield handler(
+                    origin, content, request.args, *args, **kwargs
+                )
+            except:
+                logger.exception("Callback failed")
+                raise
             defer.returnValue(response)
         return new_handler
 
@@ -419,6 +446,17 @@ class TransportLayer(object):
             )
         )
 
+        self.server.register_path(
+            "PUT",
+            re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, pdu_origin, pdu_id:
+                self._on_invite_request(
+                    origin, content, query,
+                )
+            )
+        )
+
     @defer.inlineCallbacks
     @log_function
     def _on_send_request(self, origin, content, query, transaction_id):
@@ -524,6 +562,15 @@ class TransportLayer(object):
 
         defer.returnValue((200, content))
 
+    @defer.inlineCallbacks
+    @log_function
+    def _on_invite_request(self, origin, content, query):
+        content = yield self.request_handler.on_invite_request(
+            origin, content,
+        )
+
+        defer.returnValue((200, content))
+
 
 class TransportReceivedHandler(object):
     """ Callbacks used when we receive a transaction
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0ae0541bd3..70790aaa72 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -62,6 +62,9 @@ class FederationHandler(BaseHandler):
 
         self.pdu_codec = PduCodec(hs)
 
+        # When joining a room we need to queue any events for that room up
+        self.room_queues = {}
+
     @log_function
     @defer.inlineCallbacks
     def handle_new_event(self, event, snapshot):
@@ -95,22 +98,25 @@ class FederationHandler(BaseHandler):
 
         logger.debug("Got event: %s", event.event_id)
 
+        if event.room_id in self.room_queues:
+            self.room_queues[event.room_id].append(pdu)
+            return
+
         if state:
             state = [self.pdu_codec.event_from_pdu(p) for p in state]
             state = {(e.type, e.state_key): e for e in state}
-        yield self.state_handler.annotate_state_groups(event, state=state)
+
+        is_new_state = yield self.state_handler.annotate_state_groups(
+            event,
+            state=state
+        )
 
         logger.debug("Event: %s", event)
 
         if not backfilled:
             yield self.auth.check(event, None, raises=True)
 
-        if event.is_state and not backfilled:
-            is_new_state = yield self.state_handler.handle_new_state(
-                pdu
-            )
-        else:
-            is_new_state = False
+        is_new_state = is_new_state and not backfilled
 
         # TODO: Implement something in federation that allows us to
         # respond to PDU.
@@ -211,6 +217,8 @@ class FederationHandler(BaseHandler):
         assert(event.state_key == joinee)
         assert(event.room_id == room_id)
 
+        self.room_queues[room_id] = []
+
         event.event_id = self.event_factory.create_event_id()
         event.content = content
 
@@ -219,15 +227,14 @@ class FederationHandler(BaseHandler):
             self.pdu_codec.pdu_from_event(event)
         )
 
-        # TODO (erikj): Time out here.
-        d = defer.Deferred()
-        self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d)
-        reactor.callLater(10, d.cancel)
+        state = [self.pdu_codec.event_from_pdu(p) for p in state]
 
-        try:
-            yield d
-        except defer.CancelledError:
-            raise SynapseError(500, "Unable to join remote room")
+        logger.debug("do_invite_join state: %s", state)
+
+        is_new_state = yield self.state_handler.annotate_state_groups(
+            event,
+            state=state
+        )
 
         try:
             yield self.store.store_room(
@@ -239,6 +246,32 @@ class FederationHandler(BaseHandler):
             # FIXME
             pass
 
+        for e in state:
+            # FIXME: Auth these.
+            is_new_state = yield self.state_handler.annotate_state_groups(
+                e,
+                state=state
+            )
+
+            yield self.store.persist_event(
+                e,
+                backfilled=False,
+                is_new_state=False
+            )
+
+        yield self.store.persist_event(
+            event,
+            backfilled=False,
+            is_new_state=is_new_state
+        )
+
+        room_queue = self.room_queues[room_id]
+        del self.room_queues[room_id]
+
+        for p in room_queue:
+            p.outlier = True
+            yield self.on_receive_pdu(p, backfilled=False)
+
         defer.returnValue(True)
 
     @defer.inlineCallbacks
@@ -264,13 +297,9 @@ class FederationHandler(BaseHandler):
     def on_send_join_request(self, origin, pdu):
         event = self.pdu_codec.event_from_pdu(pdu)
 
-        yield self.state_handler.annotate_state_groups(event)
+        is_new_state= yield self.state_handler.annotate_state_groups(event)
         yield self.auth.check(event, None, raises=True)
 
-        is_new_state = yield self.state_handler.handle_new_state(
-            pdu
-        )
-
         # FIXME (erikj):  All this is duplicated above :(
 
         yield self.store.persist_event(
@@ -303,7 +332,10 @@ class FederationHandler(BaseHandler):
 
         yield self.replication_layer.send_pdu(new_pdu)
 
-        defer.returnValue(event.state_events.values())
+        defer.returnValue([
+            self.pdu_codec.pdu_from_event(e)
+            for e in event.state_events.values()
+        ])
 
     @defer.inlineCallbacks
     def get_state_for_pdu(self, pdu_id, pdu_origin):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1c2cbce151..4aaf97a83e 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -199,7 +199,7 @@ class MessageHandler(BaseHandler):
                 raise RoomError(
                     403, "Member does not meet private room rules.")
 
-        data = yield self.store.get_current_state(
+        data = yield self.state_handler.get_current_state(
             room_id, event_type, state_key
         )
         defer.returnValue(data)
@@ -238,7 +238,7 @@ class MessageHandler(BaseHandler):
         yield self.auth.check_joined_room(room_id, user_id)
 
         # TODO: This is duplicating logic from snapshot_all_rooms
-        current_state = yield self.store.get_current_state(room_id)
+        current_state = yield self.state_handler.get_current_state(room_id)
         defer.returnValue([self.hs.serialize_event(c) for c in current_state])
 
     @defer.inlineCallbacks
@@ -315,7 +315,7 @@ class MessageHandler(BaseHandler):
                     "end": end_token.to_string(),
                 }
 
-                current_state = yield self.store.get_current_state(
+                current_state = yield self.state_handler.get_current_state(
                     event.room_id
                 )
                 d["state"] = [self.hs.serialize_event(c) for c in current_state]
diff --git a/synapse/rest/base.py b/synapse/rest/base.py
index 2e8e3fa7d4..dc784c1527 100644
--- a/synapse/rest/base.py
+++ b/synapse/rest/base.py
@@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX
 from synapse.rest.transactions import HttpTransactionStore
 import re
 
+import logging
+
+
+logger = logging.getLogger(__name__)
+
 
 def client_path_pattern(path_regex):
     """Creates a regex compiled client path with the correct client path
diff --git a/synapse/rest/events.py b/synapse/rest/events.py
index 097195d7cc..92ff5e5ca7 100644
--- a/synapse/rest/events.py
+++ b/synapse/rest/events.py
@@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError
 from synapse.streams.config import PaginationConfig
 from synapse.rest.base import RestServlet, client_path_pattern
 
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
 
 class EventStreamRestServlet(RestServlet):
     PATTERN = client_path_pattern("/events$")
@@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet):
     @defer.inlineCallbacks
     def on_GET(self, request):
         auth_user = yield self.auth.get_user_by_req(request)
-
-        handler = self.handlers.event_stream_handler
-        pagin_config = PaginationConfig.from_request(request)
-        timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
-        if "timeout" in request.args:
-            try:
-                timeout = int(request.args["timeout"][0])
-            except ValueError:
-                raise SynapseError(400, "timeout must be in milliseconds.")
-
-        chunk = yield handler.get_stream(auth_user.to_string(), pagin_config,
-                                         timeout=timeout)
+        try:
+            handler = self.handlers.event_stream_handler
+            pagin_config = PaginationConfig.from_request(request)
+            timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
+            if "timeout" in request.args:
+                try:
+                    timeout = int(request.args["timeout"][0])
+                except ValueError:
+                    raise SynapseError(400, "timeout must be in milliseconds.")
+
+            chunk = yield handler.get_stream(
+                auth_user.to_string(), pagin_config, timeout=timeout
+            )
+        except:
+            logger.exception("Event stream failed")
+            raise
 
         defer.returnValue((200, chunk))
 
diff --git a/synapse/state.py b/synapse/state.py
index 8c4eeb8924..24685c6fb4 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 
 from synapse.federation.pdu_codec import encode_event_id, decode_event_id
 from synapse.util.logutils import log_function
+from synapse.federation.pdu_codec import encode_event_id
 
 from collections import namedtuple
 
@@ -130,54 +131,89 @@ class StateHandler(object):
         defer.returnValue(is_new)
 
     @defer.inlineCallbacks
+    @log_function
     def annotate_state_groups(self, event, state=None):
         if state:
             event.state_group = None
             event.old_state_events = None
-            event.state_events = state
+            event.state_events = {(s.type, s.state_key): s for s in state}
+            defer.returnValue(False)
+            return
+
+        if hasattr(event, "outlier") and event.outlier:
+            event.state_group = None
+            event.old_state_events = None
+            event.state_events = None
+            defer.returnValue(False)
             return
 
+        new_state = yield self.resolve_state_groups(event.prev_events)
+
+        event.old_state_events = new_state
+
+        if hasattr(event, "state_key"):
+            new_state[(event.type, event.state_key)] = event
+
+        event.state_group = None
+        event.state_events = new_state
+
+        defer.returnValue(hasattr(event, "state_key"))
+
+    @defer.inlineCallbacks
+    def get_current_state(self, room_id, event_type=None, state_key=""):
+        # FIXME: HACK!
+        pdus = yield self.store.get_latest_pdus_in_context(room_id)
+
+        event_ids = [encode_event_id(p.pdu_id, p.origin) for p in pdus]
+
+        res = self.resolve_state_groups(event_ids)
+
+        if event_type:
+            defer.returnValue(res.get((event_type, state_key)))
+            return
+
+        defer.returnValue(res.values())
+
+    @defer.inlineCallbacks
+    @log_function
+    def resolve_state_groups(self, event_ids):
         state_groups = yield self.store.get_state_groups(
-            event.prev_events
+            event_ids
         )
 
         state = {}
-        state_sets = {}
         for group in state_groups:
             for s in group.state:
-                state.setdefault((s.type, s.state_key), []).append(s)
-
-                state_sets.setdefault(
+                state.setdefault(
                     (s.type, s.state_key),
-                    set()
-                ).add(s.event_id)
+                    {}
+                )[s.event_id] = s
 
         unconflicted_state = {
-            k: state[k].pop() for k, v in state_sets.items()
-            if len(v) == 1
+            k: v.values()[0] for k, v in state.items()
+            if len(v.values()) == 1
         }
 
         conflicted_state = {
-            k: state[k]
-            for k, v in state_sets.items()
-            if len(v) > 1
+            k: v.values()
+            for k, v in state.items()
+            if len(v.values()) > 1
         }
 
-        new_state = {}
-        new_state.update(unconflicted_state)
-        for key, events in conflicted_state.items():
-            new_state[key] = yield self.resolve(events)
+        try:
+            new_state = {}
+            new_state.update(unconflicted_state)
+            for key, events in conflicted_state.items():
+                new_state[key] = yield self._resolve_state_events(events)
+        except:
+            logger.exception("Failed to resolve state")
+            raise
 
-        event.old_state_events = new_state
-
-        if hasattr(event, "state_key"):
-            new_state[(event.type, event.state_key)] = event
-
-        event.state_group = None
-        event.state_events = new_state
+        defer.returnValue(new_state)
 
     @defer.inlineCallbacks
-    def resolve(self, events):
+    @log_function
+    def _resolve_state_events(self, events):
         curr_events = events
 
         new_powers_deferreds = []
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index d70467dcd6..b1cb0185a6 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -277,6 +277,12 @@ class PduStore(SQLBaseStore):
                 (context, depth)
             )
 
+    def get_latest_pdus_in_context(self, context):
+        return self.runInteraction(
+            self._get_latest_pdus_in_context,
+            context
+        )
+
     def _get_latest_pdus_in_context(self, txn, context):
         """Get's a list of the most current pdus for a given context. This is
         used when we are sending a Pdu and need to fill out the `prev_pdus`
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 9496c935a7..0aa979c9f0 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -63,6 +63,9 @@ class StateStore(SQLBaseStore):
         )
 
     def _store_state_groups_txn(self, txn, event):
+        if not event.state_events:
+            return
+
         state_group = event.state_group
         if not state_group:
             state_group = self._simple_insert_txn(