diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index d1eca791ab..50ce7eb4cd 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -22,6 +22,7 @@ from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
RoomJoinRulesEvent, RoomOpsPowerLevelsEvent, InviteJoinEvent,
+ RoomCreateEvent,
)
from synapse.util.logutils import log_function
@@ -59,6 +60,10 @@ class Auth(object):
is_state = hasattr(event, "state_key")
+ if event.type == RoomCreateEvent.TYPE:
+ # FIXME
+ defer.returnValue(True)
+
if event.type == RoomMemberEvent.TYPE:
yield self._can_replace_state(event)
allowed = yield self.is_membership_change_allowed(event)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index d482193851..8c7d510ef6 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -403,12 +403,19 @@ class ReplicationLayer(object):
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type, ))
)
+
@defer.inlineCallbacks
def on_make_join_request(self, context, user_id):
pdu = yield self.handler.on_make_join_request(context, user_id)
defer.returnValue(pdu.get_dict())
@defer.inlineCallbacks
+ def on_invite_request(self, origin, content):
+ pdu = Pdu(**content)
+ ret_pdu = yield self.handler.on_send_join_request(origin, pdu)
+ defer.returnValue((200, ret_pdu.get_dict()))
+
+ @defer.inlineCallbacks
def on_send_join_request(self, origin, content):
pdu = Pdu(**content)
state = yield self.handler.on_send_join_request(origin, pdu)
@@ -426,8 +433,9 @@ class ReplicationLayer(object):
defer.returnValue(Pdu(**pdu_dict))
+ @defer.inlineCallbacks
def send_join(self, destination, pdu):
- return self.transport_layer.send_join(
+ _, content = yield self.transport_layer.send_join(
destination,
pdu.context,
pdu.pdu_id,
@@ -435,6 +443,13 @@ class ReplicationLayer(object):
pdu.get_dict(),
)
+ logger.debug("Got content: %s", content)
+ pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])]
+ for pdu in pdus:
+ yield self._handle_new_pdu(destination, pdu)
+
+ defer.returnValue(pdus)
+
@defer.inlineCallbacks
@log_function
def _get_persisted_pdu(self, pdu_id, pdu_origin):
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index a0d34fd24d..de64702e2f 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -229,13 +229,36 @@ class TransportLayer(object):
pdu_id,
)
- response = yield self.client.put_json(
+ code, content = yield self.client.put_json(
destination=destination,
path=path,
data=content,
)
- defer.returnValue(response)
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_join", code)
+
+ defer.returnValue(json.loads(content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_invite(self, destination, context, pdu_id, origin, content):
+ path = PREFIX + "/invite/%s/%s/%s" % (
+ context,
+ origin,
+ pdu_id,
+ )
+
+ code, content = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_invite", code)
+
+ defer.returnValue(json.loads(content))
@defer.inlineCallbacks
def _authenticate_request(self, request):
@@ -297,9 +320,13 @@ class TransportLayer(object):
@defer.inlineCallbacks
def new_handler(request, *args, **kwargs):
(origin, content) = yield self._authenticate_request(request)
- response = yield handler(
- origin, content, request.args, *args, **kwargs
- )
+ try:
+ response = yield handler(
+ origin, content, request.args, *args, **kwargs
+ )
+ except:
+ logger.exception("Callback failed")
+ raise
defer.returnValue(response)
return new_handler
@@ -419,6 +446,17 @@ class TransportLayer(object):
)
)
+ self.server.register_path(
+ "PUT",
+ re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, pdu_origin, pdu_id:
+ self._on_invite_request(
+ origin, content, query,
+ )
+ )
+ )
+
@defer.inlineCallbacks
@log_function
def _on_send_request(self, origin, content, query, transaction_id):
@@ -524,6 +562,15 @@ class TransportLayer(object):
defer.returnValue((200, content))
+ @defer.inlineCallbacks
+ @log_function
+ def _on_invite_request(self, origin, content, query):
+ content = yield self.request_handler.on_invite_request(
+ origin, content,
+ )
+
+ defer.returnValue((200, content))
+
class TransportReceivedHandler(object):
""" Callbacks used when we receive a transaction
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0ae0541bd3..70790aaa72 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -62,6 +62,9 @@ class FederationHandler(BaseHandler):
self.pdu_codec = PduCodec(hs)
+ # When joining a room we need to queue any events for that room up
+ self.room_queues = {}
+
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event, snapshot):
@@ -95,22 +98,25 @@ class FederationHandler(BaseHandler):
logger.debug("Got event: %s", event.event_id)
+ if event.room_id in self.room_queues:
+ self.room_queues[event.room_id].append(pdu)
+ return
+
if state:
state = [self.pdu_codec.event_from_pdu(p) for p in state]
state = {(e.type, e.state_key): e for e in state}
- yield self.state_handler.annotate_state_groups(event, state=state)
+
+ is_new_state = yield self.state_handler.annotate_state_groups(
+ event,
+ state=state
+ )
logger.debug("Event: %s", event)
if not backfilled:
yield self.auth.check(event, None, raises=True)
- if event.is_state and not backfilled:
- is_new_state = yield self.state_handler.handle_new_state(
- pdu
- )
- else:
- is_new_state = False
+ is_new_state = is_new_state and not backfilled
# TODO: Implement something in federation that allows us to
# respond to PDU.
@@ -211,6 +217,8 @@ class FederationHandler(BaseHandler):
assert(event.state_key == joinee)
assert(event.room_id == room_id)
+ self.room_queues[room_id] = []
+
event.event_id = self.event_factory.create_event_id()
event.content = content
@@ -219,15 +227,14 @@ class FederationHandler(BaseHandler):
self.pdu_codec.pdu_from_event(event)
)
- # TODO (erikj): Time out here.
- d = defer.Deferred()
- self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d)
- reactor.callLater(10, d.cancel)
+ state = [self.pdu_codec.event_from_pdu(p) for p in state]
- try:
- yield d
- except defer.CancelledError:
- raise SynapseError(500, "Unable to join remote room")
+ logger.debug("do_invite_join state: %s", state)
+
+ is_new_state = yield self.state_handler.annotate_state_groups(
+ event,
+ state=state
+ )
try:
yield self.store.store_room(
@@ -239,6 +246,32 @@ class FederationHandler(BaseHandler):
# FIXME
pass
+ for e in state:
+ # FIXME: Auth these.
+ is_new_state = yield self.state_handler.annotate_state_groups(
+ e,
+ state=state
+ )
+
+ yield self.store.persist_event(
+ e,
+ backfilled=False,
+ is_new_state=False
+ )
+
+ yield self.store.persist_event(
+ event,
+ backfilled=False,
+ is_new_state=is_new_state
+ )
+
+ room_queue = self.room_queues[room_id]
+ del self.room_queues[room_id]
+
+ for p in room_queue:
+ p.outlier = True
+ yield self.on_receive_pdu(p, backfilled=False)
+
defer.returnValue(True)
@defer.inlineCallbacks
@@ -264,13 +297,9 @@ class FederationHandler(BaseHandler):
def on_send_join_request(self, origin, pdu):
event = self.pdu_codec.event_from_pdu(pdu)
- yield self.state_handler.annotate_state_groups(event)
+ is_new_state= yield self.state_handler.annotate_state_groups(event)
yield self.auth.check(event, None, raises=True)
- is_new_state = yield self.state_handler.handle_new_state(
- pdu
- )
-
# FIXME (erikj): All this is duplicated above :(
yield self.store.persist_event(
@@ -303,7 +332,10 @@ class FederationHandler(BaseHandler):
yield self.replication_layer.send_pdu(new_pdu)
- defer.returnValue(event.state_events.values())
+ defer.returnValue([
+ self.pdu_codec.pdu_from_event(e)
+ for e in event.state_events.values()
+ ])
@defer.inlineCallbacks
def get_state_for_pdu(self, pdu_id, pdu_origin):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1c2cbce151..4aaf97a83e 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -199,7 +199,7 @@ class MessageHandler(BaseHandler):
raise RoomError(
403, "Member does not meet private room rules.")
- data = yield self.store.get_current_state(
+ data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
defer.returnValue(data)
@@ -238,7 +238,7 @@ class MessageHandler(BaseHandler):
yield self.auth.check_joined_room(room_id, user_id)
# TODO: This is duplicating logic from snapshot_all_rooms
- current_state = yield self.store.get_current_state(room_id)
+ current_state = yield self.state_handler.get_current_state(room_id)
defer.returnValue([self.hs.serialize_event(c) for c in current_state])
@defer.inlineCallbacks
@@ -315,7 +315,7 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(),
}
- current_state = yield self.store.get_current_state(
+ current_state = yield self.state_handler.get_current_state(
event.room_id
)
d["state"] = [self.hs.serialize_event(c) for c in current_state]
diff --git a/synapse/rest/base.py b/synapse/rest/base.py
index 2e8e3fa7d4..dc784c1527 100644
--- a/synapse/rest/base.py
+++ b/synapse/rest/base.py
@@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX
from synapse.rest.transactions import HttpTransactionStore
import re
+import logging
+
+
+logger = logging.getLogger(__name__)
+
def client_path_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
diff --git a/synapse/rest/events.py b/synapse/rest/events.py
index 097195d7cc..92ff5e5ca7 100644
--- a/synapse/rest/events.py
+++ b/synapse/rest/events.py
@@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig
from synapse.rest.base import RestServlet, client_path_pattern
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
class EventStreamRestServlet(RestServlet):
PATTERN = client_path_pattern("/events$")
@@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request)
-
- handler = self.handlers.event_stream_handler
- pagin_config = PaginationConfig.from_request(request)
- timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
- if "timeout" in request.args:
- try:
- timeout = int(request.args["timeout"][0])
- except ValueError:
- raise SynapseError(400, "timeout must be in milliseconds.")
-
- chunk = yield handler.get_stream(auth_user.to_string(), pagin_config,
- timeout=timeout)
+ try:
+ handler = self.handlers.event_stream_handler
+ pagin_config = PaginationConfig.from_request(request)
+ timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
+ if "timeout" in request.args:
+ try:
+ timeout = int(request.args["timeout"][0])
+ except ValueError:
+ raise SynapseError(400, "timeout must be in milliseconds.")
+
+ chunk = yield handler.get_stream(
+ auth_user.to_string(), pagin_config, timeout=timeout
+ )
+ except:
+ logger.exception("Event stream failed")
+ raise
defer.returnValue((200, chunk))
diff --git a/synapse/state.py b/synapse/state.py
index 8c4eeb8924..24685c6fb4 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.federation.pdu_codec import encode_event_id, decode_event_id
from synapse.util.logutils import log_function
+from synapse.federation.pdu_codec import encode_event_id
from collections import namedtuple
@@ -130,54 +131,89 @@ class StateHandler(object):
defer.returnValue(is_new)
@defer.inlineCallbacks
+ @log_function
def annotate_state_groups(self, event, state=None):
if state:
event.state_group = None
event.old_state_events = None
- event.state_events = state
+ event.state_events = {(s.type, s.state_key): s for s in state}
+ defer.returnValue(False)
+ return
+
+ if hasattr(event, "outlier") and event.outlier:
+ event.state_group = None
+ event.old_state_events = None
+ event.state_events = None
+ defer.returnValue(False)
return
+ new_state = yield self.resolve_state_groups(event.prev_events)
+
+ event.old_state_events = new_state
+
+ if hasattr(event, "state_key"):
+ new_state[(event.type, event.state_key)] = event
+
+ event.state_group = None
+ event.state_events = new_state
+
+ defer.returnValue(hasattr(event, "state_key"))
+
+ @defer.inlineCallbacks
+ def get_current_state(self, room_id, event_type=None, state_key=""):
+ # FIXME: HACK!
+ pdus = yield self.store.get_latest_pdus_in_context(room_id)
+
+ event_ids = [encode_event_id(p.pdu_id, p.origin) for p in pdus]
+
+ res = self.resolve_state_groups(event_ids)
+
+ if event_type:
+ defer.returnValue(res.get((event_type, state_key)))
+ return
+
+ defer.returnValue(res.values())
+
+ @defer.inlineCallbacks
+ @log_function
+ def resolve_state_groups(self, event_ids):
state_groups = yield self.store.get_state_groups(
- event.prev_events
+ event_ids
)
state = {}
- state_sets = {}
for group in state_groups:
for s in group.state:
- state.setdefault((s.type, s.state_key), []).append(s)
-
- state_sets.setdefault(
+ state.setdefault(
(s.type, s.state_key),
- set()
- ).add(s.event_id)
+ {}
+ )[s.event_id] = s
unconflicted_state = {
- k: state[k].pop() for k, v in state_sets.items()
- if len(v) == 1
+ k: v.values()[0] for k, v in state.items()
+ if len(v.values()) == 1
}
conflicted_state = {
- k: state[k]
- for k, v in state_sets.items()
- if len(v) > 1
+ k: v.values()
+ for k, v in state.items()
+ if len(v.values()) > 1
}
- new_state = {}
- new_state.update(unconflicted_state)
- for key, events in conflicted_state.items():
- new_state[key] = yield self.resolve(events)
+ try:
+ new_state = {}
+ new_state.update(unconflicted_state)
+ for key, events in conflicted_state.items():
+ new_state[key] = yield self._resolve_state_events(events)
+ except:
+ logger.exception("Failed to resolve state")
+ raise
- event.old_state_events = new_state
-
- if hasattr(event, "state_key"):
- new_state[(event.type, event.state_key)] = event
-
- event.state_group = None
- event.state_events = new_state
+ defer.returnValue(new_state)
@defer.inlineCallbacks
- def resolve(self, events):
+ @log_function
+ def _resolve_state_events(self, events):
curr_events = events
new_powers_deferreds = []
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index d70467dcd6..b1cb0185a6 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -277,6 +277,12 @@ class PduStore(SQLBaseStore):
(context, depth)
)
+ def get_latest_pdus_in_context(self, context):
+ return self.runInteraction(
+ self._get_latest_pdus_in_context,
+ context
+ )
+
def _get_latest_pdus_in_context(self, txn, context):
"""Get's a list of the most current pdus for a given context. This is
used when we are sending a Pdu and need to fill out the `prev_pdus`
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 9496c935a7..0aa979c9f0 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -63,6 +63,9 @@ class StateStore(SQLBaseStore):
)
def _store_state_groups_txn(self, txn, event):
+ if not event.state_events:
+ return
+
state_group = event.state_group
if not state_group:
state_group = self._simple_insert_txn(
|