diff --git a/synapse/__init__.py b/synapse/__init__.py
new file mode 100644
index 0000000000..aa760fb341
--- /dev/null
+++ b/synapse/__init__.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" This is a reference implementation of a synapse home server.
+"""
diff --git a/synapse/api/__init__.py b/synapse/api/__init__.py
new file mode 100644
index 0000000000..fe8a073cd3
--- /dev/null
+++ b/synapse/api/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
new file mode 100644
index 0000000000..5c66a7261f
--- /dev/null
+++ b/synapse/api/auth.py
@@ -0,0 +1,164 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This module contains classes for authenticating the user."""
+from twisted.internet import defer
+
+from synapse.api.constants import Membership
+from synapse.api.errors import AuthError, StoreError
+from synapse.api.events.room import (RoomTopicEvent, RoomMemberEvent,
+ MessageEvent, FeedbackEvent)
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class Auth(object):
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def check(self, event, raises=False):
+ """ Checks if this event is correctly authed.
+
+ Returns:
+ True if the auth checks pass.
+ Raises:
+ AuthError if there was a problem authorising this event. This will
+ be raised only if raises=True.
+ """
+ try:
+ if event.type in [RoomTopicEvent.TYPE, MessageEvent.TYPE,
+ FeedbackEvent.TYPE]:
+ yield self.check_joined_room(event.room_id, event.user_id)
+ defer.returnValue(True)
+ elif event.type == RoomMemberEvent.TYPE:
+ allowed = yield self.is_membership_change_allowed(event)
+ defer.returnValue(allowed)
+ else:
+ raise AuthError(500, "Unknown event type %s" % event.type)
+ except AuthError as e:
+ logger.info("Event auth check failed on event %s with msg: %s",
+ event, e.msg)
+ if raises:
+ raise e
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def check_joined_room(self, room_id, user_id):
+ try:
+ member = yield self.store.get_room_member(
+ room_id=room_id,
+ user_id=user_id
+ )
+ if not member or member.membership != Membership.JOIN:
+ raise AuthError(403, "User %s not in room %s" %
+ (user_id, room_id))
+ defer.returnValue(member)
+ except AttributeError:
+ pass
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def is_membership_change_allowed(self, event):
+ # does this room even exist
+ room = yield self.store.get_room(event.room_id)
+ if not room:
+ raise AuthError(403, "Room does not exist")
+
+ # get info about the caller
+ try:
+ caller = yield self.store.get_room_member(
+ user_id=event.user_id,
+ room_id=event.room_id)
+ except:
+ caller = None
+ caller_in_room = caller and caller.membership == "join"
+
+ # get info about the target
+ try:
+ target = yield self.store.get_room_member(
+ user_id=event.target_user_id,
+ room_id=event.room_id)
+ except:
+ target = None
+ target_in_room = target and target.membership == "join"
+
+ membership = event.content["membership"]
+
+ if Membership.INVITE == membership:
+ # Invites are valid iff caller is in the room and target isn't.
+ if not caller_in_room: # caller isn't joined
+ raise AuthError(403, "You are not in room %s." % event.room_id)
+ elif target_in_room: # the target is already in the room.
+ raise AuthError(403, "%s is already in the room." %
+ event.target_user_id)
+ elif Membership.JOIN == membership:
+ # Joins are valid iff caller == target and they were:
+ # invited: They are accepting the invitation
+ # joined: It's a NOOP
+ if event.user_id != event.target_user_id:
+ raise AuthError(403, "Cannot force another user to join.")
+ elif room.is_public:
+ pass # anyone can join public rooms.
+ elif (not caller or caller.membership not in
+ [Membership.INVITE, Membership.JOIN]):
+ raise AuthError(403, "You are not invited to this room.")
+ elif Membership.LEAVE == membership:
+ if not caller_in_room: # trying to leave a room you aren't joined
+ raise AuthError(403, "You are not in room %s." % event.room_id)
+ elif event.target_user_id != event.user_id:
+ # trying to force another user to leave
+ raise AuthError(403, "Cannot force %s to leave." %
+ event.target_user_id)
+ else:
+ raise AuthError(500, "Unknown membership %s" % membership)
+
+ defer.returnValue(True)
+
+ def get_user_by_req(self, request):
+ """ Get a registered user's ID.
+
+ Args:
+ request - An HTTP request with an access_token query parameter.
+ Returns:
+ UserID : User ID object of the user making the request
+ Raises:
+ AuthError if no user by that token exists or the token is invalid.
+ """
+ # Can optionally look elsewhere in the request (e.g. headers)
+ try:
+ return self.get_user_by_token(request.args["access_token"][0])
+ except KeyError:
+ raise AuthError(403, "Missing access token.")
+
+ @defer.inlineCallbacks
+ def get_user_by_token(self, token):
+ """ Get a registered user's ID.
+
+ Args:
+ token (str)- The access token to get the user by.
+ Returns:
+ UserID : User ID object of the user who has that access token.
+ Raises:
+ AuthError if no user by that token exists or the token is invalid.
+ """
+ try:
+ user_id = yield self.store.get_user_by_token(token=token)
+ defer.returnValue(self.hs.parse_userid(user_id))
+ except StoreError:
+ raise AuthError(403, "Unrecognised access token.")
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
new file mode 100644
index 0000000000..37bf41bfb3
--- /dev/null
+++ b/synapse/api/constants.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains constants from the specification."""
+
+
+class Membership(object):
+
+ """Represents the membership states of a user in a room."""
+ INVITE = u"invite"
+ JOIN = u"join"
+ KNOCK = u"knock"
+ LEAVE = u"leave"
+
+
+class Feedback(object):
+
+ """Represents the types of feedback a user can send in response to a
+ message."""
+
+ DELIVERED = u"d"
+ READ = u"r"
+ LIST = (DELIVERED, READ)
+
+
+class PresenceState(object):
+ """Represents the presence state of a user."""
+ OFFLINE = 0
+ BUSY = 1
+ ONLINE = 2
+ FREE_FOR_CHAT = 3
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
new file mode 100644
index 0000000000..7ad4d636c2
--- /dev/null
+++ b/synapse/api/errors.py
@@ -0,0 +1,114 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains exceptions and error codes."""
+
+import logging
+
+
+class Codes(object):
+ FORBIDDEN = "M_FORBIDDEN"
+ BAD_JSON = "M_BAD_JSON"
+ NOT_JSON = "M_NOT_JSON"
+ USER_IN_USE = "M_USER_IN_USE"
+ ROOM_IN_USE = "M_ROOM_IN_USE"
+ BAD_PAGINATION = "M_BAD_PAGINATION"
+ UNKNOWN = "M_UNKNOWN"
+ NOT_FOUND = "M_NOT_FOUND"
+
+
+class CodeMessageException(Exception):
+ """An exception with integer code and message string attributes."""
+
+ def __init__(self, code, msg):
+ logging.error("%s: %s, %s", type(self).__name__, code, msg)
+ super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
+ self.code = code
+ self.msg = msg
+
+
+class SynapseError(CodeMessageException):
+ """A base error which can be caught for all synapse events."""
+ def __init__(self, code, msg, errcode=""):
+ """Constructs a synapse error.
+
+ Args:
+ code (int): The integer error code (typically an HTTP response code)
+ msg (str): The human-readable error message.
+ err (str): The error code e.g 'M_FORBIDDEN'
+ """
+ super(SynapseError, self).__init__(code, msg)
+ self.errcode = errcode
+
+
+class RoomError(SynapseError):
+ """An error raised when a room event fails."""
+ pass
+
+
+class RegistrationError(SynapseError):
+ """An error raised when a registration event fails."""
+ pass
+
+
+class AuthError(SynapseError):
+ """An error raised when there was a problem authorising an event."""
+
+ def __init__(self, *args, **kwargs):
+ if "errcode" not in kwargs:
+ kwargs["errcode"] = Codes.FORBIDDEN
+ super(AuthError, self).__init__(*args, **kwargs)
+
+
+class EventStreamError(SynapseError):
+ """An error raised when there a problem with the event stream."""
+ pass
+
+
+class LoginError(SynapseError):
+ """An error raised when there was a problem logging in."""
+ pass
+
+
+class StoreError(SynapseError):
+ """An error raised when there was a problem storing some data."""
+ pass
+
+
+def cs_exception(exception):
+ if isinstance(exception, SynapseError):
+ return cs_error(
+ exception.msg,
+ Codes.UNKNOWN if not exception.errcode else exception.errcode)
+ elif isinstance(exception, CodeMessageException):
+ return cs_error(exception.msg)
+ else:
+ logging.error("Unknown exception type: %s", type(exception))
+
+
+def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
+ """ Utility method for constructing an error response for client-server
+ interactions.
+
+ Args:
+ msg (str): The error message.
+ code (int): The error code.
+ kwargs : Additional keys to add to the response.
+ Returns:
+ A dict representing the error response JSON.
+ """
+ err = {"error": msg, "errcode": code}
+ for key, value in kwargs.iteritems():
+ err[key] = value
+ return err
diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py
new file mode 100644
index 0000000000..bc2daf3361
--- /dev/null
+++ b/synapse/api/events/__init__.py
@@ -0,0 +1,152 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.errors import SynapseError, Codes
+from synapse.util.jsonobject import JsonEncodedObject
+
+
+class SynapseEvent(JsonEncodedObject):
+
+ """Base class for Synapse events. These are JSON objects which must abide
+ by a certain well-defined structure.
+ """
+
+ # Attributes that are currently assumed by the federation side:
+ # Mandatory:
+ # - event_id
+ # - room_id
+ # - type
+ # - is_state
+ #
+ # Optional:
+ # - state_key (mandatory when is_state is True)
+ # - prev_events (these can be filled out by the federation layer itself.)
+ # - prev_state
+
+ valid_keys = [
+ "event_id",
+ "type",
+ "room_id",
+ "user_id", # sender/initiator
+ "content", # HTTP body, JSON
+ ]
+
+ internal_keys = [
+ "is_state",
+ "state_key",
+ "prev_events",
+ "prev_state",
+ "depth",
+ "destinations",
+ "origin",
+ ]
+
+ required_keys = [
+ "event_id",
+ "room_id",
+ "content",
+ ]
+
+ def __init__(self, raises=True, **kwargs):
+ super(SynapseEvent, self).__init__(**kwargs)
+ if "content" in kwargs:
+ self.check_json(self.content, raises=raises)
+
+ def get_content_template(self):
+ """ Retrieve the JSON template for this event as a dict.
+
+ The template must be a dict representing the JSON to match. Only
+ required keys should be present. The values of the keys in the template
+ are checked via type() to the values of the same keys in the actual
+ event JSON.
+
+ NB: If loading content via json.loads, you MUST define strings as
+ unicode.
+
+ For example:
+ Content:
+ {
+ "name": u"bob",
+ "age": 18,
+ "friends": [u"mike", u"jill"]
+ }
+ Template:
+ {
+ "name": u"string",
+ "age": 0,
+ "friends": [u"string"]
+ }
+ The values "string" and 0 could be anything, so long as the types
+ are the same as the content.
+ """
+ raise NotImplementedError("get_content_template not implemented.")
+
+ def check_json(self, content, raises=True):
+ """Checks the given JSON content abides by the rules of the template.
+
+ Args:
+ content : A JSON object to check.
+ raises: True to raise a SynapseError if the check fails.
+ Returns:
+ True if the content passes the template. Returns False if the check
+ fails and raises=False.
+ Raises:
+ SynapseError if the check fails and raises=True.
+ """
+ # recursively call to inspect each layer
+ err_msg = self._check_json(content, self.get_content_template())
+ if err_msg:
+ if raises:
+ raise SynapseError(400, err_msg, Codes.BAD_JSON)
+ else:
+ return False
+ else:
+ return True
+
+ def _check_json(self, content, template):
+ """Check content and template matches.
+
+ If the template is a dict, each key in the dict will be validated with
+ the content, else it will just compare the types of content and
+ template. This basic type check is required because this function will
+ be recursively called and could be called with just strs or ints.
+
+ Args:
+ content: The content to validate.
+ template: The validation template.
+ Returns:
+ str: An error message if the validation fails, else None.
+ """
+ if type(content) != type(template):
+ return "Mismatched types: %s" % template
+
+ if type(template) == dict:
+ for key in template:
+ if key not in content:
+ return "Missing %s key" % key
+
+ if type(content[key]) != type(template[key]):
+ return "Key %s is of the wrong type." % key
+
+ if type(content[key]) == dict:
+ # we must go deeper
+ msg = self._check_json(content[key], template[key])
+ if msg:
+ return msg
+ elif type(content[key]) == list:
+ # make sure each item type in content matches the template
+ for entry in content[key]:
+ msg = self._check_json(entry, template[key][0])
+ if msg:
+ return msg
diff --git a/synapse/api/events/factory.py b/synapse/api/events/factory.py
new file mode 100644
index 0000000000..ea7afa234e
--- /dev/null
+++ b/synapse/api/events/factory.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.events.room import (
+ RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent,
+ InviteJoinEvent, RoomConfigEvent
+)
+
+from synapse.util.stringutils import random_string
+
+
+class EventFactory(object):
+
+ _event_classes = [
+ RoomTopicEvent,
+ MessageEvent,
+ RoomMemberEvent,
+ FeedbackEvent,
+ InviteJoinEvent,
+ RoomConfigEvent
+ ]
+
+ def __init__(self):
+ self._event_list = {} # dict of TYPE to event class
+ for event_class in EventFactory._event_classes:
+ self._event_list[event_class.TYPE] = event_class
+
+ def create_event(self, etype=None, **kwargs):
+ kwargs["type"] = etype
+ if "event_id" not in kwargs:
+ kwargs["event_id"] = random_string(10)
+
+ try:
+ handler = self._event_list[etype]
+ except KeyError: # unknown event type
+ # TODO allow custom event types.
+ raise NotImplementedError("Unknown etype=%s" % etype)
+
+ return handler(**kwargs)
diff --git a/synapse/api/events/room.py b/synapse/api/events/room.py
new file mode 100644
index 0000000000..b31cd19f4b
--- /dev/null
+++ b/synapse/api/events/room.py
@@ -0,0 +1,99 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from . import SynapseEvent
+
+
+class RoomTopicEvent(SynapseEvent):
+ TYPE = "m.room.topic"
+
+ def __init__(self, **kwargs):
+ kwargs["state_key"] = ""
+ super(RoomTopicEvent, self).__init__(**kwargs)
+
+ def get_content_template(self):
+ return {"topic": u"string"}
+
+
+class RoomMemberEvent(SynapseEvent):
+ TYPE = "m.room.member"
+
+ valid_keys = SynapseEvent.valid_keys + [
+ "target_user_id", # target
+ "membership", # action
+ ]
+
+ def __init__(self, **kwargs):
+ if "target_user_id" in kwargs:
+ kwargs["state_key"] = kwargs["target_user_id"]
+ super(RoomMemberEvent, self).__init__(**kwargs)
+
+ def get_content_template(self):
+ return {"membership": u"string"}
+
+
+class MessageEvent(SynapseEvent):
+ TYPE = "m.room.message"
+
+ valid_keys = SynapseEvent.valid_keys + [
+ "msg_id", # unique per room + user combo
+ ]
+
+ def __init__(self, **kwargs):
+ super(MessageEvent, self).__init__(**kwargs)
+
+ def get_content_template(self):
+ return {"msgtype": u"string"}
+
+
+class FeedbackEvent(SynapseEvent):
+ TYPE = "m.room.message.feedback"
+
+ valid_keys = SynapseEvent.valid_keys + [
+ "msg_id", # the message ID being acknowledged
+ "msg_sender_id", # person who is sending the feedback is 'user_id'
+ "feedback_type", # the type of feedback (delivery, read, etc)
+ ]
+
+ def __init__(self, **kwargs):
+ super(FeedbackEvent, self).__init__(**kwargs)
+
+ def get_content_template(self):
+ return {}
+
+
+class InviteJoinEvent(SynapseEvent):
+ TYPE = "m.room.invite_join"
+
+ valid_keys = SynapseEvent.valid_keys + [
+ "target_user_id",
+ "target_host",
+ ]
+
+ def __init__(self, **kwargs):
+ super(InviteJoinEvent, self).__init__(**kwargs)
+
+ def get_content_template(self):
+ return {}
+
+
+class RoomConfigEvent(SynapseEvent):
+ TYPE = "m.room.config"
+
+ def __init__(self, **kwargs):
+ kwargs["state_key"] = ""
+ super(RoomConfigEvent, self).__init__(**kwargs)
+
+ def get_content_template(self):
+ return {}
diff --git a/synapse/api/notifier.py b/synapse/api/notifier.py
new file mode 100644
index 0000000000..974f7f0ba0
--- /dev/null
+++ b/synapse/api/notifier.py
@@ -0,0 +1,186 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.constants import Membership
+from synapse.api.events.room import RoomMemberEvent
+
+from twisted.internet import defer
+from twisted.internet import reactor
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class Notifier(object):
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.hs = hs
+ self.stored_event_listeners = {}
+
+ @defer.inlineCallbacks
+ def on_new_room_event(self, event, store_id):
+ """Called when there is a new room event which may potentially be sent
+ down listening users' event streams.
+
+ This function looks for interested *users* who may want to be notified
+ for this event. This is different to users requesting from the event
+ stream which looks for interested *events* for this user.
+
+ Args:
+ event (SynapseEvent): The new event, which must have a room_id
+ store_id (int): The ID of this event after it was stored with the
+ data store.
+ '"""
+ member_list = yield self.store.get_room_members(room_id=event.room_id,
+ membership="join")
+ if not member_list:
+ member_list = []
+
+ member_list = [u.user_id for u in member_list]
+
+ # invites MUST prod the person being invited, who won't be in the room.
+ if (event.type == RoomMemberEvent.TYPE and
+ event.content["membership"] == Membership.INVITE):
+ member_list.append(event.target_user_id)
+
+ for user_id in member_list:
+ if user_id in self.stored_event_listeners:
+ self._notify_and_callback(
+ user_id=user_id,
+ event_data=event.get_dict(),
+ stream_type=event.type,
+ store_id=store_id)
+
+ def on_new_user_event(self, user_id, event_data, stream_type, store_id):
+ if user_id in self.stored_event_listeners:
+ self._notify_and_callback(
+ user_id=user_id,
+ event_data=event_data,
+ stream_type=stream_type,
+ store_id=store_id
+ )
+
+ def _notify_and_callback(self, user_id, event_data, stream_type, store_id):
+ logger.debug(
+ "Notifying %s of a new event.",
+ user_id
+ )
+
+ stream_ids = list(self.stored_event_listeners[user_id])
+ for stream_id in stream_ids:
+ self._notify_and_callback_stream(user_id, stream_id, event_data,
+ stream_type, store_id)
+
+ if not self.stored_event_listeners[user_id]:
+ del self.stored_event_listeners[user_id]
+
+ def _notify_and_callback_stream(self, user_id, stream_id, event_data,
+ stream_type, store_id):
+
+ event_listener = self.stored_event_listeners[user_id].pop(stream_id)
+ return_event_object = {
+ k: event_listener[k] for k in ["start", "chunk", "end"]
+ }
+
+ # work out the new end token
+ token = event_listener["start"]
+ end = self._next_token(stream_type, store_id, token)
+ return_event_object["end"] = end
+
+ # add the event to the chunk
+ chunk = event_listener["chunk"]
+ chunk.append(event_data)
+
+ # callback the defer. We know this can't have been resolved before as
+ # we always remove the event_listener from the map before resolving.
+ event_listener["defer"].callback(return_event_object)
+
+ def _next_token(self, stream_type, store_id, current_token):
+ stream_handler = self.hs.get_handlers().event_stream_handler
+ return stream_handler.get_event_stream_token(
+ stream_type,
+ store_id,
+ current_token
+ )
+
+ def store_events_for(self, user_id=None, stream_id=None, from_tok=None):
+ """Store all incoming events for this user. This should be paired with
+ get_events_for to return chunked data.
+
+ Args:
+ user_id (str): The user to monitor incoming events for.
+ stream (object): The stream that is receiving events
+ from_tok (str): The token to monitor incoming events from.
+ """
+ event_listener = {
+ "start": from_tok,
+ "chunk": [],
+ "end": from_tok,
+ "defer": defer.Deferred(),
+ }
+
+ if user_id not in self.stored_event_listeners:
+ self.stored_event_listeners[user_id] = {stream_id: event_listener}
+ else:
+ self.stored_event_listeners[user_id][stream_id] = event_listener
+
+ def purge_events_for(self, user_id=None, stream_id=None):
+ """Purges any stored events for this user.
+
+ Args:
+ user_id (str): The user to purge stored events for.
+ """
+ try:
+ del self.stored_event_listeners[user_id][stream_id]
+ if not self.stored_event_listeners[user_id]:
+ del self.stored_event_listeners[user_id]
+ except KeyError:
+ pass
+
+ def get_events_for(self, user_id=None, stream_id=None, timeout=0):
+ """Retrieve stored events for this user, waiting if necessary.
+
+ It is advisable to wrap this call in a maybeDeferred.
+
+ Args:
+ user_id (str): The user to get events for.
+ timeout (int): The time in seconds to wait before giving up.
+ Returns:
+ A Deferred or a dict containing the chunk data, depending on if
+ there was data to return yet. The Deferred callback may be None if
+ there were no events before the timeout expired.
+ """
+ logger.debug("%s is listening for events.", user_id)
+
+ if len(self.stored_event_listeners[user_id][stream_id]["chunk"]) > 0:
+ logger.debug("%s returning existing chunk.", user_id)
+ return self.stored_event_listeners[user_id][stream_id]
+
+ reactor.callLater(
+ (timeout / 1000.0), self._timeout, user_id, stream_id
+ )
+ return self.stored_event_listeners[user_id][stream_id]["defer"]
+
+ def _timeout(self, user_id, stream_id):
+ try:
+ # We remove the event_listener from the map so that we can't
+ # resolve the deferred twice.
+ event_listeners = self.stored_event_listeners[user_id]
+ event_listener = event_listeners.pop(stream_id)
+ event_listener["defer"].callback(None)
+ logger.debug("%s event listening timed out.", user_id)
+ except KeyError:
+ pass
diff --git a/synapse/api/streams/__init__.py b/synapse/api/streams/__init__.py
new file mode 100644
index 0000000000..08137c1e79
--- /dev/null
+++ b/synapse/api/streams/__init__.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.errors import SynapseError
+
+
+class PaginationConfig(object):
+
+ """A configuration object which stores pagination parameters."""
+
+ def __init__(self, from_tok=None, to_tok=None, limit=0):
+ self.from_tok = from_tok
+ self.to_tok = to_tok
+ self.limit = limit
+
+ @classmethod
+ def from_request(cls, request, raise_invalid_params=True):
+ params = {
+ "from_tok": PaginationStream.TOK_START,
+ "to_tok": PaginationStream.TOK_END,
+ "limit": 0
+ }
+
+ query_param_mappings = [ # 3-tuple of qp_key, attribute, rules
+ ("from", "from_tok", lambda x: type(x) == str),
+ ("to", "to_tok", lambda x: type(x) == str),
+ ("limit", "limit", lambda x: x.isdigit())
+ ]
+
+ for qp, attr, is_valid in query_param_mappings:
+ if qp in request.args:
+ if is_valid(request.args[qp][0]):
+ params[attr] = request.args[qp][0]
+ elif raise_invalid_params:
+ raise SynapseError(400, "%s parameter is invalid." % qp)
+
+ return PaginationConfig(**params)
+
+
+class PaginationStream(object):
+
+ """ An interface for streaming data as chunks. """
+
+ TOK_START = "START"
+ TOK_END = "END"
+
+ def get_chunk(self, config=None):
+ """ Return the next chunk in the stream.
+
+ Args:
+ config (PaginationConfig): The config to aid which chunk to get.
+ Returns:
+ A dict containing the new start token "start", the new end token
+ "end" and the data "chunk" as a list.
+ """
+ raise NotImplementedError()
+
+
+class StreamData(object):
+
+ """ An interface for obtaining streaming data from a table. """
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+
+ def get_rows(self, user_id, from_pkey, to_pkey, limit):
+ """ Get event stream data between the specified pkeys.
+
+ Args:
+ user_id : The user's ID
+ from_pkey : The starting pkey.
+ to_pkey : The end pkey. May be -1 to mean "latest".
+ limit: The max number of results to return.
+ Returns:
+ A tuple containing the list of event stream data and the last pkey.
+ """
+ raise NotImplementedError()
+
+ def max_token(self):
+ """ Get the latest currently-valid token.
+
+ Returns:
+ The latest token."""
+ raise NotImplementedError()
diff --git a/synapse/api/streams/event.py b/synapse/api/streams/event.py
new file mode 100644
index 0000000000..0cc1a3e36a
--- /dev/null
+++ b/synapse/api/streams/event.py
@@ -0,0 +1,247 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This module contains classes for streaming from the event stream: /events.
+"""
+from twisted.internet import defer
+
+from synapse.api.errors import EventStreamError
+from synapse.api.events.room import (
+ RoomMemberEvent, MessageEvent, FeedbackEvent, RoomTopicEvent
+)
+from synapse.api.streams import PaginationStream, StreamData
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class MessagesStreamData(StreamData):
+ EVENT_TYPE = MessageEvent.TYPE
+
+ def __init__(self, hs, room_id=None, feedback=False):
+ super(MessagesStreamData, self).__init__(hs)
+ self.room_id = room_id
+ self.with_feedback = feedback
+
+ @defer.inlineCallbacks
+ def get_rows(self, user_id, from_key, to_key, limit):
+ (data, latest_ver) = yield self.store.get_message_stream(
+ user_id=user_id,
+ from_key=from_key,
+ to_key=to_key,
+ limit=limit,
+ room_id=self.room_id,
+ with_feedback=self.with_feedback
+ )
+ defer.returnValue((data, latest_ver))
+
+ @defer.inlineCallbacks
+ def max_token(self):
+ val = yield self.store.get_max_message_id()
+ defer.returnValue(val)
+
+
+class RoomMemberStreamData(StreamData):
+ EVENT_TYPE = RoomMemberEvent.TYPE
+
+ @defer.inlineCallbacks
+ def get_rows(self, user_id, from_key, to_key, limit):
+ (data, latest_ver) = yield self.store.get_room_member_stream(
+ user_id=user_id,
+ from_key=from_key,
+ to_key=to_key
+ )
+
+ defer.returnValue((data, latest_ver))
+
+ @defer.inlineCallbacks
+ def max_token(self):
+ val = yield self.store.get_max_room_member_id()
+ defer.returnValue(val)
+
+
+class FeedbackStreamData(StreamData):
+ EVENT_TYPE = FeedbackEvent.TYPE
+
+ def __init__(self, hs, room_id=None):
+ super(FeedbackStreamData, self).__init__(hs)
+ self.room_id = room_id
+
+ @defer.inlineCallbacks
+ def get_rows(self, user_id, from_key, to_key, limit):
+ (data, latest_ver) = yield self.store.get_feedback_stream(
+ user_id=user_id,
+ from_key=from_key,
+ to_key=to_key,
+ limit=limit,
+ room_id=self.room_id
+ )
+ defer.returnValue((data, latest_ver))
+
+ @defer.inlineCallbacks
+ def max_token(self):
+ val = yield self.store.get_max_feedback_id()
+ defer.returnValue(val)
+
+
+class RoomDataStreamData(StreamData):
+ EVENT_TYPE = RoomTopicEvent.TYPE # TODO need multiple event types
+
+ def __init__(self, hs, room_id=None):
+ super(RoomDataStreamData, self).__init__(hs)
+ self.room_id = room_id
+
+ @defer.inlineCallbacks
+ def get_rows(self, user_id, from_key, to_key, limit):
+ (data, latest_ver) = yield self.store.get_room_data_stream(
+ user_id=user_id,
+ from_key=from_key,
+ to_key=to_key,
+ limit=limit,
+ room_id=self.room_id
+ )
+ defer.returnValue((data, latest_ver))
+
+ @defer.inlineCallbacks
+ def max_token(self):
+ val = yield self.store.get_max_room_data_id()
+ defer.returnValue(val)
+
+
+class EventStream(PaginationStream):
+
+ SEPARATOR = '_'
+
+ def __init__(self, user_id, stream_data_list):
+ super(EventStream, self).__init__()
+ self.user_id = user_id
+ self.stream_data = stream_data_list
+
+ @defer.inlineCallbacks
+ def fix_tokens(self, pagination_config):
+ pagination_config.from_tok = yield self.fix_token(
+ pagination_config.from_tok)
+ pagination_config.to_tok = yield self.fix_token(
+ pagination_config.to_tok)
+ defer.returnValue(pagination_config)
+
+ @defer.inlineCallbacks
+ def fix_token(self, token):
+ """Fixes unknown values in a token to known values.
+
+ Args:
+ token (str): The token to fix up.
+ Returns:
+ The fixed-up token, which may == token.
+ """
+ # replace TOK_START and TOK_END with 0_0_0 or -1_-1_-1 depending.
+ replacements = [
+ (PaginationStream.TOK_START, "0"),
+ (PaginationStream.TOK_END, "-1")
+ ]
+ for magic_token, key in replacements:
+ if magic_token == token:
+ token = EventStream.SEPARATOR.join(
+ [key] * len(self.stream_data)
+ )
+
+ # replace -1 values with an actual pkey
+ token_segments = self._split_token(token)
+ for i, tok in enumerate(token_segments):
+ if tok == -1:
+ # add 1 to the max token because results are EXCLUSIVE from the
+ # latest version.
+ token_segments[i] = 1 + (yield self.stream_data[i].max_token())
+ defer.returnValue(EventStream.SEPARATOR.join(
+ str(x) for x in token_segments
+ ))
+
+ @defer.inlineCallbacks
+ def get_chunk(self, config=None):
+ # no support for limit on >1 streams, makes no sense.
+ if config.limit and len(self.stream_data) > 1:
+ raise EventStreamError(
+ 400, "Limit not supported on multiplexed streams."
+ )
+
+ (chunk_data, next_tok) = yield self._get_chunk_data(config.from_tok,
+ config.to_tok,
+ config.limit)
+
+ defer.returnValue({
+ "chunk": chunk_data,
+ "start": config.from_tok,
+ "end": next_tok
+ })
+
+ @defer.inlineCallbacks
+ def _get_chunk_data(self, from_tok, to_tok, limit):
+ """ Get event data between the two tokens.
+
+ Tokens are SEPARATOR separated values representing pkey values of
+ certain tables, and the position determines the StreamData invoked
+ according to the STREAM_DATA list.
+
+ The magic value '-1' can be used to get the latest value.
+
+ Args:
+ from_tok - The token to start from.
+ to_tok - The token to end at. Must have values > from_tok or be -1.
+ Returns:
+ A list of event data.
+ Raises:
+ EventStreamError if something went wrong.
+ """
+ # sanity check
+ if (from_tok.count(EventStream.SEPARATOR) !=
+ to_tok.count(EventStream.SEPARATOR) or
+ (from_tok.count(EventStream.SEPARATOR) + 1) !=
+ len(self.stream_data)):
+ raise EventStreamError(400, "Token lengths don't match.")
+
+ chunk = []
+ next_ver = []
+ for i, (from_pkey, to_pkey) in enumerate(zip(
+ self._split_token(from_tok),
+ self._split_token(to_tok)
+ )):
+ if from_pkey == to_pkey:
+ # tokens are the same, we have nothing to do.
+ next_ver.append(str(to_pkey))
+ continue
+
+ (event_chunk, max_pkey) = yield self.stream_data[i].get_rows(
+ self.user_id, from_pkey, to_pkey, limit
+ )
+
+ chunk += event_chunk
+ next_ver.append(str(max_pkey))
+
+ defer.returnValue((chunk, EventStream.SEPARATOR.join(next_ver)))
+
+ def _split_token(self, token):
+ """Splits the given token into a list of pkeys.
+
+ Args:
+ token (str): The token with SEPARATOR values.
+ Returns:
+ A list of ints.
+ """
+ segments = token.split(EventStream.SEPARATOR)
+ try:
+ int_segments = [int(x) for x in segments]
+ except ValueError:
+ raise EventStreamError(400, "Bad token: %s" % token)
+ return int_segments
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
new file mode 100644
index 0000000000..fe8a073cd3
--- /dev/null
+++ b/synapse/app/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
new file mode 100644
index 0000000000..5708b3ad95
--- /dev/null
+++ b/synapse/app/homeserver.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#!/usr/bin/env python
+
+from synapse.storage import read_schema
+
+from synapse.server import HomeServer
+
+from twisted.internet import reactor
+from twisted.enterprise import adbapi
+from twisted.python.log import PythonLoggingObserver
+from synapse.http.server import TwistedHttpServer
+from synapse.http.client import TwistedHttpClient
+
+from daemonize import Daemonize
+
+import argparse
+import logging
+import logging.config
+import sqlite3
+
+logger = logging.getLogger(__name__)
+
+
+class SynapseHomeServer(HomeServer):
+ def build_http_server(self):
+ return TwistedHttpServer()
+
+ def build_http_client(self):
+ return TwistedHttpClient()
+
+ def build_db_pool(self):
+ """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
+ don't have to worry about overwriting existing content.
+ """
+ logging.info("Preparing database: %s...", self.db_name)
+ pool = adbapi.ConnectionPool(
+ 'sqlite3', self.db_name, check_same_thread=False,
+ cp_min=1, cp_max=1)
+
+ schemas = [
+ "transactions",
+ "pdu",
+ "users",
+ "profiles",
+ "presence",
+ "im",
+ "room_aliases",
+ ]
+
+ for sql_loc in schemas:
+ sql_script = read_schema(sql_loc)
+
+ with sqlite3.connect(self.db_name) as db_conn:
+ c = db_conn.cursor()
+ c.executescript(sql_script)
+ c.close()
+ db_conn.commit()
+
+ logging.info("Database prepared in %s.", self.db_name)
+
+ return pool
+
+
+def setup_logging(verbosity=0, filename=None, config_path=None):
+ """ Sets up logging with verbosity levels.
+
+ Args:
+ verbosity: The verbosity level.
+ filename: Log to the given file rather than to the console.
+ config_path: Path to a python logging config file.
+ """
+
+ if config_path is None:
+ log_format = (
+ '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s'
+ )
+
+ if not verbosity or verbosity == 0:
+ level = logging.WARNING
+ elif verbosity == 1:
+ level = logging.INFO
+ else:
+ level = logging.DEBUG
+
+ logging.basicConfig(level=level, filename=filename, format=log_format)
+ else:
+ logging.config.fileConfig(config_path)
+
+ observer = PythonLoggingObserver()
+ observer.start()
+
+
+def run():
+ reactor.run()
+
+
+def setup():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-p", "--port", dest="port", type=int, default=8080,
+ help="The port to listen on.")
+ parser.add_argument("-d", "--database", dest="db", default="homeserver.db",
+ help="The database name.")
+ parser.add_argument("-H", "--host", dest="host", default="localhost",
+ help="The hostname of the server.")
+ parser.add_argument('-v', '--verbose', dest="verbose", action='count',
+ help="The verbosity level.")
+ parser.add_argument('-f', '--log-file', dest="log_file", default=None,
+ help="File to log to.")
+ parser.add_argument('--log-config', dest="log_config", default=None,
+ help="Python logging config")
+ parser.add_argument('-D', '--daemonize', action='store_true',
+ default=False, help="Daemonize the home server")
+ parser.add_argument('--pid-file', dest="pid", help="When running as a "
+ "daemon, the file to store the pid in",
+ default="hs.pid")
+ args = parser.parse_args()
+
+ verbosity = int(args.verbose) if args.verbose else None
+
+ setup_logging(
+ verbosity=verbosity,
+ filename=args.log_file,
+ config_path=args.log_config,
+ )
+
+ logger.info("Server hostname: %s", args.host)
+
+ hs = SynapseHomeServer(
+ args.host,
+ db_name=args.db
+ )
+
+ # This object doesn't need to be saved because it's set as the handler for
+ # the replication layer
+ hs.get_federation()
+
+ hs.register_servlets()
+
+ hs.get_http_server().start_listening(args.port)
+
+ hs.build_db_pool()
+
+ if args.daemonize:
+ daemon = Daemonize(
+ app="synapse-homeserver",
+ pid=args.pid,
+ action=run,
+ auto_close_fds=False,
+ verbose=True,
+ logger=logger,
+ )
+
+ daemon.start()
+ else:
+ run()
+
+
+if __name__ == '__main__':
+ setup()
diff --git a/synapse/crypto/__init__.py b/synapse/crypto/__init__.py
new file mode 100644
index 0000000000..fe8a073cd3
--- /dev/null
+++ b/synapse/crypto/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/crypto/config.py b/synapse/crypto/config.py
new file mode 100644
index 0000000000..801dfd8656
--- /dev/null
+++ b/synapse/crypto/config.py
@@ -0,0 +1,159 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import ConfigParser as configparser
+import argparse
+import socket
+import sys
+import os
+from OpenSSL import crypto
+import nacl.signing
+from syutil.base64util import encode_base64
+import subprocess
+
+
+def load_config(description, argv):
+ config_parser = argparse.ArgumentParser(add_help=False)
+ config_parser.add_argument("-c", "--config-path", metavar="CONFIG_FILE",
+ help="Specify config file")
+ config_args, remaining_args = config_parser.parse_known_args(argv)
+ if config_args.config_path:
+ config = configparser.SafeConfigParser()
+ config.read([config_args.config_path])
+ defaults = dict(config.items("KeyServer"))
+ else:
+ defaults = {}
+ parser = argparse.ArgumentParser(
+ parents=[config_parser],
+ description=description,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ parser.set_defaults(**defaults)
+ parser.add_argument("--server-name", default=socket.getfqdn(),
+ help="The name of the server")
+ parser.add_argument("--signing-key-path",
+ help="The signing key to sign responses with")
+ parser.add_argument("--tls-certificate-path",
+ help="PEM encoded X509 certificate for TLS")
+ parser.add_argument("--tls-private-key-path",
+ help="PEM encoded private key for TLS")
+ parser.add_argument("--tls-dh-params-path",
+ help="PEM encoded dh parameters for ephemeral keys")
+ parser.add_argument("--bind-port", type=int,
+ help="TCP port to listen on")
+ parser.add_argument("--bind-host", default="",
+ help="Local interface to listen on")
+
+ args = parser.parse_args(remaining_args)
+
+ server_config = vars(args)
+ del server_config["config_path"]
+ return server_config
+
+
+def generate_config(argv):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-c", "--config-path", help="Specify config file",
+ metavar="CONFIG_FILE", required=True)
+ parser.add_argument("--server-name", default=socket.getfqdn(),
+ help="The name of the server")
+ parser.add_argument("--signing-key-path",
+ help="The signing key to sign responses with")
+ parser.add_argument("--tls-certificate-path",
+ help="PEM encoded X509 certificate for TLS")
+ parser.add_argument("--tls-private-key-path",
+ help="PEM encoded private key for TLS")
+ parser.add_argument("--tls-dh-params-path",
+ help="PEM encoded dh parameters for ephemeral keys")
+ parser.add_argument("--bind-port", type=int, required=True,
+ help="TCP port to listen on")
+ parser.add_argument("--bind-host", default="",
+ help="Local interface to listen on")
+
+ args = parser.parse_args(argv)
+
+ dir_name = os.path.dirname(args.config_path)
+ base_key_name = os.path.join(dir_name, args.server_name)
+
+ if args.signing_key_path is None:
+ args.signing_key_path = base_key_name + ".signing.key"
+
+ if args.tls_certificate_path is None:
+ args.tls_certificate_path = base_key_name + ".tls.crt"
+
+ if args.tls_private_key_path is None:
+ args.tls_private_key_path = base_key_name + ".tls.key"
+
+ if args.tls_dh_params_path is None:
+ args.tls_dh_params_path = base_key_name + ".tls.dh"
+
+ if not os.path.exists(args.signing_key_path):
+ with open(args.signing_key_path, "w") as signing_key_file:
+ key = nacl.signing.SigningKey.generate()
+ signing_key_file.write(encode_base64(key.encode()))
+
+ if not os.path.exists(args.tls_private_key_path):
+ with open(args.tls_private_key_path, "w") as private_key_file:
+ tls_private_key = crypto.PKey()
+ tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
+ private_key_pem = crypto.dump_privatekey(
+ crypto.FILETYPE_PEM, tls_private_key
+ )
+ private_key_file.write(private_key_pem)
+ else:
+ with open(args.tls_private_key_path) as private_key_file:
+ private_key_pem = private_key_file.read()
+ tls_private_key = crypto.load_privatekey(
+ crypto.FILETYPE_PEM, private_key_pem
+ )
+
+ if not os.path.exists(args.tls_certificate_path):
+ with open(args.tls_certificate_path, "w") as certifcate_file:
+ cert = crypto.X509()
+ subject = cert.get_subject()
+ subject.CN = args.server_name
+
+ cert.set_serial_number(1000)
+ cert.gmtime_adj_notBefore(0)
+ cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
+ cert.set_issuer(cert.get_subject())
+ cert.set_pubkey(tls_private_key)
+
+ cert.sign(tls_private_key, 'sha256')
+
+ cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
+
+ certifcate_file.write(cert_pem)
+
+ if not os.path.exists(args.tls_dh_params_path):
+ subprocess.check_call([
+ "openssl", "dhparam",
+ "-outform", "PEM",
+ "-out", args.tls_dh_params_path,
+ "2048"
+ ])
+
+ config = configparser.SafeConfigParser()
+ config.add_section("KeyServer")
+ for key, value in vars(args).items():
+ if key != "config_path":
+ config.set("KeyServer", key, str(value))
+
+ with open(args.config_path, "w") as config_file:
+ config.write(config_file)
+
+
+if __name__ == "__main__":
+ generate_config(sys.argv[1:])
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
new file mode 100644
index 0000000000..b53d1c572b
--- /dev/null
+++ b/synapse/crypto/keyclient.py
@@ -0,0 +1,118 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.http import HTTPClient
+from twisted.internet import defer, reactor
+from twisted.internet.protocol import ClientFactory
+from twisted.names.srvconnect import SRVConnector
+import json
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+@defer.inlineCallbacks
+def fetch_server_key(server_name, ssl_context_factory):
+ """Fetch the keys for a remote server."""
+
+ factory = SynapseKeyClientFactory()
+
+ SRVConnector(
+ reactor, "matrix", server_name, factory,
+ protocol="tcp", connectFuncName="connectSSL", defaultPort=443,
+ connectFuncKwArgs=dict(contextFactory=ssl_context_factory)).connect()
+
+ server_key, server_certificate = yield factory.remote_key
+
+ defer.returnValue((server_key, server_certificate))
+
+
+class SynapseKeyClientError(Exception):
+ """The key wasn't retireved from the remote server."""
+ pass
+
+
+class SynapseKeyClientProtocol(HTTPClient):
+ """Low level HTTPS client which retrieves an application/json response from
+ the server and extracts the X.509 certificate for the remote peer from the
+ SSL connection."""
+
+ def connectionMade(self):
+ logger.debug("Connected to %s", self.transport.getHost())
+ self.sendCommand(b"GET", b"/key")
+ self.endHeaders()
+ self.timer = reactor.callLater(
+ self.factory.timeout_seconds,
+ self.on_timeout
+ )
+
+ def handleStatus(self, version, status, message):
+ if status != b"200":
+ logger.info("Non-200 response from %s: %s %s",
+ self.transport.getHost(), status, message)
+ self.transport.abortConnection()
+
+ def handleResponse(self, response_body_bytes):
+ try:
+ json_response = json.loads(response_body_bytes)
+ except ValueError:
+ logger.info("Invalid JSON response from %s",
+ self.transport.getHost())
+ self.transport.abortConnection()
+ return
+
+ certificate = self.transport.getPeerCertificate()
+ self.factory.on_remote_key((json_response, certificate))
+ self.transport.abortConnection()
+ self.timer.cancel()
+
+ def on_timeout(self):
+ logger.debug("Timeout waiting for response from %s",
+ self.transport.getHost())
+ self.transport.abortConnection()
+
+
+class SynapseKeyClientFactory(ClientFactory):
+ protocol = SynapseKeyClientProtocol
+ max_retries = 5
+ timeout_seconds = 30
+
+ def __init__(self):
+ self.succeeded = False
+ self.retries = 0
+ self.remote_key = defer.Deferred()
+
+ def on_remote_key(self, key):
+ self.succeeded = True
+ self.remote_key.callback(key)
+
+ def retry_connection(self, connector):
+ self.retries += 1
+ if self.retries < self.max_retries:
+ connector.connector = None
+ connector.connect()
+ else:
+ self.remote_key.errback(
+ SynapseKeyClientError("Max retries exceeded"))
+
+ def clientConnectionFailed(self, connector, reason):
+ logger.info("Connection failed %s", reason)
+ self.retry_connection(connector)
+
+ def clientConnectionLost(self, connector, reason):
+ logger.info("Connection lost %s", reason)
+ if not self.succeeded:
+ self.retry_connection(connector)
diff --git a/synapse/crypto/keyserver.py b/synapse/crypto/keyserver.py
new file mode 100644
index 0000000000..48bd380781
--- /dev/null
+++ b/synapse/crypto/keyserver.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import reactor, ssl
+from twisted.web import server
+from twisted.web.resource import Resource
+from twisted.python.log import PythonLoggingObserver
+
+from synapse.crypto.resource.key import LocalKey
+from synapse.crypto.config import load_config
+
+from syutil.base64util import decode_base64
+
+from OpenSSL import crypto, SSL
+
+import logging
+import nacl.signing
+import sys
+
+
+class KeyServerSSLContextFactory(ssl.ContextFactory):
+ """Factory for PyOpenSSL SSL contexts that are used to handle incoming
+ connections and to make connections to remote servers."""
+
+ def __init__(self, key_server):
+ self._context = SSL.Context(SSL.SSLv23_METHOD)
+ self.configure_context(self._context, key_server)
+
+ @staticmethod
+ def configure_context(context, key_server):
+ context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
+ context.use_certificate(key_server.tls_certificate)
+ context.use_privatekey(key_server.tls_private_key)
+ context.load_tmp_dh(key_server.tls_dh_params_path)
+ context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")
+
+ def getContext(self):
+ return self._context
+
+
+class KeyServer(object):
+ """An HTTPS server serving LocalKey and RemoteKey resources."""
+
+ def __init__(self, server_name, tls_certificate_path, tls_private_key_path,
+ tls_dh_params_path, signing_key_path, bind_host, bind_port):
+ self.server_name = server_name
+ self.tls_certificate = self.read_tls_certificate(tls_certificate_path)
+ self.tls_private_key = self.read_tls_private_key(tls_private_key_path)
+ self.tls_dh_params_path = tls_dh_params_path
+ self.signing_key = self.read_signing_key(signing_key_path)
+ self.bind_host = bind_host
+ self.bind_port = int(bind_port)
+ self.ssl_context_factory = KeyServerSSLContextFactory(self)
+
+ @staticmethod
+ def read_tls_certificate(cert_path):
+ with open(cert_path) as cert_file:
+ cert_pem = cert_file.read()
+ return crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
+
+ @staticmethod
+ def read_tls_private_key(private_key_path):
+ with open(private_key_path) as private_key_file:
+ private_key_pem = private_key_file.read()
+ return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
+
+ @staticmethod
+ def read_signing_key(signing_key_path):
+ with open(signing_key_path) as signing_key_file:
+ signing_key_b64 = signing_key_file.read()
+ signing_key_bytes = decode_base64(signing_key_b64)
+ return nacl.signing.SigningKey(signing_key_bytes)
+
+ def run(self):
+ root = Resource()
+ root.putChild("key", LocalKey(self))
+ site = server.Site(root)
+ reactor.listenSSL(
+ self.bind_port,
+ site,
+ self.ssl_context_factory,
+ interface=self.bind_host
+ )
+
+ logging.basicConfig(level=logging.DEBUG)
+ observer = PythonLoggingObserver()
+ observer.start()
+
+ reactor.run()
+
+
+def main():
+ key_server = KeyServer(**load_config(__doc__, sys.argv[1:]))
+ key_server.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/synapse/crypto/resource/__init__.py b/synapse/crypto/resource/__init__.py
new file mode 100644
index 0000000000..fe8a073cd3
--- /dev/null
+++ b/synapse/crypto/resource/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/crypto/resource/key.py b/synapse/crypto/resource/key.py
new file mode 100644
index 0000000000..6ce6e0b034
--- /dev/null
+++ b/synapse/crypto/resource/key.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+from twisted.internet import defer
+from synapse.http.server import respond_with_json_bytes
+from synapse.crypto.keyclient import fetch_server_key
+from syutil.crypto.jsonsign import sign_json, verify_signed_json
+from syutil.base64util import encode_base64, decode_base64
+from syutil.jsonutil import encode_canonical_json
+from OpenSSL import crypto
+from nacl.signing import VerifyKey
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class LocalKey(Resource):
+ """HTTP resource containing encoding the TLS X.509 certificate and NACL
+ signature verification keys for this server::
+
+ GET /key HTTP/1.1
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+ {
+ "server_name": "this.server.example.com"
+ "signature_verify_key": # base64 encoded NACL verification key.
+ "tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
+ "signatures": {
+ "this.server.example.com": # NACL signature for this server.
+ }
+ }
+ """
+
+ def __init__(self, key_server):
+ self.key_server = key_server
+ self.response_body = encode_canonical_json(
+ self.response_json_object(key_server)
+ )
+ Resource.__init__(self)
+
+ @staticmethod
+ def response_json_object(key_server):
+ verify_key_bytes = key_server.signing_key.verify_key.encode()
+ x509_certificate_bytes = crypto.dump_certificate(
+ crypto.FILETYPE_ASN1,
+ key_server.tls_certificate
+ )
+ json_object = {
+ u"server_name": key_server.server_name,
+ u"signature_verify_key": encode_base64(verify_key_bytes),
+ u"tls_certificate": encode_base64(x509_certificate_bytes)
+ }
+ signed_json = sign_json(
+ json_object,
+ key_server.server_name,
+ key_server.signing_key
+ )
+ return signed_json
+
+ def getChild(self, name, request):
+ logger.info("getChild %s %s", name, request)
+ if name == '':
+ return self
+ else:
+ return RemoteKey(name, self.key_server)
+
+ def render_GET(self, request):
+ return respond_with_json_bytes(request, 200, self.response_body)
+
+
+class RemoteKey(Resource):
+ """HTTP resource for retreiving the TLS certificate and NACL signature
+ verification keys for a another server. Checks that the reported X.509 TLS
+ certificate matches the one used in the HTTPS connection. Checks that the
+ NACL signature for the remote server is valid. Returns JSON signed by both
+ the remote server and by this server.
+
+ GET /key/remote.server.example.com HTTP/1.1
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+ {
+ "server_name": "remote.server.example.com"
+ "signature_verify_key": # base64 encoded NACL verification key.
+ "tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
+ "signatures": {
+ "remote.server.example.com": # NACL signature for remote server.
+ "this.server.example.com": # NACL signature for this server.
+ }
+ }
+ """
+
+ isLeaf = True
+
+ def __init__(self, server_name, key_server):
+ self.server_name = server_name
+ self.key_server = key_server
+ Resource.__init__(self)
+
+ def render_GET(self, request):
+ self._async_render_GET(request)
+ return NOT_DONE_YET
+
+ @defer.inlineCallbacks
+ def _async_render_GET(self, request):
+ try:
+ server_keys, certificate = yield fetch_server_key(
+ self.server_name,
+ self.key_server.ssl_context_factory
+ )
+
+ resp_server_name = server_keys[u"server_name"]
+ verify_key_b64 = server_keys[u"signature_verify_key"]
+ tls_certificate_b64 = server_keys[u"tls_certificate"]
+ verify_key = VerifyKey(decode_base64(verify_key_b64))
+
+ if resp_server_name != self.server_name:
+ raise ValueError("Wrong server name '%s' != '%s'" %
+ (resp_server_name, self.server_name))
+
+ x509_certificate_bytes = crypto.dump_certificate(
+ crypto.FILETYPE_ASN1,
+ certificate
+ )
+
+ if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
+ raise ValueError("TLS certificate doesn't match")
+
+ verify_signed_json(server_keys, self.server_name, verify_key)
+
+ signed_json = sign_json(
+ server_keys,
+ self.key_server.server_name,
+ self.key_server.signing_key
+ )
+
+ json_bytes = encode_canonical_json(signed_json)
+ respond_with_json_bytes(request, 200, json_bytes)
+
+ except Exception as e:
+ json_bytes = encode_canonical_json({
+ u"error": {u"code": 502, u"message": e.message}
+ })
+ respond_with_json_bytes(request, 502, json_bytes)
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
new file mode 100644
index 0000000000..b4d95ed5ac
--- /dev/null
+++ b/synapse/federation/__init__.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" This package includes all the federation specific logic.
+"""
+
+from .replication import ReplicationLayer
+from .transport import TransportLayer
+
+
+def initialize_http_replication(homeserver):
+ transport = TransportLayer(
+ homeserver.hostname,
+ server=homeserver.get_http_server(),
+ client=homeserver.get_http_client()
+ )
+
+ return ReplicationLayer(homeserver, transport)
diff --git a/synapse/federation/handler.py b/synapse/federation/handler.py
new file mode 100644
index 0000000000..31e8470b33
--- /dev/null
+++ b/synapse/federation/handler.py
@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from .pdu_codec import PduCodec
+
+from synapse.api.errors import AuthError
+from synapse.util.logutils import log_function
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class FederationEventHandler(object):
+ """ Responsible for:
+ a) handling received Pdus before handing them on as Events to the rest
+ of the home server (including auth and state conflict resoultion)
+ b) converting events that were produced by local clients that may need
+ to be sent to remote home servers.
+ """
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.replication_layer = hs.get_replication_layer()
+ self.state_handler = hs.get_state_handler()
+ # self.auth_handler = gs.get_auth_handler()
+ self.event_handler = hs.get_handlers().federation_handler
+ self.server_name = hs.hostname
+
+ self.lock_manager = hs.get_room_lock_manager()
+
+ self.replication_layer.set_handler(self)
+
+ self.pdu_codec = PduCodec(hs)
+
+ @log_function
+ @defer.inlineCallbacks
+ def handle_new_event(self, event):
+ """ Takes in an event from the client to server side, that has already
+ been authed and handled by the state module, and sends it to any
+ remote home servers that may be interested.
+
+ Args:
+ event
+
+ Returns:
+ Deferred: Resolved when it has successfully been queued for
+ processing.
+ """
+ yield self._fill_out_prev_events(event)
+
+ pdu = self.pdu_codec.pdu_from_event(event)
+
+ if not hasattr(pdu, "destinations") or not pdu.destinations:
+ pdu.destinations = []
+
+ yield self.replication_layer.send_pdu(pdu)
+
+ @log_function
+ @defer.inlineCallbacks
+ def backfill(self, room_id, limit):
+ # TODO: Work out which destinations to ask for pagination
+ # self.replication_layer.paginate(dest, room_id, limit)
+ pass
+
+ @log_function
+ def get_state_for_room(self, destination, room_id):
+ return self.replication_layer.get_state_for_context(
+ destination, room_id
+ )
+
+ @log_function
+ @defer.inlineCallbacks
+ def on_receive_pdu(self, pdu):
+ """ Called by the ReplicationLayer when we have a new pdu. We need to
+ do auth checks and put it throught the StateHandler.
+ """
+ event = self.pdu_codec.event_from_pdu(pdu)
+
+ try:
+ with (yield self.lock_manager.lock(pdu.context)):
+ if event.is_state:
+ is_new_state = yield self.state_handler.handle_new_state(
+ pdu
+ )
+ if not is_new_state:
+ return
+ else:
+ is_new_state = False
+
+ yield self.event_handler.on_receive(event, is_new_state)
+
+ except AuthError:
+ # TODO: Implement something in federation that allows us to
+ # respond to PDU.
+ raise
+
+ return
+
+ @defer.inlineCallbacks
+ def _on_new_state(self, pdu, new_state_event):
+ # TODO: Do any store stuff here. Notifiy C2S about this new
+ # state.
+
+ yield self.store.update_current_state(
+ pdu_id=pdu.pdu_id,
+ origin=pdu.origin,
+ context=pdu.context,
+ pdu_type=pdu.pdu_type,
+ state_key=pdu.state_key
+ )
+
+ yield self.event_handler.on_receive(new_state_event)
+
+ @defer.inlineCallbacks
+ def _fill_out_prev_events(self, event):
+ if hasattr(event, "prev_events"):
+ return
+
+ results = yield self.store.get_latest_pdus_in_context(
+ event.room_id
+ )
+
+ es = [
+ "%s@%s" % (p_id, origin) for p_id, origin, _ in results
+ ]
+
+ event.prev_events = [e for e in es if e != event.event_id]
+
+ if results:
+ event.depth = max([int(v) for _, _, v in results]) + 1
+ else:
+ event.depth = 0
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py
new file mode 100644
index 0000000000..9155930e47
--- /dev/null
+++ b/synapse/federation/pdu_codec.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .units import Pdu
+
+import copy
+
+
+def decode_event_id(event_id, server_name):
+ parts = event_id.split("@")
+ if len(parts) < 2:
+ return (event_id, server_name)
+ else:
+ return (parts[0], "".join(parts[1:]))
+
+
+def encode_event_id(pdu_id, origin):
+ return "%s@%s" % (pdu_id, origin)
+
+
+class PduCodec(object):
+
+ def __init__(self, hs):
+ self.server_name = hs.hostname
+ self.event_factory = hs.get_event_factory()
+ self.clock = hs.get_clock()
+
+ def event_from_pdu(self, pdu):
+ kwargs = {}
+
+ kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
+ kwargs["room_id"] = pdu.context
+ kwargs["etype"] = pdu.pdu_type
+ kwargs["prev_events"] = [
+ encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
+ ]
+
+ if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
+ kwargs["prev_state"] = encode_event_id(
+ pdu.prev_state_id, pdu.prev_state_origin
+ )
+
+ kwargs.update({
+ k: v
+ for k, v in pdu.get_full_dict().items()
+ if k not in [
+ "pdu_id",
+ "context",
+ "pdu_type",
+ "prev_pdus",
+ "prev_state_id",
+ "prev_state_origin",
+ ]
+ })
+
+ return self.event_factory.create_event(**kwargs)
+
+ def pdu_from_event(self, event):
+ d = event.get_full_dict()
+
+ d["pdu_id"], d["origin"] = decode_event_id(
+ event.event_id, self.server_name
+ )
+ d["context"] = event.room_id
+ d["pdu_type"] = event.type
+
+ if hasattr(event, "prev_events"):
+ d["prev_pdus"] = [
+ decode_event_id(e, self.server_name)
+ for e in event.prev_events
+ ]
+
+ if hasattr(event, "prev_state"):
+ d["prev_state_id"], d["prev_state_origin"] = (
+ decode_event_id(event.prev_state, self.server_name)
+ )
+
+ if hasattr(event, "state_key"):
+ d["is_state"] = True
+
+ kwargs = copy.deepcopy(event.unrecognized_keys)
+ kwargs.update({
+ k: v for k, v in d.items()
+ if k not in ["event_id", "room_id", "type", "prev_events"]
+ })
+
+ if "ts" not in kwargs:
+ kwargs["ts"] = int(self.clock.time_msec())
+
+ return Pdu(**kwargs)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
new file mode 100644
index 0000000000..ad4111c683
--- /dev/null
+++ b/synapse/federation/persistence.py
@@ -0,0 +1,240 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" This module contains all the persistence actions done by the federation
+package.
+
+These actions are mostly only used by the :py:mod:`.replication` module.
+"""
+
+from twisted.internet import defer
+
+from .units import Pdu
+
+from synapse.util.logutils import log_function
+
+import copy
+import json
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class PduActions(object):
+ """ Defines persistence actions that relate to handling PDUs.
+ """
+
+ def __init__(self, datastore):
+ self.store = datastore
+
+ @log_function
+ def persist_received(self, pdu):
+ """ Persists the given `Pdu` that was received from a remote home
+ server.
+
+ Returns:
+ Deferred
+ """
+ return self._persist(pdu)
+
+ @defer.inlineCallbacks
+ @log_function
+ def persist_outgoing(self, pdu):
+ """ Persists the given `Pdu` that this home server created.
+
+ Returns:
+ Deferred
+ """
+ ret = yield self._persist(pdu)
+
+ defer.returnValue(ret)
+
+ @log_function
+ def mark_as_processed(self, pdu):
+ """ Persist the fact that we have fully processed the given `Pdu`
+
+ Returns:
+ Deferred
+ """
+ return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
+
+ @defer.inlineCallbacks
+ @log_function
+ def populate_previous_pdus(self, pdu):
+ """ Given an outgoing `Pdu` fill out its `prev_ids` key with the `Pdu`s
+ that we have received.
+
+ Returns:
+ Deferred
+ """
+ results = yield self.store.get_latest_pdus_in_context(pdu.context)
+
+ pdu.prev_pdus = [(p_id, origin) for p_id, origin, _ in results]
+
+ vs = [int(v) for _, _, v in results]
+ if vs:
+ pdu.depth = max(vs) + 1
+ else:
+ pdu.depth = 0
+
+ @defer.inlineCallbacks
+ @log_function
+ def after_transaction(self, transaction_id, destination, origin):
+ """ Returns all `Pdu`s that we sent to the given remote home server
+ after a given transaction id.
+
+ Returns:
+ Deferred: Results in a list of `Pdu`s
+ """
+ results = yield self.store.get_pdus_after_transaction(
+ transaction_id,
+ destination
+ )
+
+ defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_all_pdus_from_context(self, context):
+ results = yield self.store.get_all_pdus_from_context(context)
+ defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
+
+ @defer.inlineCallbacks
+ @log_function
+ def paginate(self, context, pdu_list, limit):
+ """ For a given list of PDU id and origins return the proceeding
+ `limit` `Pdu`s in the given `context`.
+
+ Returns:
+ Deferred: Results in a list of `Pdu`s.
+ """
+ results = yield self.store.get_pagination(
+ context, pdu_list, limit
+ )
+
+ defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
+
+ @log_function
+ def is_new(self, pdu):
+ """ When we receive a `Pdu` from a remote home server, we want to
+ figure out whether it is `new`, i.e. it is not some historic PDU that
+ we haven't seen simply because we haven't paginated back that far.
+
+ Returns:
+ Deferred: Results in a `bool`
+ """
+ return self.store.is_pdu_new(
+ pdu_id=pdu.pdu_id,
+ origin=pdu.origin,
+ context=pdu.context,
+ depth=pdu.depth
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _persist(self, pdu):
+ kwargs = copy.copy(pdu.__dict__)
+ unrec_keys = copy.copy(pdu.unrecognized_keys)
+ del kwargs["content"]
+ kwargs["content_json"] = json.dumps(pdu.content)
+ kwargs["unrecognized_keys"] = json.dumps(unrec_keys)
+
+ logger.debug("Persisting: %s", repr(kwargs))
+
+ if pdu.is_state:
+ ret = yield self.store.persist_state(**kwargs)
+ else:
+ ret = yield self.store.persist_pdu(**kwargs)
+
+ yield self.store.update_min_depth_for_context(
+ pdu.context, pdu.depth
+ )
+
+ defer.returnValue(ret)
+
+
+class TransactionActions(object):
+ """ Defines persistence actions that relate to handling Transactions.
+ """
+
+ def __init__(self, datastore):
+ self.store = datastore
+
+ @log_function
+ def have_responded(self, transaction):
+ """ Have we already responded to a transaction with the same id and
+ origin?
+
+ Returns:
+ Deferred: Results in `None` if we have not previously responded to
+ this transaction or a 2-tuple of `(int, dict)` representing the
+ response code and response body.
+ """
+ if not transaction.transaction_id:
+ raise RuntimeError("Cannot persist a transaction with no "
+ "transaction_id")
+
+ return self.store.get_received_txn_response(
+ transaction.transaction_id, transaction.origin
+ )
+
+ @log_function
+ def set_response(self, transaction, code, response):
+ """ Persist how we responded to a transaction.
+
+ Returns:
+ Deferred
+ """
+ if not transaction.transaction_id:
+ raise RuntimeError("Cannot persist a transaction with no "
+ "transaction_id")
+
+ return self.store.set_received_txn_response(
+ transaction.transaction_id,
+ transaction.origin,
+ code,
+ json.dumps(response)
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def prepare_to_send(self, transaction):
+ """ Persists the `Transaction` we are about to send and works out the
+ correct value for the `prev_ids` key.
+
+ Returns:
+ Deferred
+ """
+ transaction.prev_ids = yield self.store.prep_send_transaction(
+ transaction.transaction_id,
+ transaction.destination,
+ transaction.ts,
+ [(p["pdu_id"], p["origin"]) for p in transaction.pdus]
+ )
+
+ @log_function
+ def delivered(self, transaction, response_code, response_dict):
+ """ Marks the given `Transaction` as having been successfully
+ delivered to the remote homeserver, and what the response was.
+
+ Returns:
+ Deferred
+ """
+ return self.store.delivered_txn(
+ transaction.transaction_id,
+ transaction.destination,
+ response_code,
+ json.dumps(response_dict)
+ )
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
new file mode 100644
index 0000000000..0f5b974291
--- /dev/null
+++ b/synapse/federation/replication.py
@@ -0,0 +1,582 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This layer is responsible for replicating with remote home servers using
+a given transport.
+"""
+
+from twisted.internet import defer
+
+from .units import Transaction, Pdu, Edu
+
+from .persistence import PduActions, TransactionActions
+
+from synapse.util.logutils import log_function
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationLayer(object):
+ """This layer is responsible for replicating with remote home servers over
+ the given transport. I.e., does the sending and receiving of PDUs to
+ remote home servers.
+
+ The layer communicates with the rest of the server via a registered
+ ReplicationHandler.
+
+ In more detail, the layer:
+ * Receives incoming data and processes it into transactions and pdus.
+ * Fetches any PDUs it thinks it might have missed.
+ * Keeps the current state for contexts up to date by applying the
+ suitable conflict resolution.
+ * Sends outgoing pdus wrapped in transactions.
+ * Fills out the references to previous pdus/transactions appropriately
+ for outgoing data.
+ """
+
+ def __init__(self, hs, transport_layer):
+ self.server_name = hs.hostname
+
+ self.transport_layer = transport_layer
+ self.transport_layer.register_received_handler(self)
+ self.transport_layer.register_request_handler(self)
+
+ self.store = hs.get_datastore()
+ self.pdu_actions = PduActions(self.store)
+ self.transaction_actions = TransactionActions(self.store)
+
+ self._transaction_queue = _TransactionQueue(
+ hs, self.transaction_actions, transport_layer
+ )
+
+ self.handler = None
+ self.edu_handlers = {}
+
+ self._order = 0
+
+ self._clock = hs.get_clock()
+
+ def set_handler(self, handler):
+ """Sets the handler that the replication layer will use to communicate
+ receipt of new PDUs from other home servers. The required methods are
+ documented on :py:class:`.ReplicationHandler`.
+ """
+ self.handler = handler
+
+ def register_edu_handler(self, edu_type, handler):
+ if edu_type in self.edu_handlers:
+ raise KeyError("Already have an EDU handler for %s" % (edu_type))
+
+ self.edu_handlers[edu_type] = handler
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_pdu(self, pdu):
+ """Informs the replication layer about a new PDU generated within the
+ home server that should be transmitted to others.
+
+ This will fill out various attributes on the PDU object, e.g. the
+ `prev_pdus` key.
+
+ *Note:* The home server should always call `send_pdu` even if it knows
+ that it does not need to be replicated to other home servers. This is
+ in case e.g. someone else joins via a remote home server and then
+ paginates.
+
+ TODO: Figure out when we should actually resolve the deferred.
+
+ Args:
+ pdu (Pdu): The new Pdu.
+
+ Returns:
+ Deferred: Completes when we have successfully processed the PDU
+ and replicated it to any interested remote home servers.
+ """
+ order = self._order
+ self._order += 1
+
+ logger.debug("[%s] Persisting PDU", pdu.pdu_id)
+
+ #yield self.pdu_actions.populate_previous_pdus(pdu)
+
+ # Save *before* trying to send
+ yield self.pdu_actions.persist_outgoing(pdu)
+
+ logger.debug("[%s] Persisted PDU", pdu.pdu_id)
+ logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
+
+ # TODO, add errback, etc.
+ self._transaction_queue.enqueue_pdu(pdu, order)
+
+ logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
+
+ @log_function
+ def send_edu(self, destination, edu_type, content):
+ edu = Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type=edu_type,
+ content=content,
+ )
+
+ # TODO, add errback, etc.
+ self._transaction_queue.enqueue_edu(edu)
+
+ @defer.inlineCallbacks
+ @log_function
+ def paginate(self, dest, context, limit):
+ """Requests some more historic PDUs for the given context from the
+ given destination server.
+
+ Args:
+ dest (str): The remote home server to ask.
+ context (str): The context to paginate back on.
+ limit (int): The maximum number of PDUs to return.
+
+ Returns:
+ Deferred: Results in the received PDUs.
+ """
+ extremities = yield self.store.get_oldest_pdus_in_context(context)
+
+ logger.debug("paginate extrem=%s", extremities)
+
+ # If there are no extremeties then we've (probably) reached the start.
+ if not extremities:
+ return
+
+ transaction_data = yield self.transport_layer.paginate(
+ dest, context, extremities, limit)
+
+ logger.debug("paginate transaction_data=%s", repr(transaction_data))
+
+ transaction = Transaction(**transaction_data)
+
+ pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
+ for pdu in pdus:
+ yield self._handle_new_pdu(pdu)
+
+ defer.returnValue(pdus)
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
+ """Requests the PDU with given origin and ID from the remote home
+ server.
+
+ This will persist the PDU locally upon receipt.
+
+ Args:
+ destination (str): Which home server to query
+ pdu_origin (str): The home server that originally sent the pdu.
+ pdu_id (str)
+ outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
+ it's from an arbitary point in the context as opposed to part
+ of the current block of PDUs. Defaults to `False`
+
+ Returns:
+ Deferred: Results in the requested PDU.
+ """
+
+ transaction_data = yield self.transport_layer.get_pdu(
+ destination, pdu_origin, pdu_id)
+
+ transaction = Transaction(**transaction_data)
+
+ pdu_list = [Pdu(outlier=outlier, **p) for p in transaction.pdus]
+
+ pdu = None
+ if pdu_list:
+ pdu = pdu_list[0]
+ yield self._handle_new_pdu(pdu)
+
+ defer.returnValue(pdu)
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_state_for_context(self, destination, context):
+ """Requests all of the `current` state PDUs for a given context from
+ a remote home server.
+
+ Args:
+ destination (str): The remote homeserver to query for the state.
+ context (str): The context we're interested in.
+
+ Returns:
+ Deferred: Results in a list of PDUs.
+ """
+
+ transaction_data = yield self.transport_layer.get_context_state(
+ destination, context)
+
+ transaction = Transaction(**transaction_data)
+
+ pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
+ for pdu in pdus:
+ yield self._handle_new_pdu(pdu)
+
+ defer.returnValue(pdus)
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_context_pdus_request(self, context):
+ pdus = yield self.pdu_actions.get_all_pdus_from_context(
+ context
+ )
+ defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_paginate_request(self, context, versions, limit):
+
+ pdus = yield self.pdu_actions.paginate(context, versions, limit)
+
+ defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_incoming_transaction(self, transaction_data):
+ transaction = Transaction(**transaction_data)
+
+ logger.debug("[%s] Got transaction", transaction.transaction_id)
+
+ response = yield self.transaction_actions.have_responded(transaction)
+
+ if response:
+ logger.debug("[%s] We've already responed to this request",
+ transaction.transaction_id)
+ defer.returnValue(response)
+ return
+
+ logger.debug("[%s] Transacition is new", transaction.transaction_id)
+
+ pdu_list = [Pdu(**p) for p in transaction.pdus]
+
+ dl = []
+ for pdu in pdu_list:
+ dl.append(self._handle_new_pdu(pdu))
+
+ if hasattr(transaction, "edus"):
+ for edu in [Edu(**x) for x in transaction.edus]:
+ self.received_edu(edu.origin, edu.edu_type, edu.content)
+
+ results = yield defer.DeferredList(dl)
+
+ ret = []
+ for r in results:
+ if r[0]:
+ ret.append({})
+ else:
+ logger.exception(r[1])
+ ret.append({"error": str(r[1])})
+
+ logger.debug("Returning: %s", str(ret))
+
+ yield self.transaction_actions.set_response(
+ transaction,
+ 200, response
+ )
+ defer.returnValue((200, response))
+
+ def received_edu(self, origin, edu_type, content):
+ if edu_type in self.edu_handlers:
+ self.edu_handlers[edu_type](origin, content)
+ else:
+ logger.warn("Received EDU of type %s with no handler", edu_type)
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_context_state_request(self, context):
+ results = yield self.store.get_current_state_for_context(
+ context
+ )
+
+ logger.debug("Context returning %d results", len(results))
+
+ pdus = [Pdu.from_pdu_tuple(p) for p in results]
+ defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_pdu_request(self, pdu_origin, pdu_id):
+ pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
+
+ if pdu:
+ defer.returnValue(
+ (200, self._transaction_from_pdus([pdu]).get_dict())
+ )
+ else:
+ defer.returnValue((404, ""))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_pull_request(self, origin, versions):
+ transaction_id = max([int(v) for v in versions])
+
+ response = yield self.pdu_actions.after_transaction(
+ transaction_id,
+ origin,
+ self.server_name
+ )
+
+ if not response:
+ response = []
+
+ defer.returnValue(
+ (200, self._transaction_from_pdus(response).get_dict())
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _get_persisted_pdu(self, pdu_id, pdu_origin):
+ """ Get a PDU from the database with given origin and id.
+
+ Returns:
+ Deferred: Results in a `Pdu`.
+ """
+ pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
+
+ defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
+
+ def _transaction_from_pdus(self, pdu_list):
+ """Returns a new Transaction containing the given PDUs suitable for
+ transmission.
+ """
+ return Transaction(
+ pdus=[p.get_dict() for p in pdu_list],
+ origin=self.server_name,
+ ts=int(self._clock.time_msec()),
+ destination=None,
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _handle_new_pdu(self, pdu):
+ # We reprocess pdus when we have seen them only as outliers
+ existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
+
+ if existing and (not existing.outlier or pdu.outlier):
+ logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
+ defer.returnValue({})
+ return
+
+ # Get missing pdus if necessary.
+ is_new = yield self.pdu_actions.is_new(pdu)
+ if is_new and not pdu.outlier:
+ # We only paginate backwards to the min depth.
+ min_depth = yield self.store.get_min_depth_for_context(pdu.context)
+
+ if min_depth and pdu.depth > min_depth:
+ for pdu_id, origin in pdu.prev_pdus:
+ exists = yield self._get_persisted_pdu(pdu_id, origin)
+
+ if not exists:
+ logger.debug("Requesting pdu %s %s", pdu_id, origin)
+
+ try:
+ yield self.get_pdu(
+ pdu.origin,
+ pdu_id=pdu_id,
+ pdu_origin=origin
+ )
+ logger.debug("Processed pdu %s %s", pdu_id, origin)
+ except:
+ # TODO(erikj): Do some more intelligent retries.
+ logger.exception("Failed to get PDU")
+
+ # Persist the Pdu, but don't mark it as processed yet.
+ yield self.pdu_actions.persist_received(pdu)
+
+ ret = yield self.handler.on_receive_pdu(pdu)
+
+ yield self.pdu_actions.mark_as_processed(pdu)
+
+ defer.returnValue(ret)
+
+ def __str__(self):
+ return "<ReplicationLayer(%s)>" % self.server_name
+
+
+class ReplicationHandler(object):
+ """This defines the methods that the :py:class:`.ReplicationLayer` will
+ use to communicate with the rest of the home server.
+ """
+ def on_receive_pdu(self, pdu):
+ raise NotImplementedError("on_receive_pdu")
+
+
+class _TransactionQueue(object):
+ """This class makes sure we only have one transaction in flight at
+ a time for a given destination.
+
+ It batches pending PDUs into single transactions.
+ """
+
+ def __init__(self, hs, transaction_actions, transport_layer):
+
+ self.server_name = hs.hostname
+ self.transaction_actions = transaction_actions
+ self.transport_layer = transport_layer
+
+ self._clock = hs.get_clock()
+
+ # Is a mapping from destinations -> deferreds. Used to keep track
+ # of which destinations have transactions in flight and when they are
+ # done
+ self.pending_transactions = {}
+
+ # Is a mapping from destination -> list of
+ # tuple(pending pdus, deferred, order)
+ self.pending_pdus_by_dest = {}
+ # destination -> list of tuple(edu, deferred)
+ self.pending_edus_by_dest = {}
+
+ # HACK to get unique tx id
+ self._next_txn_id = int(self._clock.time_msec())
+
+ @defer.inlineCallbacks
+ @log_function
+ def enqueue_pdu(self, pdu, order):
+ # We loop through all destinations to see whether we already have
+ # a transaction in progress. If we do, stick it in the pending_pdus
+ # table and we'll get back to it later.
+
+ destinations = [
+ d for d in pdu.destinations
+ if d != self.server_name
+ ]
+
+ logger.debug("Sending to: %s", str(destinations))
+
+ if not destinations:
+ return
+
+ deferreds = []
+
+ for destination in destinations:
+ deferred = defer.Deferred()
+ self.pending_pdus_by_dest.setdefault(destination, []).append(
+ (pdu, deferred, order)
+ )
+
+ self._attempt_new_transaction(destination)
+
+ deferreds.append(deferred)
+
+ yield defer.DeferredList(deferreds)
+
+ # NO inlineCallbacks
+ def enqueue_edu(self, edu):
+ destination = edu.destination
+
+ deferred = defer.Deferred()
+ self.pending_edus_by_dest.setdefault(destination, []).append(
+ (edu, deferred)
+ )
+
+ def eb(failure):
+ deferred.errback(failure)
+ self._attempt_new_transaction(destination).addErrback(eb)
+
+ return deferred
+
+ @defer.inlineCallbacks
+ @log_function
+ def _attempt_new_transaction(self, destination):
+ if destination in self.pending_transactions:
+ return
+
+ # list of (pending_pdu, deferred, order)
+ pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+ pending_edus = self.pending_edus_by_dest.pop(destination, [])
+
+ if not pending_pdus and not pending_edus:
+ return
+
+ logger.debug("TX [%s] Attempting new transaction", destination)
+
+ # Sort based on the order field
+ pending_pdus.sort(key=lambda t: t[2])
+
+ pdus = [x[0] for x in pending_pdus]
+ edus = [x[0] for x in pending_edus]
+ deferreds = [x[1] for x in pending_pdus + pending_edus]
+
+ try:
+ self.pending_transactions[destination] = 1
+
+ logger.debug("TX [%s] Persisting transaction...", destination)
+
+ transaction = Transaction.create_new(
+ ts=self._clock.time_msec(),
+ transaction_id=self._next_txn_id,
+ origin=self.server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ )
+
+ self._next_txn_id += 1
+
+ yield self.transaction_actions.prepare_to_send(transaction)
+
+ logger.debug("TX [%s] Persisted transaction", destination)
+ logger.debug("TX [%s] Sending transaction...", destination)
+
+ # Actually send the transaction
+ code, response = yield self.transport_layer.send_transaction(
+ transaction
+ )
+
+ logger.debug("TX [%s] Sent transaction", destination)
+ logger.debug("TX [%s] Marking as delivered...", destination)
+
+ yield self.transaction_actions.delivered(
+ transaction, code, response
+ )
+
+ logger.debug("TX [%s] Marked as delivered", destination)
+ logger.debug("TX [%s] Yielding to callbacks...", destination)
+
+ for deferred in deferreds:
+ if code == 200:
+ deferred.callback(None)
+ else:
+ deferred.errback(RuntimeError("Got status %d" % code))
+
+ # Ensures we don't continue until all callbacks on that
+ # deferred have fired
+ yield deferred
+
+ logger.debug("TX [%s] Yielded to callbacks", destination)
+
+ except Exception as e:
+ logger.error("TX Problem in _attempt_transaction")
+
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.exception(e)
+
+ for deferred in deferreds:
+ deferred.errback(e)
+ yield deferred
+
+ finally:
+ # We want to be *very* sure we delete this after we stop processing
+ self.pending_transactions.pop(destination, None)
+
+ # Check to see if there is anything else to send.
+ self._attempt_new_transaction(destination)
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
new file mode 100644
index 0000000000..2136adf8d7
--- /dev/null
+++ b/synapse/federation/transport.py
@@ -0,0 +1,454 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The transport layer is responsible for both sending transactions to remote
+home servers and receiving a variety of requests from other home servers.
+
+Typically, this is done over HTTP (and all home servers are required to
+support HTTP), however individual pairings of servers may decide to communicate
+over a different (albeit still reliable) protocol.
+"""
+
+from twisted.internet import defer
+
+from synapse.util.logutils import log_function
+
+import logging
+import json
+import re
+
+
+logger = logging.getLogger(__name__)
+
+
+class TransportLayer(object):
+ """This is a basic implementation of the transport layer that translates
+ transactions and other requests to/from HTTP.
+
+ Attributes:
+ server_name (str): Local home server host
+
+ server (synapse.http.server.HttpServer): the http server to
+ register listeners on
+
+ client (synapse.http.client.HttpClient): the http client used to
+ send requests
+
+ request_handler (TransportRequestHandler): The handler to fire when we
+ receive requests for data.
+
+ received_handler (TransportReceivedHandler): The handler to fire when
+ we receive data.
+ """
+
+ def __init__(self, server_name, server, client):
+ """
+ Args:
+ server_name (str): Local home server host
+ server (synapse.protocol.http.HttpServer): the http server to
+ register listeners on
+ client (synapse.protocol.http.HttpClient): the http client used to
+ send requests
+ """
+ self.server_name = server_name
+ self.server = server
+ self.client = client
+ self.request_handler = None
+ self.received_handler = None
+
+ @log_function
+ def get_context_state(self, destination, context):
+ """ Requests all state for a given context (i.e. room) from the
+ given server.
+
+ Args:
+ destination (str): The host name of the remote home server we want
+ to get the state from.
+ context (str): The name of the context we want the state of
+
+ Returns:
+ Deferred: Results in a dict received from the remote homeserver.
+ """
+ logger.debug("get_context_state dest=%s, context=%s",
+ destination, context)
+
+ path = "/state/%s/" % context
+
+ return self._do_request_for_transaction(destination, path)
+
+ @log_function
+ def get_pdu(self, destination, pdu_origin, pdu_id):
+ """ Requests the pdu with give id and origin from the given server.
+
+ Args:
+ destination (str): The host name of the remote home server we want
+ to get the state from.
+ pdu_origin (str): The home server which created the PDU.
+ pdu_id (str): The id of the PDU being requested.
+
+ Returns:
+ Deferred: Results in a dict received from the remote homeserver.
+ """
+ logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
+ destination, pdu_origin, pdu_id)
+
+ path = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
+
+ return self._do_request_for_transaction(destination, path)
+
+ @log_function
+ def paginate(self, dest, context, pdu_tuples, limit):
+ """ Requests `limit` previous PDUs in a given context before list of
+ PDUs.
+
+ Args:
+ dest (str)
+ context (str)
+ pdu_tuples (list)
+ limt (int)
+
+ Returns:
+ Deferred: Results in a dict received from the remote homeserver.
+ """
+ logger.debug(
+ "paginate dest=%s, context=%s, pdu_tuples=%s, limit=%s",
+ dest, context, repr(pdu_tuples), str(limit)
+ )
+
+ if not pdu_tuples:
+ return
+
+ path = "/paginate/%s/" % context
+
+ args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
+ args["limit"] = limit
+
+ return self._do_request_for_transaction(
+ dest,
+ path,
+ args=args,
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_transaction(self, transaction):
+ """ Sends the given Transaction to it's destination
+
+ Args:
+ transaction (Transaction)
+
+ Returns:
+ Deferred: Results of the deferred is a tuple in the form of
+ (response_code, response_body) where the response_body is a
+ python dict decoded from json
+ """
+ logger.debug(
+ "send_data dest=%s, txid=%s",
+ transaction.destination, transaction.transaction_id
+ )
+
+ if transaction.destination == self.server_name:
+ raise RuntimeError("Transport layer cannot send to itself!")
+
+ data = transaction.get_dict()
+
+ code, response = yield self.client.put_json(
+ transaction.destination,
+ path="/send/%s/" % transaction.transaction_id,
+ data=data
+ )
+
+ logger.debug(
+ "send_data dest=%s, txid=%s, got response: %d",
+ transaction.destination, transaction.transaction_id, code
+ )
+
+ defer.returnValue((code, response))
+
+ @log_function
+ def register_received_handler(self, handler):
+ """ Register a handler that will be fired when we receive data.
+
+ Args:
+ handler (TransportReceivedHandler)
+ """
+ self.received_handler = handler
+
+ # This is when someone is trying to send us a bunch of data.
+ self.server.register_path(
+ "PUT",
+ re.compile("^/send/([^/]*)/$"),
+ self._on_send_request
+ )
+
+ @log_function
+ def register_request_handler(self, handler):
+ """ Register a handler that will be fired when we get asked for data.
+
+ Args:
+ handler (TransportRequestHandler)
+ """
+ self.request_handler = handler
+
+ # TODO(markjh): Namespace the federation URI paths
+
+ # This is for when someone asks us for everything since version X
+ self.server.register_path(
+ "GET",
+ re.compile("^/pull/$"),
+ lambda request: handler.on_pull_request(
+ request.args["origin"][0],
+ request.args["v"]
+ )
+ )
+
+ # This is when someone asks for a data item for a given server
+ # data_id pair.
+ self.server.register_path(
+ "GET",
+ re.compile("^/pdu/([^/]*)/([^/]*)/$"),
+ lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
+ pdu_origin, pdu_id
+ )
+ )
+
+ # This is when someone asks for all data for a given context.
+ self.server.register_path(
+ "GET",
+ re.compile("^/state/([^/]*)/$"),
+ lambda request, context: handler.on_context_state_request(
+ context
+ )
+ )
+
+ self.server.register_path(
+ "GET",
+ re.compile("^/paginate/([^/]*)/$"),
+ lambda request, context: self._on_paginate_request(
+ context, request.args["v"],
+ request.args["limit"]
+ )
+ )
+
+ self.server.register_path(
+ "GET",
+ re.compile("^/context/([^/]*)/$"),
+ lambda request, context: handler.on_context_pdus_request(context)
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _on_send_request(self, request, transaction_id):
+ """ Called on PUT /send/<transaction_id>/
+
+ Args:
+ request (twisted.web.http.Request): The HTTP request.
+ transaction_id (str): The transaction_id associated with this
+ request. This is *not* None.
+
+ Returns:
+ Deferred: Results in a tuple of `(code, response)`, where
+ `response` is a python dict to be converted into JSON that is
+ used as the response body.
+ """
+ # Parse the request
+ try:
+ data = request.content.read()
+
+ l = data[:20].encode("string_escape")
+ logger.debug("Got data: \"%s\"", l)
+
+ transaction_data = json.loads(data)
+
+ logger.debug(
+ "Decoded %s: %s",
+ transaction_id, str(transaction_data)
+ )
+
+ # We should ideally be getting this from the security layer.
+ # origin = body["origin"]
+
+ # Add some extra data to the transaction dict that isn't included
+ # in the request body.
+ transaction_data.update(
+ transaction_id=transaction_id,
+ destination=self.server_name
+ )
+
+ except Exception as e:
+ logger.exception(e)
+ defer.returnValue((400, {"error": "Invalid transaction"}))
+ return
+
+ code, response = yield self.received_handler.on_incoming_transaction(
+ transaction_data
+ )
+
+ defer.returnValue((code, response))
+
+ @defer.inlineCallbacks
+ @log_function
+ def _do_request_for_transaction(self, destination, path, args={}):
+ """
+ Args:
+ destination (str)
+ path (str)
+ args (dict): This is parsed directly to the HttpClient.
+
+ Returns:
+ Deferred: Results in a dict.
+ """
+
+ data = yield self.client.get_json(
+ destination,
+ path=path,
+ args=args,
+ )
+
+ # Add certain keys to the JSON, ready for decoding as a Transaction
+ data.update(
+ origin=destination,
+ destination=self.server_name,
+ transaction_id=None
+ )
+
+ defer.returnValue(data)
+
+ @log_function
+ def _on_paginate_request(self, context, v_list, limits):
+ if not limits:
+ return defer.succeed(
+ (400, {"error": "Did not include limit param"})
+ )
+
+ limit = int(limits[-1])
+
+ versions = [v.split(",", 1) for v in v_list]
+
+ return self.request_handler.on_paginate_request(
+ context, versions, limit)
+
+
+class TransportReceivedHandler(object):
+ """ Callbacks used when we receive a transaction
+ """
+ def on_incoming_transaction(self, transaction):
+ """ Called on PUT /send/<transaction_id>, or on response to a request
+ that we sent (e.g. a pagination request)
+
+ Args:
+ transaction (synapse.transaction.Transaction): The transaction that
+ was sent to us.
+
+ Returns:
+ twisted.internet.defer.Deferred: A deferred that get's fired when
+ the transaction has finished being processed.
+
+ The result should be a tuple in the form of
+ `(response_code, respond_body)`, where `response_body` is a python
+ dict that will get serialized to JSON.
+
+ On errors, the dict should have an `error` key with a brief message
+ of what went wrong.
+ """
+ pass
+
+
+class TransportRequestHandler(object):
+ """ Handlers used when someone want's data from us
+ """
+ def on_pull_request(self, versions):
+ """ Called on GET /pull/?v=...
+
+ This is hit when a remote home server wants to get all data
+ after a given transaction. Mainly used when a home server comes back
+ online and wants to get everything it has missed.
+
+ Args:
+ versions (list): A list of transaction_ids that should be used to
+ determine what PDUs the remote side have not yet seen.
+
+ Returns:
+ Deferred: Resultsin a tuple in the form of
+ `(response_code, respond_body)`, where `response_body` is a python
+ dict that will get serialized to JSON.
+
+ On errors, the dict should have an `error` key with a brief message
+ of what went wrong.
+ """
+ pass
+
+ def on_pdu_request(self, pdu_origin, pdu_id):
+ """ Called on GET /pdu/<pdu_origin>/<pdu_id>/
+
+ Someone wants a particular PDU. This PDU may or may not have originated
+ from us.
+
+ Args:
+ pdu_origin (str)
+ pdu_id (str)
+
+ Returns:
+ Deferred: Resultsin a tuple in the form of
+ `(response_code, respond_body)`, where `response_body` is a python
+ dict that will get serialized to JSON.
+
+ On errors, the dict should have an `error` key with a brief message
+ of what went wrong.
+ """
+ pass
+
+ def on_context_state_request(self, context):
+ """ Called on GET /state/<context>/
+
+ Get's hit when someone wants all the *current* state for a given
+ contexts.
+
+ Args:
+ context (str): The name of the context that we're interested in.
+
+ Returns:
+ twisted.internet.defer.Deferred: A deferred that get's fired when
+ the transaction has finished being processed.
+
+ The result should be a tuple in the form of
+ `(response_code, respond_body)`, where `response_body` is a python
+ dict that will get serialized to JSON.
+
+ On errors, the dict should have an `error` key with a brief message
+ of what went wrong.
+ """
+ pass
+
+ def on_paginate_request(self, context, versions, limit):
+ """ Called on GET /paginate/<context>/?v=...&limit=...
+
+ Get's hit when we want to paginate backwards on a given context from
+ the given point.
+
+ Args:
+ context (str): The context to paginate on
+ versions (list): A list of 2-tuple's representing where to paginate
+ from, in the form `(pdu_id, origin)`
+ limit (int): How many pdus to return.
+
+ Returns:
+ Deferred: Resultsin a tuple in the form of
+ `(response_code, respond_body)`, where `response_body` is a python
+ dict that will get serialized to JSON.
+
+ On errors, the dict should have an `error` key with a brief message
+ of what went wrong.
+ """
+ pass
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
new file mode 100644
index 0000000000..0efea7b768
--- /dev/null
+++ b/synapse/federation/units.py
@@ -0,0 +1,236 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Defines the JSON structure of the protocol units used by the server to
+server protocol.
+"""
+
+from synapse.util.jsonobject import JsonEncodedObject
+
+import logging
+import json
+import copy
+
+
+logger = logging.getLogger(__name__)
+
+
+class Pdu(JsonEncodedObject):
+ """ A Pdu represents a piece of data sent from a server and is associated
+ with a context.
+
+ A Pdu can be classified as "state". For a given context, we can efficiently
+ retrieve all state pdu's that haven't been clobbered. Clobbering is done
+ via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
+ is a state pdu if `is_state` is True.
+
+ Example pdu::
+
+ {
+ "pdu_id": "78c",
+ "ts": 1404835423000,
+ "origin": "bar",
+ "prev_ids": [
+ ["23b", "foo"],
+ ["56a", "bar"],
+ ],
+ "content": { ... },
+ }
+
+ """
+
+ valid_keys = [
+ "pdu_id",
+ "context",
+ "origin",
+ "ts",
+ "pdu_type",
+ "destinations",
+ "transaction_id",
+ "prev_pdus",
+ "depth",
+ "content",
+ "outlier",
+ "is_state", # Below this are keys valid only for State Pdus.
+ "state_key",
+ "power_level",
+ "prev_state_id",
+ "prev_state_origin",
+ ]
+
+ internal_keys = [
+ "destinations",
+ "transaction_id",
+ "outlier",
+ ]
+
+ required_keys = [
+ "pdu_id",
+ "context",
+ "origin",
+ "ts",
+ "pdu_type",
+ "content",
+ ]
+
+ # TODO: We need to make this properly load content rather than
+ # just leaving it as a dict. (OR DO WE?!)
+
+ def __init__(self, destinations=[], is_state=False, prev_pdus=[],
+ outlier=False, **kwargs):
+ if is_state:
+ for required_key in ["state_key"]:
+ if required_key not in kwargs:
+ raise RuntimeError("Key %s is required" % required_key)
+
+ super(Pdu, self).__init__(
+ destinations=destinations,
+ is_state=is_state,
+ prev_pdus=prev_pdus,
+ outlier=outlier,
+ **kwargs
+ )
+
+ @classmethod
+ def from_pdu_tuple(cls, pdu_tuple):
+ """ Converts a PduTuple to a Pdu
+
+ Args:
+ pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
+ convert
+
+ Returns:
+ Pdu
+ """
+ if pdu_tuple:
+ d = copy.copy(pdu_tuple.pdu_entry._asdict())
+
+ d["content"] = json.loads(d["content_json"])
+ del d["content_json"]
+
+ args = {f: d[f] for f in cls.valid_keys if f in d}
+ if "unrecognized_keys" in d and d["unrecognized_keys"]:
+ args.update(json.loads(d["unrecognized_keys"]))
+
+ return Pdu(
+ prev_pdus=pdu_tuple.prev_pdu_list,
+ **args
+ )
+ else:
+ return None
+
+ def __str__(self):
+ return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
+
+ def __repr__(self):
+ return "<%s, %s>" % (self.__class__.__name__, repr(self.__dict__))
+
+
+class Edu(JsonEncodedObject):
+ """ An Edu represents a piece of data sent from one homeserver to another.
+
+ In comparison to Pdus, Edus are not persisted for a long time on disk, are
+ not meaningful beyond a given pair of homeservers, and don't have an
+ internal ID or previous references graph.
+ """
+
+ valid_keys = [
+ "origin",
+ "destination",
+ "edu_type",
+ "content",
+ ]
+
+ required_keys = [
+ "origin",
+ "destination",
+ "edu_type",
+ ]
+
+
+class Transaction(JsonEncodedObject):
+ """ A transaction is a list of Pdus and Edus to be sent to a remote home
+ server with some extra metadata.
+
+ Example transaction::
+
+ {
+ "origin": "foo",
+ "prev_ids": ["abc", "def"],
+ "pdus": [
+ ...
+ ],
+ }
+
+ """
+
+ valid_keys = [
+ "transaction_id",
+ "origin",
+ "destination",
+ "ts",
+ "previous_ids",
+ "pdus",
+ "edus",
+ ]
+
+ internal_keys = [
+ "transaction_id",
+ "destination",
+ ]
+
+ required_keys = [
+ "transaction_id",
+ "origin",
+ "destination",
+ "ts",
+ "pdus",
+ ]
+
+ def __init__(self, transaction_id=None, pdus=[], **kwargs):
+ """ If we include a list of pdus then we decode then as PDU's
+ automatically.
+ """
+
+ # If there's no EDUs then remove the arg
+ if "edus" in kwargs and not kwargs["edus"]:
+ del kwargs["edus"]
+
+ super(Transaction, self).__init__(
+ transaction_id=transaction_id,
+ pdus=pdus,
+ **kwargs
+ )
+
+ @staticmethod
+ def create_new(pdus, **kwargs):
+ """ Used to create a new transaction. Will auto fill out
+ transaction_id and ts keys.
+ """
+ if "ts" not in kwargs:
+ raise KeyError("Require 'ts' to construct a Transaction")
+ if "transaction_id" not in kwargs:
+ raise KeyError(
+ "Require 'transaction_id' to construct a Transaction"
+ )
+
+ for p in pdus:
+ p.transaction_id = kwargs["transaction_id"]
+
+ kwargs["pdus"] = [p.get_dict() for p in pdus]
+
+ return Transaction(**kwargs)
+
+
+
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
new file mode 100644
index 0000000000..5688b68e49
--- /dev/null
+++ b/synapse/handlers/__init__.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .register import RegistrationHandler
+from .room import (
+ MessageHandler, RoomCreationHandler, RoomMemberHandler, RoomListHandler
+)
+from .events import EventStreamHandler
+from .federation import FederationHandler
+from .login import LoginHandler
+from .profile import ProfileHandler
+from .presence import PresenceHandler
+from .directory import DirectoryHandler
+
+
+class Handlers(object):
+
+ """ A collection of all the event handlers.
+
+ There's no need to lazily create these; we'll just make them all eagerly
+ at construction time.
+ """
+
+ def __init__(self, hs):
+ self.registration_handler = RegistrationHandler(hs)
+ self.message_handler = MessageHandler(hs)
+ self.room_creation_handler = RoomCreationHandler(hs)
+ self.room_member_handler = RoomMemberHandler(hs)
+ self.event_stream_handler = EventStreamHandler(hs)
+ self.federation_handler = FederationHandler(hs)
+ self.profile_handler = ProfileHandler(hs)
+ self.presence_handler = PresenceHandler(hs)
+ self.room_list_handler = RoomListHandler(hs)
+ self.login_handler = LoginHandler(hs)
+ self.directory_handler = DirectoryHandler(hs)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
new file mode 100644
index 0000000000..87a392dd77
--- /dev/null
+++ b/synapse/handlers/_base.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class BaseHandler(object):
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.event_factory = hs.get_event_factory()
+ self.auth = hs.get_auth()
+ self.notifier = hs.get_notifier()
+ self.room_lock = hs.get_room_lock_manager()
+ self.state_handler = hs.get_state_handler()
+ self.hs = hs
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
new file mode 100644
index 0000000000..456007c71d
--- /dev/null
+++ b/synapse/handlers/directory.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from ._base import BaseHandler
+
+from synapse.api.errors import SynapseError
+
+import logging
+import json
+import urllib
+
+
+logger = logging.getLogger(__name__)
+
+
+# TODO(erikj): This needs to be factored out somewere
+PREFIX = "/matrix/client/api/v1"
+
+
+class DirectoryHandler(BaseHandler):
+
+ def __init__(self, hs):
+ super(DirectoryHandler, self).__init__(hs)
+ self.hs = hs
+ self.http_client = hs.get_http_client()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def create_association(self, room_alias, room_id, servers):
+ # TODO(erikj): Do auth.
+
+ if not room_alias.is_mine:
+ raise SynapseError(400, "Room alias must be local")
+ # TODO(erikj): Change this.
+
+ # TODO(erikj): Add transactions.
+
+ # TODO(erikj): Check if there is a current association.
+
+ yield self.store.create_room_alias_association(
+ room_alias,
+ room_id,
+ servers
+ )
+
+ @defer.inlineCallbacks
+ def get_association(self, room_alias, local_only=False):
+ # TODO(erikj): Do auth
+
+ room_id = None
+ if room_alias.is_mine:
+ result = yield self.store.get_association_from_room_alias(
+ room_alias
+ )
+
+ if result:
+ room_id = result.room_id
+ servers = result.servers
+ elif not local_only:
+ path = "%s/ds/room/%s?local_only=1" % (
+ PREFIX,
+ urllib.quote(room_alias.to_string())
+ )
+
+ result = None
+ try:
+ result = yield self.http_client.get_json(
+ destination=room_alias.domain,
+ path=path,
+ )
+ except:
+ # TODO(erikj): Handle this better?
+ logger.exception("Failed to get remote room alias")
+
+ if result and "room_id" in result and "servers" in result:
+ room_id = result["room_id"]
+ servers = result["servers"]
+
+ if not room_id:
+ defer.returnValue({})
+ return
+
+ defer.returnValue({
+ "room_id": room_id,
+ "servers": servers,
+ })
+ return
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
new file mode 100644
index 0000000000..79742a4e1c
--- /dev/null
+++ b/synapse/handlers/events.py
@@ -0,0 +1,149 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from ._base import BaseHandler
+from synapse.api.streams.event import (
+ EventStream, MessagesStreamData, RoomMemberStreamData, FeedbackStreamData,
+ RoomDataStreamData
+)
+from synapse.handlers.presence import PresenceStreamData
+
+
+class EventStreamHandler(BaseHandler):
+
+ stream_data_classes = [
+ MessagesStreamData,
+ RoomMemberStreamData,
+ FeedbackStreamData,
+ RoomDataStreamData,
+ PresenceStreamData,
+ ]
+
+ def __init__(self, hs):
+ super(EventStreamHandler, self).__init__(hs)
+
+ # Count of active streams per user
+ self._streams_per_user = {}
+ # Grace timers per user to delay the "stopped" signal
+ self._stop_timer_per_user = {}
+
+ self.distributor = hs.get_distributor()
+ self.distributor.declare("started_user_eventstream")
+ self.distributor.declare("stopped_user_eventstream")
+
+ self.clock = hs.get_clock()
+
+ def get_event_stream_token(self, stream_type, store_id, start_token):
+ """Return the next token after this event.
+
+ Args:
+ stream_type (str): The StreamData.EVENT_TYPE
+ store_id (int): The new storage ID assigned from the data store.
+ start_token (str): The token the user started with.
+ Returns:
+ str: The end token.
+ """
+ for i, stream_cls in enumerate(EventStreamHandler.stream_data_classes):
+ if stream_cls.EVENT_TYPE == stream_type:
+ # this is the stream for this event, so replace this part of
+ # the token
+ store_ids = start_token.split(EventStream.SEPARATOR)
+ store_ids[i] = str(store_id)
+ return EventStream.SEPARATOR.join(store_ids)
+ raise RuntimeError("Didn't find a stream type %s" % stream_type)
+
+ @defer.inlineCallbacks
+ def get_stream(self, auth_user_id, pagin_config, timeout=0):
+ """Gets events as an event stream for this user.
+
+ This function looks for interesting *events* for this user. This is
+ different from the notifier, which looks for interested *users* who may
+ want to know about a single event.
+
+ Args:
+ auth_user_id (str): The user requesting their event stream.
+ pagin_config (synapse.api.streams.PaginationConfig): The config to
+ use when obtaining the stream.
+ timeout (int): The max time to wait for an incoming event in ms.
+ Returns:
+ A pagination stream API dict
+ """
+ auth_user = self.hs.parse_userid(auth_user_id)
+
+ stream_id = object()
+
+ try:
+ if auth_user not in self._streams_per_user:
+ self._streams_per_user[auth_user] = 0
+ if auth_user in self._stop_timer_per_user:
+ self.clock.cancel_call_later(
+ self._stop_timer_per_user.pop(auth_user))
+ else:
+ self.distributor.fire(
+ "started_user_eventstream", auth_user
+ )
+ self._streams_per_user[auth_user] += 1
+
+ # construct an event stream with the correct data ordering
+ stream_data_list = []
+ for stream_class in EventStreamHandler.stream_data_classes:
+ stream_data_list.append(stream_class(self.hs))
+ event_stream = EventStream(auth_user_id, stream_data_list)
+
+ # fix unknown tokens to known tokens
+ pagin_config = yield event_stream.fix_tokens(pagin_config)
+
+ # register interest in receiving new events
+ self.notifier.store_events_for(user_id=auth_user_id,
+ stream_id=stream_id,
+ from_tok=pagin_config.from_tok)
+
+ # see if we can grab a chunk now
+ data_chunk = yield event_stream.get_chunk(config=pagin_config)
+
+ # if there are previous events, return those. If not, wait on the
+ # new events for 'timeout' seconds.
+ if len(data_chunk["chunk"]) == 0 and timeout != 0:
+ results = yield defer.maybeDeferred(
+ self.notifier.get_events_for,
+ user_id=auth_user_id,
+ stream_id=stream_id,
+ timeout=timeout
+ )
+ if results:
+ defer.returnValue(results)
+
+ defer.returnValue(data_chunk)
+ finally:
+ # cleanup
+ self.notifier.purge_events_for(user_id=auth_user_id,
+ stream_id=stream_id)
+
+ self._streams_per_user[auth_user] -= 1
+ if not self._streams_per_user[auth_user]:
+ del self._streams_per_user[auth_user]
+
+ # 10 seconds of grace to allow the client to reconnect again
+ # before we think they're gone
+ def _later():
+ self.distributor.fire(
+ "stopped_user_eventstream", auth_user
+ )
+ del self._stop_timer_per_user[auth_user]
+
+ self._stop_timer_per_user[auth_user] = (
+ self.clock.call_later(5, _later)
+ )
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
new file mode 100644
index 0000000000..12e7afca4c
--- /dev/null
+++ b/synapse/handlers/federation.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains handlers for federation events."""
+
+from ._base import BaseHandler
+
+from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent
+from synapse.api.constants import Membership
+from synapse.util.logutils import log_function
+
+from twisted.internet import defer
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class FederationHandler(BaseHandler):
+
+ """Handles events that originated from federation."""
+
+ @log_function
+ @defer.inlineCallbacks
+ def on_receive(self, event, is_new_state):
+ if hasattr(event, "state_key") and not is_new_state:
+ logger.debug("Ignoring old state.")
+ return
+
+ target_is_mine = False
+ if hasattr(event, "target_host"):
+ target_is_mine = event.target_host == self.hs.hostname
+
+ if event.type == InviteJoinEvent.TYPE:
+ if not target_is_mine:
+ logger.debug("Ignoring invite/join event %s", event)
+ return
+
+ # If we receive an invite/join event then we need to join the
+ # sender to the given room.
+ # TODO: We should probably auth this or some such
+ content = event.content
+ content.update({"membership": Membership.JOIN})
+ new_event = self.event_factory.create_event(
+ etype=RoomMemberEvent.TYPE,
+ target_user_id=event.user_id,
+ room_id=event.room_id,
+ user_id=event.user_id,
+ membership=Membership.JOIN,
+ content=content
+ )
+
+ yield self.hs.get_handlers().room_member_handler.change_membership(
+ new_event,
+ True
+ )
+
+ else:
+ with (yield self.room_lock.lock(event.room_id)):
+ store_id = yield self.store.persist_event(event)
+
+ yield self.notifier.on_new_room_event(event, store_id)
diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py
new file mode 100644
index 0000000000..5a1acd7102
--- /dev/null
+++ b/synapse/handlers/login.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from ._base import BaseHandler
+from synapse.api.errors import LoginError
+
+import bcrypt
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class LoginHandler(BaseHandler):
+
+ def __init__(self, hs):
+ super(LoginHandler, self).__init__(hs)
+ self.hs = hs
+
+ @defer.inlineCallbacks
+ def login(self, user, password):
+ """Login as the specified user with the specified password.
+
+ Args:
+ user (str): The user ID.
+ password (str): The password.
+ Returns:
+ The newly allocated access token.
+ Raises:
+ StoreError if there was a problem storing the token.
+ LoginError if there was an authentication problem.
+ """
+ # TODO do this better, it can't go in __init__ else it cyclic loops
+ if not hasattr(self, "reg_handler"):
+ self.reg_handler = self.hs.get_handlers().registration_handler
+
+ # pull out the hash for this user if they exist
+ user_info = yield self.store.get_user_by_id(user_id=user)
+ if not user_info:
+ logger.warn("Attempted to login as %s but they do not exist.", user)
+ raise LoginError(403, "")
+
+ stored_hash = user_info[0]["password_hash"]
+ if bcrypt.checkpw(password, stored_hash):
+ # generate an access token and store it.
+ token = self.reg_handler._generate_token(user)
+ logger.info("Adding token %s for user %s", token, user)
+ yield self.store.add_access_token_to_user(user, token)
+ defer.returnValue(token)
+ else:
+ logger.warn("Failed password login for user %s", user)
+ raise LoginError(403, "")
\ No newline at end of file
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
new file mode 100644
index 0000000000..38db4b1d67
--- /dev/null
+++ b/synapse/handlers/presence.py
@@ -0,0 +1,697 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError, AuthError
+from synapse.api.constants import PresenceState
+from synapse.api.streams import StreamData
+
+from ._base import BaseHandler
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+# TODO(paul): Maybe there's one of these I can steal from somewhere
+def partition(l, func):
+ """Partition the list by the result of func applied to each element."""
+ ret = {}
+
+ for x in l:
+ key = func(x)
+ if key not in ret:
+ ret[key] = []
+ ret[key].append(x)
+
+ return ret
+
+
+def partitionbool(l, func):
+ def boolfunc(x):
+ return bool(func(x))
+
+ ret = partition(l, boolfunc)
+ return ret.get(True, []), ret.get(False, [])
+
+
+class PresenceHandler(BaseHandler):
+
+ def __init__(self, hs):
+ super(PresenceHandler, self).__init__(hs)
+
+ self.homeserver = hs
+
+ distributor = hs.get_distributor()
+ distributor.observe("registered_user", self.registered_user)
+
+ distributor.observe(
+ "started_user_eventstream", self.started_user_eventstream
+ )
+ distributor.observe(
+ "stopped_user_eventstream", self.stopped_user_eventstream
+ )
+
+ distributor.observe("user_joined_room",
+ self.user_joined_room
+ )
+
+ distributor.declare("collect_presencelike_data")
+
+ distributor.declare("changed_presencelike_data")
+ distributor.observe(
+ "changed_presencelike_data", self.changed_presencelike_data
+ )
+
+ self.distributor = distributor
+
+ self.federation = hs.get_replication_layer()
+
+ self.federation.register_edu_handler(
+ "m.presence", self.incoming_presence
+ )
+ self.federation.register_edu_handler(
+ "m.presence_invite",
+ lambda origin, content: self.invite_presence(
+ observed_user=hs.parse_userid(content["observed_user"]),
+ observer_user=hs.parse_userid(content["observer_user"]),
+ )
+ )
+ self.federation.register_edu_handler(
+ "m.presence_accept",
+ lambda origin, content: self.accept_presence(
+ observed_user=hs.parse_userid(content["observed_user"]),
+ observer_user=hs.parse_userid(content["observer_user"]),
+ )
+ )
+ self.federation.register_edu_handler(
+ "m.presence_deny",
+ lambda origin, content: self.deny_presence(
+ observed_user=hs.parse_userid(content["observed_user"]),
+ observer_user=hs.parse_userid(content["observer_user"]),
+ )
+ )
+
+ # IN-MEMORY store, mapping local userparts to sets of local users to
+ # be informed of state changes.
+ self._local_pushmap = {}
+ # map local users to sets of remote /domain names/ who are interested
+ # in them
+ self._remote_sendmap = {}
+ # map remote users to sets of local users who're interested in them
+ self._remote_recvmap = {}
+
+ # map any user to a UserPresenceCache
+ self._user_cachemap = {}
+ self._user_cachemap_latest_serial = 0
+
+ def _get_or_make_usercache(self, user):
+ """If the cache entry doesn't exist, initialise a new one."""
+ if user not in self._user_cachemap:
+ self._user_cachemap[user] = UserPresenceCache()
+ return self._user_cachemap[user]
+
+ def _get_or_offline_usercache(self, user):
+ """If the cache entry doesn't exist, return an OFFLINE one but do not
+ store it into the cache."""
+ if user in self._user_cachemap:
+ return self._user_cachemap[user]
+ else:
+ statuscache = UserPresenceCache()
+ statuscache.update({"state": PresenceState.OFFLINE}, user)
+ return statuscache
+
+ def registered_user(self, user):
+ self.store.create_presence(user.localpart)
+
+ @defer.inlineCallbacks
+ def is_presence_visible(self, observer_user, observed_user):
+ assert(observed_user.is_mine)
+
+ if observer_user == observed_user:
+ defer.returnValue(True)
+
+ allowed_by_subscription = yield self.store.is_presence_visible(
+ observed_localpart=observed_user.localpart,
+ observer_userid=observer_user.to_string(),
+ )
+
+ if allowed_by_subscription:
+ defer.returnValue(True)
+
+ # TODO(paul): Check same channel
+
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def get_state(self, target_user, auth_user):
+ if target_user.is_mine:
+ visible = yield self.is_presence_visible(observer_user=auth_user,
+ observed_user=target_user
+ )
+
+ if visible:
+ state = yield self.store.get_presence_state(
+ target_user.localpart
+ )
+ defer.returnValue(state)
+ else:
+ raise SynapseError(404, "Presence information not visible")
+ else:
+ # TODO(paul): Have remote server send us permissions set
+ defer.returnValue(
+ self._get_or_offline_usercache(target_user).get_state()
+ )
+
+ @defer.inlineCallbacks
+ def set_state(self, target_user, auth_user, state):
+ if not target_user.is_mine:
+ raise SynapseError(400, "User is not hosted on this Home Server")
+
+ if target_user != auth_user:
+ raise AuthError(400, "Cannot set another user's displayname")
+
+ # TODO(paul): Sanity-check 'state'
+ if "status_msg" not in state:
+ state["status_msg"] = None
+
+ for k in state.keys():
+ if k not in ("state", "status_msg"):
+ raise SynapseError(
+ 400, "Unexpected presence state key '%s'" % (k,)
+ )
+
+ logger.debug("Updating presence state of %s to %s",
+ target_user.localpart, state["state"])
+
+ state_to_store = dict(state)
+
+ yield defer.DeferredList([
+ self.store.set_presence_state(
+ target_user.localpart, state_to_store
+ ),
+ self.distributor.fire(
+ "collect_presencelike_data", target_user, state
+ ),
+ ])
+
+ now_online = state["state"] != PresenceState.OFFLINE
+ was_polling = target_user in self._user_cachemap
+
+ if now_online and not was_polling:
+ self.start_polling_presence(target_user, state=state)
+ elif not now_online and was_polling:
+ self.stop_polling_presence(target_user)
+
+ # TODO(paul): perform a presence push as part of start/stop poll so
+ # we don't have to do this all the time
+ self.changed_presencelike_data(target_user, state)
+
+ if not now_online:
+ del self._user_cachemap[target_user]
+
+ def changed_presencelike_data(self, user, state):
+ statuscache = self._get_or_make_usercache(user)
+
+ self._user_cachemap_latest_serial += 1
+ statuscache.update(state, serial=self._user_cachemap_latest_serial)
+
+ self.push_presence(user, statuscache=statuscache)
+
+ def started_user_eventstream(self, user):
+ # TODO(paul): Use "last online" state
+ self.set_state(user, user, {"state": PresenceState.ONLINE})
+
+ def stopped_user_eventstream(self, user):
+ # TODO(paul): Save current state as "last online" state
+ self.set_state(user, user, {"state": PresenceState.OFFLINE})
+
+ @defer.inlineCallbacks
+ def user_joined_room(self, user, room_id):
+ localusers = set()
+ remotedomains = set()
+
+ rm_handler = self.homeserver.get_handlers().room_member_handler
+ yield rm_handler.fetch_room_distributions_into(room_id,
+ localusers=localusers, remotedomains=remotedomains,
+ ignore_user=user)
+
+ if user.is_mine:
+ yield self._send_presence_to_distribution(srcuser=user,
+ localusers=localusers, remotedomains=remotedomains,
+ statuscache=self._get_or_offline_usercache(user),
+ )
+
+ for srcuser in localusers:
+ yield self._send_presence(srcuser=srcuser, destuser=user,
+ statuscache=self._get_or_offline_usercache(srcuser),
+ )
+
+ @defer.inlineCallbacks
+ def send_invite(self, observer_user, observed_user):
+ if not observer_user.is_mine:
+ raise SynapseError(400, "User is not hosted on this Home Server")
+
+ yield self.store.add_presence_list_pending(
+ observer_user.localpart, observed_user.to_string()
+ )
+
+ if observed_user.is_mine:
+ yield self.invite_presence(observed_user, observer_user)
+ else:
+ yield self.federation.send_edu(
+ destination=observed_user.domain,
+ edu_type="m.presence_invite",
+ content={
+ "observed_user": observed_user.to_string(),
+ "observer_user": observer_user.to_string(),
+ }
+ )
+
+ @defer.inlineCallbacks
+ def _should_accept_invite(self, observed_user, observer_user):
+ if not observed_user.is_mine:
+ defer.returnValue(False)
+
+ row = yield self.store.has_presence_state(observed_user.localpart)
+ if not row:
+ defer.returnValue(False)
+
+ # TODO(paul): Eventually we'll ask the user's permission for this
+ # before accepting. For now just accept any invite request
+ defer.returnValue(True)
+
+ @defer.inlineCallbacks
+ def invite_presence(self, observed_user, observer_user):
+ accept = yield self._should_accept_invite(observed_user, observer_user)
+
+ if accept:
+ yield self.store.allow_presence_visible(
+ observed_user.localpart, observer_user.to_string()
+ )
+
+ if observer_user.is_mine:
+ if accept:
+ yield self.accept_presence(observed_user, observer_user)
+ else:
+ yield self.deny_presence(observed_user, observer_user)
+ else:
+ edu_type = "m.presence_accept" if accept else "m.presence_deny"
+
+ yield self.federation.send_edu(
+ destination=observer_user.domain,
+ edu_type=edu_type,
+ content={
+ "observed_user": observed_user.to_string(),
+ "observer_user": observer_user.to_string(),
+ }
+ )
+
+ @defer.inlineCallbacks
+ def accept_presence(self, observed_user, observer_user):
+ yield self.store.set_presence_list_accepted(
+ observer_user.localpart, observed_user.to_string()
+ )
+
+ self.start_polling_presence(observer_user, target_user=observed_user)
+
+ @defer.inlineCallbacks
+ def deny_presence(self, observed_user, observer_user):
+ yield self.store.del_presence_list(
+ observer_user.localpart, observed_user.to_string()
+ )
+
+ # TODO(paul): Inform the user somehow?
+
+ @defer.inlineCallbacks
+ def drop(self, observed_user, observer_user):
+ if not observer_user.is_mine:
+ raise SynapseError(400, "User is not hosted on this Home Server")
+
+ yield self.store.del_presence_list(
+ observer_user.localpart, observed_user.to_string()
+ )
+
+ self.stop_polling_presence(observer_user, target_user=observed_user)
+
+ @defer.inlineCallbacks
+ def get_presence_list(self, observer_user, accepted=None):
+ if not observer_user.is_mine:
+ raise SynapseError(400, "User is not hosted on this Home Server")
+
+ presence = yield self.store.get_presence_list(
+ observer_user.localpart, accepted=accepted
+ )
+
+ for p in presence:
+ observed_user = self.hs.parse_userid(p.pop("observed_user_id"))
+ p["observed_user"] = observed_user
+ p.update(self._get_or_offline_usercache(observed_user).get_state())
+
+ defer.returnValue(presence)
+
+ @defer.inlineCallbacks
+ def start_polling_presence(self, user, target_user=None, state=None):
+ logger.debug("Start polling for presence from %s", user)
+
+ if target_user:
+ target_users = [target_user]
+ else:
+ presence = yield self.store.get_presence_list(
+ user.localpart, accepted=True
+ )
+ target_users = [
+ self.hs.parse_userid(x["observed_user_id"]) for x in presence
+ ]
+
+ if state is None:
+ state = yield self.store.get_presence_state(user.localpart)
+
+ localusers, remoteusers = partitionbool(
+ target_users,
+ lambda u: u.is_mine
+ )
+
+ for target_user in localusers:
+ self._start_polling_local(user, target_user)
+
+ deferreds = []
+ remoteusers_by_domain = partition(remoteusers, lambda u: u.domain)
+ for domain in remoteusers_by_domain:
+ remoteusers = remoteusers_by_domain[domain]
+
+ deferreds.append(self._start_polling_remote(
+ user, domain, remoteusers
+ ))
+
+ yield defer.DeferredList(deferreds)
+
+ def _start_polling_local(self, user, target_user):
+ target_localpart = target_user.localpart
+
+ if not self.is_presence_visible(observer_user=user,
+ observed_user=target_user):
+ return
+
+ if target_localpart not in self._local_pushmap:
+ self._local_pushmap[target_localpart] = set()
+
+ self._local_pushmap[target_localpart].add(user)
+
+ self.push_update_to_clients(
+ observer_user=user,
+ observed_user=target_user,
+ statuscache=self._get_or_offline_usercache(target_user),
+ )
+
+ def _start_polling_remote(self, user, domain, remoteusers):
+ for u in remoteusers:
+ if u not in self._remote_recvmap:
+ self._remote_recvmap[u] = set()
+
+ self._remote_recvmap[u].add(user)
+
+ return self.federation.send_edu(
+ destination=domain,
+ edu_type="m.presence",
+ content={"poll": [u.to_string() for u in remoteusers]}
+ )
+
+ def stop_polling_presence(self, user, target_user=None):
+ logger.debug("Stop polling for presence from %s", user)
+
+ if not target_user or target_user.is_mine:
+ self._stop_polling_local(user, target_user=target_user)
+
+ deferreds = []
+
+ if target_user:
+ raise NotImplementedError("TODO: remove one user")
+
+ remoteusers = [u for u in self._remote_recvmap
+ if user in self._remote_recvmap[u]]
+ remoteusers_by_domain = partition(remoteusers, lambda u: u.domain)
+
+ for domain in remoteusers_by_domain:
+ remoteusers = remoteusers_by_domain[domain]
+
+ deferreds.append(
+ self._stop_polling_remote(user, domain, remoteusers)
+ )
+
+ return defer.DeferredList(deferreds)
+
+ def _stop_polling_local(self, user, target_user):
+ for localpart in self._local_pushmap.keys():
+ if target_user and localpart != target_user.localpart:
+ continue
+
+ if user in self._local_pushmap[localpart]:
+ self._local_pushmap[localpart].remove(user)
+
+ if not self._local_pushmap[localpart]:
+ del self._local_pushmap[localpart]
+
+ def _stop_polling_remote(self, user, domain, remoteusers):
+ for u in remoteusers:
+ self._remote_recvmap[u].remove(user)
+
+ if not self._remote_recvmap[u]:
+ del self._remote_recvmap[u]
+
+ return self.federation.send_edu(
+ destination=domain,
+ edu_type="m.presence",
+ content={"unpoll": [u.to_string() for u in remoteusers]}
+ )
+
+ @defer.inlineCallbacks
+ def push_presence(self, user, statuscache):
+ assert(user.is_mine)
+
+ logger.debug("Pushing presence update from %s", user)
+
+ localusers = set(self._local_pushmap.get(user.localpart, set()))
+ remotedomains = set(self._remote_sendmap.get(user.localpart, set()))
+
+ # Reflect users' status changes back to themselves, so UIs look nice
+ # and also user is informed of server-forced pushes
+ localusers.add(user)
+
+ rm_handler = self.homeserver.get_handlers().room_member_handler
+ room_ids = yield rm_handler.get_rooms_for_user(user)
+
+ for room_id in room_ids:
+ yield rm_handler.fetch_room_distributions_into(
+ room_id, localusers=localusers, remotedomains=remotedomains,
+ ignore_user=user,
+ )
+
+ if not localusers and not remotedomains:
+ defer.returnValue(None)
+
+ yield self._send_presence_to_distribution(user,
+ localusers=localusers, remotedomains=remotedomains,
+ statuscache=statuscache
+ )
+
+ def _send_presence(self, srcuser, destuser, statuscache):
+ if destuser.is_mine:
+ self.push_update_to_clients(
+ observer_user=destuser,
+ observed_user=srcuser,
+ statuscache=statuscache)
+ return defer.succeed(None)
+ else:
+ return self._push_presence_remote(srcuser, destuser.domain,
+ state=statuscache.get_state()
+ )
+
+ @defer.inlineCallbacks
+ def _send_presence_to_distribution(self, srcuser, localusers=set(),
+ remotedomains=set(), statuscache=None):
+
+ for u in localusers:
+ logger.debug(" | push to local user %s", u)
+ self.push_update_to_clients(
+ observer_user=u,
+ observed_user=srcuser,
+ statuscache=statuscache,
+ )
+
+ deferreds = []
+ for domain in remotedomains:
+ logger.debug(" | push to remote domain %s", domain)
+ deferreds.append(self._push_presence_remote(srcuser, domain,
+ state=statuscache.get_state())
+ )
+
+ yield defer.DeferredList(deferreds)
+
+ @defer.inlineCallbacks
+ def _push_presence_remote(self, user, destination, state=None):
+ if state is None:
+ state = yield self.store.get_presence_state(user.localpart)
+ yield self.distributor.fire(
+ "collect_presencelike_data", user, state
+ )
+
+ yield self.federation.send_edu(
+ destination=destination,
+ edu_type="m.presence",
+ content={
+ "push": [
+ dict(user_id=user.to_string(), **state),
+ ],
+ }
+ )
+
+ @defer.inlineCallbacks
+ def incoming_presence(self, origin, content):
+ deferreds = []
+
+ for push in content.get("push", []):
+ user = self.hs.parse_userid(push["user_id"])
+
+ logger.debug("Incoming presence update from %s", user)
+
+ observers = set(self._remote_recvmap.get(user, set()))
+
+ rm_handler = self.homeserver.get_handlers().room_member_handler
+ room_ids = yield rm_handler.get_rooms_for_user(user)
+
+ for room_id in room_ids:
+ yield rm_handler.fetch_room_distributions_into(
+ room_id, localusers=observers, ignore_user=user
+ )
+
+ if not observers:
+ break
+
+ state = dict(push)
+ del state["user_id"]
+
+ statuscache = self._get_or_make_usercache(user)
+
+ self._user_cachemap_latest_serial += 1
+ statuscache.update(state, serial=self._user_cachemap_latest_serial)
+
+ for observer_user in observers:
+ self.push_update_to_clients(
+ observer_user=observer_user,
+ observed_user=user,
+ statuscache=statuscache,
+ )
+
+ if state["state"] == PresenceState.OFFLINE:
+ del self._user_cachemap[user]
+
+ for poll in content.get("poll", []):
+ user = self.hs.parse_userid(poll)
+
+ if not user.is_mine:
+ continue
+
+ # TODO(paul) permissions checks
+
+ if not user in self._remote_sendmap:
+ self._remote_sendmap[user] = set()
+
+ self._remote_sendmap[user].add(origin)
+
+ deferreds.append(self._push_presence_remote(user, origin))
+
+ for unpoll in content.get("unpoll", []):
+ user = self.hs.parse_userid(unpoll)
+
+ if not user.is_mine:
+ continue
+
+ if user in self._remote_sendmap:
+ self._remote_sendmap[user].remove(origin)
+
+ if not self._remote_sendmap[user]:
+ del self._remote_sendmap[user]
+
+ yield defer.DeferredList(deferreds)
+
+ def push_update_to_clients(self, observer_user, observed_user,
+ statuscache):
+ self.notifier.on_new_user_event(
+ observer_user.to_string(),
+ event_data=statuscache.make_event(user=observed_user),
+ stream_type=PresenceStreamData,
+ store_id=statuscache.serial
+ )
+
+
+class PresenceStreamData(StreamData):
+ def __init__(self, hs):
+ super(PresenceStreamData, self).__init__(hs)
+ self.presence = hs.get_handlers().presence_handler
+
+ def get_rows(self, user_id, from_key, to_key, limit):
+ cachemap = self.presence._user_cachemap
+
+ # TODO(paul): limit, and filter by visibility
+ updates = [(k, cachemap[k]) for k in cachemap
+ if from_key < cachemap[k].serial <= to_key]
+
+ if updates:
+ latest_serial = max([x[1].serial for x in updates])
+ data = [x[1].make_event(user=x[0]) for x in updates]
+ return ((data, latest_serial))
+ else:
+ return (([], self.presence._user_cachemap_latest_serial))
+
+ def max_token(self):
+ return self.presence._user_cachemap_latest_serial
+
+PresenceStreamData.EVENT_TYPE = PresenceStreamData
+
+
+class UserPresenceCache(object):
+ """Store an observed user's state and status message.
+
+ Includes the update timestamp.
+ """
+ def __init__(self):
+ self.state = {}
+ self.serial = None
+
+ def update(self, state, serial):
+ self.state.update(state)
+ # Delete keys that are now 'None'
+ for k in self.state.keys():
+ if self.state[k] is None:
+ del self.state[k]
+
+ self.serial = serial
+
+ if "status_msg" in state:
+ self.status_msg = state["status_msg"]
+ else:
+ self.status_msg = None
+
+ def get_state(self):
+ # clone it so caller can't break our cache
+ return dict(self.state)
+
+ def make_event(self, user):
+ content = self.get_state()
+ content["user_id"] = user.to_string()
+
+ return {"type": "m.presence", "content": content}
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
new file mode 100644
index 0000000000..a27206b002
--- /dev/null
+++ b/synapse/handlers/profile.py
@@ -0,0 +1,169 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError, AuthError
+
+from synapse.api.errors import CodeMessageException
+
+from ._base import BaseHandler
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+PREFIX = "/matrix/client/api/v1"
+
+
+class ProfileHandler(BaseHandler):
+
+ def __init__(self, hs):
+ super(ProfileHandler, self).__init__(hs)
+
+ self.client = hs.get_http_client()
+
+ distributor = hs.get_distributor()
+ self.distributor = distributor
+
+ distributor.observe("registered_user", self.registered_user)
+
+ distributor.observe(
+ "collect_presencelike_data", self.collect_presencelike_data
+ )
+
+ def registered_user(self, user):
+ self.store.create_profile(user.localpart)
+
+ @defer.inlineCallbacks
+ def get_displayname(self, target_user, local_only=False):
+ if target_user.is_mine:
+ displayname = yield self.store.get_profile_displayname(
+ target_user.localpart
+ )
+
+ defer.returnValue(displayname)
+ elif not local_only:
+ # TODO(paul): This should use the server-server API to ask another
+ # HS. For now we'll just have it use the http client to talk to the
+ # other HS's REST client API
+ path = PREFIX + "/profile/%s/displayname?local_only=1" % (
+ target_user.to_string()
+ )
+
+ try:
+ result = yield self.client.get_json(
+ destination=target_user.domain,
+ path=path
+ )
+ except CodeMessageException as e:
+ if e.code != 404:
+ logger.exception("Failed to get displayname")
+
+ raise
+ except:
+ logger.exception("Failed to get displayname")
+
+ defer.returnValue(result["displayname"])
+ else:
+ raise SynapseError(400, "User is not hosted on this Home Server")
+
+ @defer.inlineCallbacks
+ def set_displayname(self, target_user, auth_user, new_displayname):
+ """target_user is the user whose displayname is to be changed;
+ auth_user is the user attempting to make this change."""
+ if not target_user.is_mine:
+ raise SynapseError(400, "User is not hosted on this Home Server")
+
+ if target_user != auth_user:
+ raise AuthError(400, "Cannot set another user's displayname")
+
+ yield self.store.set_profile_displayname(
+ target_user.localpart, new_displayname
+ )
+
+ yield self.distributor.fire(
+ "changed_presencelike_data", target_user, {
+ "displayname": new_displayname,
+ }
+ )
+
+ @defer.inlineCallbacks
+ def get_avatar_url(self, target_user, local_only=False):
+ if target_user.is_mine:
+ avatar_url = yield self.store.get_profile_avatar_url(
+ target_user.localpart
+ )
+
+ defer.returnValue(avatar_url)
+ elif not local_only:
+ # TODO(paul): This should use the server-server API to ask another
+ # HS. For now we'll just have it use the http client to talk to the
+ # other HS's REST client API
+ destination = target_user.domain
+ path = PREFIX + "/profile/%s/avatar_url?local_only=1" % (
+ target_user.to_string(),
+ )
+
+ try:
+ result = yield self.client.get_json(
+ destination=destination,
+ path=path
+ )
+ except CodeMessageException as e:
+ if e.code != 404:
+ logger.exception("Failed to get avatar_url")
+ raise
+ except:
+ logger.exception("Failed to get avatar_url")
+
+ defer.returnValue(result["avatar_url"])
+ else:
+ raise SynapseError(400, "User is not hosted on this Home Server")
+
+ @defer.inlineCallbacks
+ def set_avatar_url(self, target_user, auth_user, new_avatar_url):
+ """target_user is the user whose avatar_url is to be changed;
+ auth_user is the user attempting to make this change."""
+ if not target_user.is_mine:
+ raise SynapseError(400, "User is not hosted on this Home Server")
+
+ if target_user != auth_user:
+ raise AuthError(400, "Cannot set another user's avatar_url")
+
+ yield self.store.set_profile_avatar_url(
+ target_user.localpart, new_avatar_url
+ )
+
+ yield self.distributor.fire(
+ "changed_presencelike_data", target_user, {
+ "avatar_url": new_avatar_url,
+ }
+ )
+
+ @defer.inlineCallbacks
+ def collect_presencelike_data(self, user, state):
+ if not user.is_mine:
+ defer.returnValue(None)
+
+ (displayname, avatar_url) = yield defer.gatherResults([
+ self.store.get_profile_displayname(user.localpart),
+ self.store.get_profile_avatar_url(user.localpart),
+ ])
+
+ state["displayname"] = displayname
+ state["avatar_url"] = avatar_url
+
+ defer.returnValue(None)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
new file mode 100644
index 0000000000..246c1f6530
--- /dev/null
+++ b/synapse/handlers/register.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains functions for registering clients."""
+from twisted.internet import defer
+
+from synapse.types import UserID
+from synapse.api.errors import SynapseError, RegistrationError
+from ._base import BaseHandler
+import synapse.util.stringutils as stringutils
+
+import base64
+import bcrypt
+
+
+class RegistrationHandler(BaseHandler):
+
+ def __init__(self, hs):
+ super(RegistrationHandler, self).__init__(hs)
+
+ self.distributor = hs.get_distributor()
+ self.distributor.declare("registered_user")
+
+ @defer.inlineCallbacks
+ def register(self, localpart=None, password=None):
+ """Registers a new client on the server.
+
+ Args:
+ localpart : The local part of the user ID to register. If None,
+ one will be randomly generated.
+ password (str) : The password to assign to this user so they can
+ login again.
+ Returns:
+ A tuple of (user_id, access_token).
+ Raises:
+ RegistrationError if there was a problem registering.
+ """
+ password_hash = None
+ if password:
+ password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
+
+ if localpart:
+ user = UserID(localpart, self.hs.hostname, True)
+ user_id = user.to_string()
+
+ token = self._generate_token(user_id)
+ yield self.store.register(user_id=user_id,
+ token=token,
+ password_hash=password_hash)
+
+ self.distributor.fire("registered_user", user)
+ defer.returnValue((user_id, token))
+ else:
+ # autogen a random user ID
+ attempts = 0
+ user_id = None
+ token = None
+ while not user_id and not token:
+ try:
+ localpart = self._generate_user_id()
+ user = UserID(localpart, self.hs.hostname, True)
+ user_id = user.to_string()
+
+ token = self._generate_token(user_id)
+ yield self.store.register(
+ user_id=user_id,
+ token=token,
+ password_hash=password_hash)
+
+ self.distributor.fire("registered_user", user)
+ defer.returnValue((user_id, token))
+ except SynapseError:
+ # if user id is taken, just generate another
+ user_id = None
+ token = None
+ attempts += 1
+ if attempts > 5:
+ raise RegistrationError(
+ 500, "Cannot generate user ID.")
+
+ def _generate_token(self, user_id):
+ # urlsafe variant uses _ and - so use . as the separator and replace
+ # all =s with .s so http clients don't quote =s when it is used as
+ # query params.
+ return (base64.urlsafe_b64encode(user_id).replace('=', '.') + '.' +
+ stringutils.random_string(18))
+
+ def _generate_user_id(self):
+ return "-" + stringutils.random_string(18)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
new file mode 100644
index 0000000000..4d82b33993
--- /dev/null
+++ b/synapse/handlers/room.py
@@ -0,0 +1,808 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains functions for performing events on rooms."""
+from twisted.internet import defer
+
+from synapse.types import UserID, RoomAlias, RoomID
+from synapse.api.constants import Membership
+from synapse.api.errors import RoomError, StoreError, SynapseError
+from synapse.api.events.room import (
+ RoomTopicEvent, MessageEvent, InviteJoinEvent, RoomMemberEvent,
+ RoomConfigEvent
+)
+from synapse.api.streams.event import EventStream, MessagesStreamData
+from synapse.util import stringutils
+from ._base import BaseHandler
+
+import logging
+import json
+
+logger = logging.getLogger(__name__)
+
+
+class MessageHandler(BaseHandler):
+
+ def __init__(self, hs):
+ super(MessageHandler, self).__init__(hs)
+ self.hs = hs
+ self.clock = hs.get_clock()
+ self.event_factory = hs.get_event_factory()
+
+ @defer.inlineCallbacks
+ def get_message(self, msg_id=None, room_id=None, sender_id=None,
+ user_id=None):
+ """ Retrieve a message.
+
+ Args:
+ msg_id (str): The message ID to obtain.
+ room_id (str): The room where the message resides.
+ sender_id (str): The user ID of the user who sent the message.
+ user_id (str): The user ID of the user making this request.
+ Returns:
+ The message, or None if no message exists.
+ Raises:
+ SynapseError if something went wrong.
+ """
+ yield self.auth.check_joined_room(room_id, user_id)
+
+ # Pull out the message from the db
+ msg = yield self.store.get_message(room_id=room_id,
+ msg_id=msg_id,
+ user_id=sender_id)
+
+ if msg:
+ defer.returnValue(msg)
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def send_message(self, event=None, suppress_auth=False, stamp_event=True):
+ """ Send a message.
+
+ Args:
+ event : The message event to store.
+ suppress_auth (bool) : True to suppress auth for this message. This
+ is primarily so the home server can inject messages into rooms at
+ will.
+ stamp_event (bool) : True to stamp event content with server keys.
+ Raises:
+ SynapseError if something went wrong.
+ """
+ if stamp_event:
+ event.content["hsob_ts"] = int(self.clock.time_msec())
+
+ with (yield self.room_lock.lock(event.room_id)):
+ if not suppress_auth:
+ yield self.auth.check(event, raises=True)
+
+ # store message in db
+ store_id = yield self.store.persist_event(event)
+
+ event.destinations = yield self.store.get_joined_hosts_for_room(
+ event.room_id
+ )
+
+ yield self.hs.get_federation().handle_new_event(event)
+
+ self.notifier.on_new_room_event(event, store_id)
+
+ @defer.inlineCallbacks
+ def get_messages(self, user_id=None, room_id=None, pagin_config=None,
+ feedback=False):
+ """Get messages in a room.
+
+ Args:
+ user_id (str): The user requesting messages.
+ room_id (str): The room they want messages from.
+ pagin_config (synapse.api.streams.PaginationConfig): The pagination
+ config rules to apply, if any.
+ feedback (bool): True to get compressed feedback with the messages
+ Returns:
+ dict: Pagination API results
+ """
+ yield self.auth.check_joined_room(room_id, user_id)
+
+ data_source = [MessagesStreamData(self.hs, room_id=room_id,
+ feedback=feedback)]
+ event_stream = EventStream(user_id, data_source)
+ pagin_config = yield event_stream.fix_tokens(pagin_config)
+ data_chunk = yield event_stream.get_chunk(config=pagin_config)
+ defer.returnValue(data_chunk)
+
+ @defer.inlineCallbacks
+ def store_room_data(self, event=None, stamp_event=True):
+ """ Stores data for a room.
+
+ Args:
+ event : The room path event
+ stamp_event (bool) : True to stamp event content with server keys.
+ Raises:
+ SynapseError if something went wrong.
+ """
+
+ with (yield self.room_lock.lock(event.room_id)):
+ yield self.auth.check(event, raises=True)
+
+ if stamp_event:
+ event.content["hsob_ts"] = int(self.clock.time_msec())
+
+ yield self.state_handler.handle_new_event(event)
+
+ # store in db
+ store_id = yield self.store.store_room_data(
+ room_id=event.room_id,
+ etype=event.type,
+ state_key=event.state_key,
+ content=json.dumps(event.content)
+ )
+
+ event.destinations = yield self.store.get_joined_hosts_for_room(
+ event.room_id
+ )
+ self.notifier.on_new_room_event(event, store_id)
+
+ yield self.hs.get_federation().handle_new_event(event)
+
+ @defer.inlineCallbacks
+ def get_room_data(self, user_id=None, room_id=None,
+ event_type=None, state_key="",
+ public_room_rules=[],
+ private_room_rules=["join"]):
+ """ Get data from a room.
+
+ Args:
+ event : The room path event
+ public_room_rules : A list of membership states the user can be in,
+ in order to read this data IN A PUBLIC ROOM. An empty list means
+ 'any state'.
+ private_room_rules : A list of membership states the user can be
+ in, in order to read this data IN A PRIVATE ROOM. An empty list
+ means 'any state'.
+ Returns:
+ The path data content.
+ Raises:
+ SynapseError if something went wrong.
+ """
+ if event_type == RoomTopicEvent.TYPE:
+ # anyone invited/joined can read the topic
+ private_room_rules = ["invite", "join"]
+
+ # does this room exist
+ room = yield self.store.get_room(room_id)
+ if not room:
+ raise RoomError(403, "Room does not exist.")
+
+ # does this user exist in this room
+ member = yield self.store.get_room_member(
+ room_id=room_id,
+ user_id="" if not user_id else user_id)
+
+ member_state = member.membership if member else None
+
+ if room.is_public and public_room_rules:
+ # make sure the user meets public room rules
+ if member_state not in public_room_rules:
+ raise RoomError(403, "Member does not meet public room rules.")
+ elif not room.is_public and private_room_rules:
+ # make sure the user meets private room rules
+ if member_state not in private_room_rules:
+ raise RoomError(
+ 403, "Member does not meet private room rules.")
+
+ data = yield self.store.get_room_data(room_id, event_type, state_key)
+ defer.returnValue(data)
+
+ @defer.inlineCallbacks
+ def get_feedback(self, room_id=None, msg_sender_id=None, msg_id=None,
+ user_id=None, fb_sender_id=None, fb_type=None):
+ yield self.auth.check_joined_room(room_id, user_id)
+
+ # Pull out the feedback from the db
+ fb = yield self.store.get_feedback(
+ room_id=room_id, msg_id=msg_id, msg_sender_id=msg_sender_id,
+ fb_sender_id=fb_sender_id, fb_type=fb_type
+ )
+
+ if fb:
+ defer.returnValue(fb)
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def send_feedback(self, event, stamp_event=True):
+ if stamp_event:
+ event.content["hsob_ts"] = int(self.clock.time_msec())
+
+ with (yield self.room_lock.lock(event.room_id)):
+ yield self.auth.check(event, raises=True)
+
+ # store message in db
+ store_id = yield self.store.persist_event(event)
+
+ event.destinations = yield self.store.get_joined_hosts_for_room(
+ event.room_id
+ )
+ yield self.hs.get_federation().handle_new_event(event)
+
+ self.notifier.on_new_room_event(event, store_id)
+
+ @defer.inlineCallbacks
+ def snapshot_all_rooms(self, user_id=None, pagin_config=None,
+ feedback=False):
+ """Retrieve a snapshot of all rooms the user is invited or has joined.
+
+ This snapshot may include messages for all rooms where the user is
+ joined, depending on the pagination config.
+
+ Args:
+ user_id (str): The ID of the user making the request.
+ pagin_config (synapse.api.streams.PaginationConfig): The pagination
+ config used to determine how many messages *PER ROOM* to return.
+ feedback (bool): True to get feedback along with these messages.
+ Returns:
+ A list of dicts with "room_id" and "membership" keys for all rooms
+ the user is currently invited or joined in on. Rooms where the user
+ is joined on, may return a "messages" key with messages, depending
+ on the specified PaginationConfig.
+ """
+ room_list = yield self.store.get_rooms_for_user_where_membership_is(
+ user_id=user_id,
+ membership_list=[Membership.INVITE, Membership.JOIN]
+ )
+ for room_info in room_list:
+ if room_info["membership"] != Membership.JOIN:
+ continue
+ try:
+ event_chunk = yield self.get_messages(
+ user_id=user_id,
+ pagin_config=pagin_config,
+ feedback=feedback,
+ room_id=room_info["room_id"]
+ )
+ room_info["messages"] = event_chunk
+ except:
+ pass
+ defer.returnValue(room_list)
+
+
+class RoomCreationHandler(BaseHandler):
+
+ @defer.inlineCallbacks
+ def create_room(self, user_id, room_id, config):
+ """ Creates a new room.
+
+ Args:
+ user_id (str): The ID of the user creating the new room.
+ room_id (str): The proposed ID for the new room. Can be None, in
+ which case one will be created for you.
+ config (dict) : A dict of configuration options.
+ Returns:
+ The new room ID.
+ Raises:
+ SynapseError if the room ID was taken, couldn't be stored, or
+ something went horribly wrong.
+ """
+
+ if "room_alias_name" in config:
+ room_alias = RoomAlias.create_local(
+ config["room_alias_name"],
+ self.hs
+ )
+ mapping = yield self.store.get_association_from_room_alias(
+ room_alias
+ )
+
+ if mapping:
+ raise SynapseError(400, "Room alias already taken")
+ else:
+ room_alias = None
+
+ if room_id:
+ # Ensure room_id is the correct type
+ room_id_obj = RoomID.from_string(room_id, self.hs)
+ if not room_id_obj.is_mine:
+ raise SynapseError(400, "Room id must be local")
+
+ yield self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id=user_id,
+ is_public=config["visibility"] == "public"
+ )
+ else:
+ # autogen room IDs and try to create it. We may clash, so just
+ # try a few times till one goes through, giving up eventually.
+ attempts = 0
+ room_id = None
+ while attempts < 5:
+ try:
+ random_string = stringutils.random_string(18)
+ gen_room_id = RoomID.create_local(random_string, self.hs)
+ yield self.store.store_room(
+ room_id=gen_room_id.to_string(),
+ room_creator_user_id=user_id,
+ is_public=config["visibility"] == "public"
+ )
+ room_id = gen_room_id.to_string()
+ break
+ except StoreError:
+ attempts += 1
+ if not room_id:
+ raise StoreError(500, "Couldn't generate a room ID.")
+
+ config_event = self.event_factory.create_event(
+ etype=RoomConfigEvent.TYPE,
+ room_id=room_id,
+ user_id=user_id,
+ content=config,
+ )
+
+ if room_alias:
+ yield self.store.create_room_alias_association(
+ room_id=room_id,
+ room_alias=room_alias,
+ servers=[self.hs.hostname],
+ )
+
+ yield self.state_handler.handle_new_event(config_event)
+ # store_id = persist...
+
+ yield self.hs.get_federation().handle_new_event(config_event)
+ # self.notifier.on_new_room_event(event, store_id)
+
+ content = {"membership": Membership.JOIN}
+ join_event = self.event_factory.create_event(
+ etype=RoomMemberEvent.TYPE,
+ target_user_id=user_id,
+ room_id=room_id,
+ user_id=user_id,
+ membership=Membership.JOIN,
+ content=content
+ )
+
+ yield self.hs.get_handlers().room_member_handler.change_membership(
+ join_event,
+ broadcast_msg=True,
+ do_auth=False
+ )
+
+ result = {"room_id": room_id}
+ if room_alias:
+ result["room_alias"] = room_alias.to_string()
+
+ defer.returnValue(result)
+
+
+class RoomMemberHandler(BaseHandler):
+ # TODO(paul): This handler currently contains a messy conflation of
+ # low-level API that works on UserID objects and so on, and REST-level
+ # API that takes ID strings and returns pagination chunks. These concerns
+ # ought to be separated out a lot better.
+
+ def __init__(self, hs):
+ super(RoomMemberHandler, self).__init__(hs)
+
+ self.clock = hs.get_clock()
+
+ self.distributor = hs.get_distributor()
+ self.distributor.declare("user_joined_room")
+
+ @defer.inlineCallbacks
+ def get_room_members(self, room_id, membership=Membership.JOIN):
+ hs = self.hs
+
+ memberships = yield self.store.get_room_members(
+ room_id=room_id, membership=membership
+ )
+
+ defer.returnValue([hs.parse_userid(m.user_id) for m in memberships])
+
+ @defer.inlineCallbacks
+ def fetch_room_distributions_into(self, room_id, localusers=None,
+ remotedomains=None, ignore_user=None):
+ """Fetch the distribution of a room, adding elements to either
+ 'localusers' or 'remotedomains', which should be a set() if supplied.
+ If ignore_user is set, ignore that user.
+
+ This function returns nothing; its result is performed by the
+ side-effect on the two passed sets. This allows easy accumulation of
+ member lists of multiple rooms at once if required.
+ """
+ members = yield self.get_room_members(room_id)
+ for member in members:
+ if ignore_user is not None and member == ignore_user:
+ continue
+
+ if member.is_mine:
+ if localusers is not None:
+ localusers.add(member)
+ else:
+ if remotedomains is not None:
+ remotedomains.add(member.domain)
+
+ @defer.inlineCallbacks
+ def get_room_members_as_pagination_chunk(self, room_id=None, user_id=None,
+ limit=0, start_tok=None,
+ end_tok=None):
+ """Retrieve a list of room members in the room.
+
+ Args:
+ room_id (str): The room to get the member list for.
+ user_id (str): The ID of the user making the request.
+ limit (int): The max number of members to return.
+ start_tok (str): Optional. The start token if known.
+ end_tok (str): Optional. The end token if known.
+ Returns:
+ dict: A Pagination streamable dict.
+ Raises:
+ SynapseError if something goes wrong.
+ """
+ yield self.auth.check_joined_room(room_id, user_id)
+
+ member_list = yield self.store.get_room_members(room_id=room_id)
+ event_list = [
+ entry.as_event(self.event_factory).get_dict()
+ for entry in member_list
+ ]
+ chunk_data = {
+ "start": "START",
+ "end": "END",
+ "chunk": event_list
+ }
+ # TODO honor Pagination stream params
+ # TODO snapshot this list to return on subsequent requests when
+ # paginating
+ defer.returnValue(chunk_data)
+
+ @defer.inlineCallbacks
+ def get_room_member(self, room_id, member_user_id, auth_user_id):
+ """Retrieve a room member from a room.
+
+ Args:
+ room_id : The room the member is in.
+ member_user_id : The member's user ID
+ auth_user_id : The user ID of the user making this request.
+ Returns:
+ The room member, or None if this member does not exist.
+ Raises:
+ SynapseError if something goes wrong.
+ """
+ yield self.auth.check_joined_room(room_id, auth_user_id)
+
+ member = yield self.store.get_room_member(user_id=member_user_id,
+ room_id=room_id)
+ defer.returnValue(member)
+
+ @defer.inlineCallbacks
+ def change_membership(self, event=None, broadcast_msg=False, do_auth=True):
+ """ Change the membership status of a user in a room.
+
+ Args:
+ event (SynapseEvent): The membership event
+ broadcast_msg (bool): True to inject a membership message into this
+ room on success.
+ Raises:
+ SynapseError if there was a problem changing the membership.
+ """
+
+ #broadcast_msg = False
+
+ prev_state = yield self.store.get_room_member(
+ event.target_user_id, event.room_id
+ )
+
+ if prev_state and prev_state.membership == event.membership:
+ # treat this event as a NOOP.
+ if do_auth: # This is mainly to fix a unit test.
+ yield self.auth.check(event, raises=True)
+ defer.returnValue({})
+ return
+
+ room_id = event.room_id
+
+ # If we're trying to join a room then we have to do this differently
+ # if this HS is not currently in the room, i.e. we have to do the
+ # invite/join dance.
+ if event.membership == Membership.JOIN:
+ yield self._do_join(
+ event, do_auth=do_auth, broadcast_msg=broadcast_msg
+ )
+ else:
+ # This is not a JOIN, so we can handle it normally.
+ if do_auth:
+ yield self.auth.check(event, raises=True)
+
+ prev_state = yield self.store.get_room_member(
+ event.target_user_id, event.room_id
+ )
+ if prev_state and prev_state.membership == event.membership:
+ # double same action, treat this event as a NOOP.
+ defer.returnValue({})
+ return
+
+ yield self.state_handler.handle_new_event(event)
+ yield self._do_local_membership_update(
+ event,
+ membership=event.content["membership"],
+ broadcast_msg=broadcast_msg,
+ )
+
+ defer.returnValue({"room_id": room_id})
+
+ @defer.inlineCallbacks
+ def join_room_alias(self, joinee, room_alias, do_auth=True, content={}):
+ directory_handler = self.hs.get_handlers().directory_handler
+ mapping = yield directory_handler.get_association(room_alias)
+
+ if not mapping:
+ raise SynapseError(404, "No such room alias")
+
+ room_id = mapping["room_id"]
+ hosts = mapping["servers"]
+ if not hosts:
+ raise SynapseError(404, "No known servers")
+
+ host = hosts[0]
+
+ content.update({"membership": Membership.JOIN})
+ new_event = self.event_factory.create_event(
+ etype=RoomMemberEvent.TYPE,
+ target_user_id=joinee.to_string(),
+ room_id=room_id,
+ user_id=joinee.to_string(),
+ membership=Membership.JOIN,
+ content=content,
+ )
+
+ yield self._do_join(new_event, room_host=host, do_auth=True)
+
+ defer.returnValue({"room_id": room_id})
+
+ @defer.inlineCallbacks
+ def _do_join(self, event, room_host=None, do_auth=True, broadcast_msg=True):
+ joinee = self.hs.parse_userid(event.target_user_id)
+ # room_id = RoomID.from_string(event.room_id, self.hs)
+ room_id = event.room_id
+
+ # If event doesn't include a display name, add one.
+ yield self._fill_out_join_content(
+ joinee, event.content
+ )
+
+ # XXX: We don't do an auth check if we are doing an invite
+ # join dance for now, since we're kinda implicitly checking
+ # that we are allowed to join when we decide whether or not we
+ # need to do the invite/join dance.
+
+ room = yield self.store.get_room(room_id)
+
+ if room:
+ should_do_dance = False
+ elif room_host:
+ should_do_dance = True
+ else:
+ prev_state = yield self.store.get_room_member(
+ joinee.to_string(), room_id
+ )
+
+ if prev_state and prev_state.membership == Membership.INVITE:
+ room = yield self.store.get_room(room_id)
+ inviter = UserID.from_string(
+ prev_state.sender, self.hs
+ )
+
+ should_do_dance = not inviter.is_mine and not room
+ room_host = inviter.domain
+ else:
+ should_do_dance = False
+
+ # We want to do the _do_update inside the room lock.
+ if not should_do_dance:
+ logger.debug("Doing normal join")
+
+ if do_auth:
+ yield self.auth.check(event, raises=True)
+
+ yield self.state_handler.handle_new_event(event)
+ yield self._do_local_membership_update(
+ event,
+ membership=event.content["membership"],
+ broadcast_msg=broadcast_msg,
+ )
+
+
+ if should_do_dance:
+ yield self._do_invite_join_dance(
+ room_id=room_id,
+ joinee=event.user_id,
+ target_host=room_host,
+ content=event.content,
+ )
+
+ user = self.hs.parse_userid(event.user_id)
+ self.distributor.fire(
+ "user_joined_room", user=user, room_id=room_id
+ )
+
+ @defer.inlineCallbacks
+ def _fill_out_join_content(self, user_id, content):
+ # If event doesn't include a display name, add one.
+ profile_handler = self.hs.get_handlers().profile_handler
+ if "displayname" not in content:
+ try:
+ display_name = yield profile_handler.get_displayname(
+ user_id
+ )
+
+ if display_name:
+ content["displayname"] = display_name
+ except:
+ logger.exception("Failed to set display_name")
+
+ if "avatar_url" not in content:
+ try:
+ avatar_url = yield profile_handler.get_avatar_url(
+ user_id
+ )
+
+ if avatar_url:
+ content["avatar_url"] = avatar_url
+ except:
+ logger.exception("Failed to set display_name")
+
+ @defer.inlineCallbacks
+ def _should_invite_join(self, room_id, prev_state, do_auth):
+ logger.debug("_should_invite_join: room_id: %s", room_id)
+
+ # XXX: We don't do an auth check if we are doing an invite
+ # join dance for now, since we're kinda implicitly checking
+ # that we are allowed to join when we decide whether or not we
+ # need to do the invite/join dance.
+
+ # Only do an invite join dance if a) we were invited,
+ # b) the person inviting was from a differnt HS and c) we are
+ # not currently in the room
+ room_host = None
+ if prev_state and prev_state.membership == Membership.INVITE:
+ room = yield self.store.get_room(room_id)
+ inviter = UserID.from_string(
+ prev_state.sender, self.hs
+ )
+
+ is_remote_invite_join = not inviter.is_mine and not room
+ room_host = inviter.domain
+ else:
+ is_remote_invite_join = False
+
+ defer.returnValue((is_remote_invite_join, room_host))
+
+ @defer.inlineCallbacks
+ def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]):
+ """Returns a list of roomids that the user has any of the given
+ membership states in."""
+ rooms = yield self.store.get_rooms_for_user_where_membership_is(
+ user_id=user.to_string(), membership_list=membership_list
+ )
+
+ defer.returnValue([r["room_id"] for r in rooms])
+
+ @defer.inlineCallbacks
+ def _do_local_membership_update(self, event, membership, broadcast_msg):
+ # store membership
+ store_id = yield self.store.store_room_member(
+ user_id=event.target_user_id,
+ sender=event.user_id,
+ room_id=event.room_id,
+ content=event.content,
+ membership=membership
+ )
+
+ # Send a PDU to all hosts who have joined the room.
+ destinations = yield self.store.get_joined_hosts_for_room(
+ event.room_id
+ )
+
+ # If we're inviting someone, then we should also send it to that
+ # HS.
+ if membership == Membership.INVITE:
+ host = UserID.from_string(
+ event.target_user_id, self.hs
+ ).domain
+ destinations.append(host)
+
+ # If we are joining a remote HS, include that.
+ if membership == Membership.JOIN:
+ host = UserID.from_string(
+ event.target_user_id, self.hs
+ ).domain
+ destinations.append(host)
+
+ event.destinations = list(set(destinations))
+
+ yield self.hs.get_federation().handle_new_event(event)
+ self.notifier.on_new_room_event(event, store_id)
+
+ if broadcast_msg:
+ yield self._inject_membership_msg(
+ source=event.user_id,
+ target=event.target_user_id,
+ room_id=event.room_id,
+ membership=event.content["membership"]
+ )
+
+ @defer.inlineCallbacks
+ def _do_invite_join_dance(self, room_id, joinee, target_host, content):
+ logger.debug("Doing remote join dance")
+
+ # do invite join dance
+ federation = self.hs.get_federation()
+ new_event = self.event_factory.create_event(
+ etype=InviteJoinEvent.TYPE,
+ target_host=target_host,
+ room_id=room_id,
+ user_id=joinee,
+ content=content
+ )
+
+ new_event.destinations = [target_host]
+
+ yield self.store.store_room(
+ room_id, "", is_public=False
+ )
+
+ #yield self.state_handler.handle_new_event(event)
+ yield federation.handle_new_event(new_event)
+ yield federation.get_state_for_room(
+ target_host, room_id
+ )
+
+ @defer.inlineCallbacks
+ def _inject_membership_msg(self, room_id=None, source=None, target=None,
+ membership=None):
+ # TODO this should be a different type of message, not m.text
+ if membership == Membership.INVITE:
+ body = "%s invited %s to the room." % (source, target)
+ elif membership == Membership.JOIN:
+ body = "%s joined the room." % (target)
+ elif membership == Membership.LEAVE:
+ body = "%s left the room." % (target)
+ else:
+ raise RoomError(500, "Unknown membership value %s" % membership)
+
+ membership_json = {
+ "msgtype": u"m.text",
+ "body": body,
+ "membership_source": source,
+ "membership_target": target,
+ "membership": membership,
+ }
+
+ msg_id = "m%s" % int(self.clock.time_msec())
+
+ event = self.event_factory.create_event(
+ etype=MessageEvent.TYPE,
+ room_id=room_id,
+ user_id="_homeserver_",
+ msg_id=msg_id,
+ content=membership_json
+ )
+
+ handler = self.hs.get_handlers().message_handler
+ yield handler.send_message(event, suppress_auth=True)
+
+
+class RoomListHandler(BaseHandler):
+
+ @defer.inlineCallbacks
+ def get_public_room_list(self):
+ chunk = yield self.store.get_rooms(is_public=True, with_topics=True)
+ defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
new file mode 100644
index 0000000000..fe8a073cd3
--- /dev/null
+++ b/synapse/http/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/http/client.py b/synapse/http/client.py
new file mode 100644
index 0000000000..bb22b0ee9a
--- /dev/null
+++ b/synapse/http/client.py
@@ -0,0 +1,246 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, reactor
+from twisted.web.client import _AgentBase, _URI, readBody
+from twisted.web.http_headers import Headers
+
+from synapse.http.endpoint import matrix_endpoint
+from synapse.util.async import sleep
+
+from syutil.jsonutil import encode_canonical_json
+
+from synapse.api.errors import CodeMessageException
+
+import json
+import logging
+import urllib
+
+
+logger = logging.getLogger(__name__)
+
+
+_destination_mappings = {
+ "red": "localhost:8080",
+ "blue": "localhost:8081",
+ "green": "localhost:8082",
+}
+
+
+class HttpClient(object):
+ """ Interface for talking json over http
+ """
+
+ def put_json(self, destination, path, data):
+ """ Sends the specifed json data using PUT
+
+ Args:
+ destination (str): The remote server to send the HTTP request
+ to.
+ path (str): The HTTP path.
+ data (dict): A dict containing the data that will be used as
+ the request body. This will be encoded as JSON.
+
+ Returns:
+ Deferred: Succeeds when we get *any* HTTP response.
+
+ The result of the deferred is a tuple of `(code, response)`,
+ where `response` is a dict representing the decoded JSON body.
+ """
+ pass
+
+ def get_json(self, destination, path, args=None):
+ """ Get's some json from the given host homeserver and path
+
+ Args:
+ destination (str): The remote server to send the HTTP request
+ to.
+ path (str): The HTTP path.
+ args (dict): A dictionary used to create query strings, defaults to
+ None.
+ **Note**: The value of each key is assumed to be an iterable
+ and *not* a string.
+
+ Returns:
+ Deferred: Succeeds when we get *any* HTTP response.
+
+ The result of the deferred is a tuple of `(code, response)`,
+ where `response` is a dict representing the decoded JSON body.
+ """
+ pass
+
+
+class MatrixHttpAgent(_AgentBase):
+
+ def __init__(self, reactor, pool=None):
+ _AgentBase.__init__(self, reactor, pool)
+
+ def request(self, destination, endpoint, method, path, params, query,
+ headers, body_producer):
+
+ host = b""
+ port = 0
+ fragment = b""
+
+ parsed_URI = _URI(b"http", destination, host, port, path, params,
+ query, fragment)
+
+ # Set the connection pool key to be the destination.
+ key = destination
+
+ return self._requestWithEndpoint(key, endpoint, method, parsed_URI,
+ headers, body_producer,
+ parsed_URI.originForm)
+
+
+class TwistedHttpClient(HttpClient):
+ """ Wrapper around the twisted HTTP client api.
+
+ Attributes:
+ agent (twisted.web.client.Agent): The twisted Agent used to send the
+ requests.
+ """
+
+ def __init__(self):
+ self.agent = MatrixHttpAgent(reactor)
+
+ @defer.inlineCallbacks
+ def put_json(self, destination, path, data):
+ if destination in _destination_mappings:
+ destination = _destination_mappings[destination]
+
+ response = yield self._create_request(
+ destination.encode("ascii"),
+ "PUT",
+ path.encode("ascii"),
+ producer=_JsonProducer(data),
+ headers_dict={"Content-Type": ["application/json"]}
+ )
+
+ logger.debug("Getting resp body")
+ body = yield readBody(response)
+ logger.debug("Got resp body")
+
+ defer.returnValue((response.code, body))
+
+ @defer.inlineCallbacks
+ def get_json(self, destination, path, args={}):
+ if destination in _destination_mappings:
+ destination = _destination_mappings[destination]
+
+ logger.debug("get_json args: %s", args)
+ query_bytes = urllib.urlencode(args, True)
+
+ response = yield self._create_request(
+ destination.encode("ascii"),
+ "GET",
+ path.encode("ascii"),
+ query_bytes
+ )
+
+ body = yield readBody(response)
+
+ defer.returnValue(json.loads(body))
+
+ @defer.inlineCallbacks
+ def _create_request(self, destination, method, path_bytes, param_bytes=b"",
+ query_bytes=b"", producer=None, headers_dict={}):
+ """ Creates and sends a request to the given url
+ """
+ headers_dict[b"User-Agent"] = [b"Synapse"]
+ headers_dict[b"Host"] = [destination]
+
+ logger.debug("Sending request to %s: %s %s;%s?%s",
+ destination, method, path_bytes, param_bytes, query_bytes)
+
+ logger.debug(
+ "Types: %s",
+ [
+ type(destination), type(method), type(path_bytes),
+ type(param_bytes),
+ type(query_bytes)
+ ]
+ )
+
+ retries_left = 5
+
+ # TODO: setup and pass in an ssl_context to enable TLS
+ endpoint = matrix_endpoint(reactor, destination, timeout=10)
+
+ while True:
+ try:
+ response = yield self.agent.request(
+ destination,
+ endpoint,
+ method,
+ path_bytes,
+ param_bytes,
+ query_bytes,
+ Headers(headers_dict),
+ producer
+ )
+
+ logger.debug("Got response to %s", method)
+ break
+ except Exception as e:
+ logger.exception("Got error in _create_request")
+ _print_ex(e)
+
+ if retries_left:
+ yield sleep(2 ** (5 - retries_left))
+ retries_left -= 1
+ else:
+ raise
+
+ if 200 <= response.code < 300:
+ # We need to update the transactions table to say it was sent?
+ pass
+ else:
+ # :'(
+ # Update transactions table?
+ logger.error(
+ "Got response %d %s", response.code, response.phrase
+ )
+ raise CodeMessageException(
+ response.code, response.phrase
+ )
+
+ defer.returnValue(response)
+
+
+def _print_ex(e):
+ if hasattr(e, "reasons") and e.reasons:
+ for ex in e.reasons:
+ _print_ex(ex)
+ else:
+ logger.exception(e)
+
+
+class _JsonProducer(object):
+ """ Used by the twisted http client to create the HTTP body from json
+ """
+ def __init__(self, jsn):
+ self.body = encode_canonical_json(jsn)
+ self.length = len(self.body)
+
+ def startProducing(self, consumer):
+ consumer.write(self.body)
+ return defer.succeed(None)
+
+ def pauseProducing(self):
+ pass
+
+ def stopProducing(self):
+ pass
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
new file mode 100644
index 0000000000..c4e6e63a80
--- /dev/null
+++ b/synapse/http/endpoint.py
@@ -0,0 +1,171 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
+from twisted.internet import defer
+from twisted.internet.error import ConnectError
+from twisted.names import client, dns
+from twisted.names.error import DNSNameError
+
+import collections
+import logging
+import random
+
+
+logger = logging.getLogger(__name__)
+
+
+def matrix_endpoint(reactor, destination, ssl_context_factory=None,
+ timeout=None):
+ """Construct an endpoint for the given matrix destination.
+
+ Args:
+ reactor: Twisted reactor.
+ destination (bytes): The name of the server to connect to.
+ ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory
+ which generates SSL contexts to use for TLS.
+ timeout (int): connection timeout in seconds
+ """
+
+ domain_port = destination.split(":")
+ domain = domain_port[0]
+ port = int(domain_port[1]) if domain_port[1:] else None
+
+ endpoint_kw_args = {}
+
+ if timeout is not None:
+ endpoint_kw_args.update(timeout=timeout)
+
+ if ssl_context_factory is None:
+ transport_endpoint = TCP4ClientEndpoint
+ default_port = 8080
+ else:
+ transport_endpoint = SSL4ClientEndpoint
+ endpoint_kw_args.update(ssl_context_factory=ssl_context_factory)
+ default_port = 443
+
+ if port is None:
+ return SRVClientEndpoint(
+ reactor, "matrix", domain, protocol="tcp",
+ default_port=default_port, endpoint=transport_endpoint,
+ endpoint_kw_args=endpoint_kw_args
+ )
+ else:
+ return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
+
+
+class SRVClientEndpoint(object):
+ """An endpoint which looks up SRV records for a service.
+ Cycles through the list of servers starting with each call to connect
+ picking the next server.
+ Implements twisted.internet.interfaces.IStreamClientEndpoint.
+ """
+
+ _Server = collections.namedtuple(
+ "_Server", "priority weight host port"
+ )
+
+ def __init__(self, reactor, service, domain, protocol="tcp",
+ default_port=None, endpoint=TCP4ClientEndpoint,
+ endpoint_kw_args={}):
+ self.reactor = reactor
+ self.service_name = "_%s._%s.%s" % (service, protocol, domain)
+
+ if default_port is not None:
+ self.default_server = self._Server(
+ host=domain,
+ port=default_port,
+ priority=0,
+ weight=0
+ )
+ else:
+ self.default_server = None
+
+ self.endpoint = endpoint
+ self.endpoint_kw_args = endpoint_kw_args
+
+ self.servers = None
+ self.used_servers = None
+
+ @defer.inlineCallbacks
+ def fetch_servers(self):
+ try:
+ answers, auth, add = yield client.lookupService(self.service_name)
+ except DNSNameError:
+ answers = []
+
+ if (len(answers) == 1
+ and answers[0].type == dns.SRV
+ and answers[0].payload
+ and answers[0].payload.target == dns.Name('.')):
+ raise ConnectError("Service %s unavailable", self.service_name)
+
+ self.servers = []
+ self.used_servers = []
+
+ for answer in answers:
+ if answer.type != dns.SRV or not answer.payload:
+ continue
+ payload = answer.payload
+ self.servers.append(self._Server(
+ host=str(payload.target),
+ port=int(payload.port),
+ priority=int(payload.priority),
+ weight=int(payload.weight)
+ ))
+
+ self.servers.sort()
+
+ def pick_server(self):
+ if not self.servers:
+ if self.used_servers:
+ self.servers = self.used_servers
+ self.used_servers = []
+ self.servers.sort()
+ elif self.default_server:
+ return self.default_server
+ else:
+ raise ConnectError(
+ "Not server available for %s", self.service_name
+ )
+
+ min_priority = self.servers[0].priority
+ weight_indexes = list(
+ (index, server.weight + 1)
+ for index, server in enumerate(self.servers)
+ if server.priority == min_priority
+ )
+
+ total_weight = sum(weight for index, weight in weight_indexes)
+ target_weight = random.randint(0, total_weight)
+
+ for index, weight in weight_indexes:
+ target_weight -= weight
+ if target_weight <= 0:
+ server = self.servers[index]
+ del self.servers[index]
+ self.used_servers.append(server)
+ return server
+
+ @defer.inlineCallbacks
+ def connect(self, protocolFactory):
+ if self.servers is None:
+ yield self.fetch_servers()
+ server = self.pick_server()
+ logger.info("Connecting to %s:%s", server.host, server.port)
+ endpoint = self.endpoint(
+ self.reactor, server.host, server.port, **self.endpoint_kw_args
+ )
+ connection = yield endpoint.connect(protocolFactory)
+ defer.returnValue(connection)
diff --git a/synapse/http/server.py b/synapse/http/server.py
new file mode 100644
index 0000000000..8823aade78
--- /dev/null
+++ b/synapse/http/server.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from syutil.jsonutil import (
+ encode_canonical_json, encode_pretty_printed_json
+)
+from synapse.api.errors import cs_exception, CodeMessageException
+
+from twisted.internet import defer, reactor
+from twisted.web import server, resource
+from twisted.web.server import NOT_DONE_YET
+
+import collections
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class HttpServer(object):
+ """ Interface for registering callbacks on a HTTP server
+ """
+
+ def register_path(self, method, path_pattern, callback):
+ """ Register a callback that get's fired if we receive a http request
+ with the given method for a path that matches the given regex.
+
+ If the regex contains groups these get's passed to the calback via
+ an unpacked tuple.
+
+ Args:
+ method (str): The method to listen to.
+ path_pattern (str): The regex used to match requests.
+ callback (function): The function to fire if we receive a matched
+ request. The first argument will be the request object and
+ subsequent arguments will be any matched groups from the regex.
+ This should return a tuple of (code, response).
+ """
+ pass
+
+
+# The actual HTTP server impl, using twisted http server
+class TwistedHttpServer(HttpServer, resource.Resource):
+ """ This wraps the twisted HTTP server, and triggers the correct callbacks
+ on the transport_layer.
+
+ Register callbacks via register_path()
+ """
+
+ isLeaf = True
+
+ _PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
+
+ def __init__(self):
+ resource.Resource.__init__(self)
+
+ self.path_regexs = {}
+
+ def register_path(self, method, path_pattern, callback):
+ self.path_regexs.setdefault(method, []).append(
+ self._PathEntry(path_pattern, callback)
+ )
+
+ def start_listening(self, port):
+ """ Registers the http server with the twisted reactor.
+
+ Args:
+ port (int): The port to listen on.
+
+ """
+ reactor.listenTCP(port, server.Site(self))
+
+ # Gets called by twisted
+ def render(self, request):
+ """ This get's called by twisted every time someone sends us a request.
+ """
+ self._async_render(request)
+ return server.NOT_DONE_YET
+
+ @defer.inlineCallbacks
+ def _async_render(self, request):
+ """ This get's called by twisted every time someone sends us a request.
+ This checks if anyone has registered a callback for that method and
+ path.
+ """
+ try:
+ # Loop through all the registered callbacks to check if the method
+ # and path regex match
+ for path_entry in self.path_regexs.get(request.method, []):
+ m = path_entry.pattern.match(request.path)
+ if m:
+ # We found a match! Trigger callback and then return the
+ # returned response. We pass both the request and any
+ # matched groups from the regex to the callback.
+ code, response = yield path_entry.callback(
+ request,
+ *m.groups()
+ )
+
+ self._send_response(request, code, response)
+ return
+
+ # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
+ self._send_response(
+ request,
+ 400,
+ {"error": "Unrecognized request"}
+ )
+ except CodeMessageException as e:
+ logger.exception(e)
+ self._send_response(
+ request,
+ e.code,
+ cs_exception(e)
+ )
+ except Exception as e:
+ logger.exception(e)
+ self._send_response(
+ request,
+ 500,
+ {"error": "Internal server error"}
+ )
+
+ def _send_response(self, request, code, response_json_object):
+
+ if not self._request_user_agent_is_curl(request):
+ json_bytes = encode_canonical_json(response_json_object)
+ else:
+ json_bytes = encode_pretty_printed_json(response_json_object)
+
+ # TODO: Only enable CORS for the requests that need it.
+ respond_with_json_bytes(request, code, json_bytes, send_cors=True)
+
+ @staticmethod
+ def _request_user_agent_is_curl(request):
+ user_agents = request.requestHeaders.getRawHeaders(
+ "User-Agent", default=[]
+ )
+ for user_agent in user_agents:
+ if "curl" in user_agent:
+ return True
+ return False
+
+
+def respond_with_json_bytes(request, code, json_bytes, send_cors=False):
+ """Sends encoded JSON in response to the given request.
+
+ Args:
+ request (twisted.web.http.Request): The http request to respond to.
+ code (int): The HTTP response code.
+ json_bytes (bytes): The json bytes to use as the response body.
+ send_cors (bool): Whether to send Cross-Origin Resource Sharing headers
+ http://www.w3.org/TR/cors/
+ Returns:
+ twisted.web.server.NOT_DONE_YET"""
+
+ request.setResponseCode(code)
+ request.setHeader(b"Content-Type", b"application/json")
+
+ if send_cors:
+ request.setHeader("Access-Control-Allow-Origin", "*")
+ request.setHeader("Access-Control-Allow-Methods",
+ "GET, POST, PUT, DELETE, OPTIONS")
+ request.setHeader("Access-Control-Allow-Headers",
+ "Origin, X-Requested-With, Content-Type, Accept")
+
+ request.write(json_bytes)
+ request.finish()
+ return NOT_DONE_YET
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
new file mode 100644
index 0000000000..5598295793
--- /dev/null
+++ b/synapse/rest/__init__.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import (
+ room, events, register, profile, public, presence, im, directory
+)
+
+class RestServletFactory(object):
+
+ """ A factory for creating REST servlets.
+
+ These REST servlets represent the entire client-server REST API. Generally
+ speaking, they serve as wrappers around events and the handlers that
+ process them.
+
+ See synapse.api.events for information on synapse events.
+ """
+
+ def __init__(self, hs):
+ http_server = hs.get_http_server()
+
+ # TODO(erikj): There *must* be a better way of doing this.
+ room.register_servlets(hs, http_server)
+ events.register_servlets(hs, http_server)
+ register.register_servlets(hs, http_server)
+ profile.register_servlets(hs, http_server)
+ public.register_servlets(hs, http_server)
+ presence.register_servlets(hs, http_server)
+ im.register_servlets(hs, http_server)
+ directory.register_servlets(hs, http_server)
+
+
diff --git a/synapse/rest/base.py b/synapse/rest/base.py
new file mode 100644
index 0000000000..d90ac611fe
--- /dev/null
+++ b/synapse/rest/base.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" This module contains base REST classes for constructing REST servlets. """
+import re
+
+
+def client_path_pattern(path_regex):
+ """Creates a regex compiled client path with the correct client path
+ prefix.
+
+ Args:
+ path_regex (str): The regex string to match. This should NOT have a ^
+ as this will be prefixed.
+ Returns:
+ SRE_Pattern
+ """
+ return re.compile("^/matrix/client/api/v1" + path_regex)
+
+
+class RestServletFactory(object):
+
+ """ A factory for creating REST servlets.
+
+ These REST servlets represent the entire client-server REST API. Generally
+ speaking, they serve as wrappers around events and the handlers that
+ process them.
+
+ See synapse.api.events for information on synapse events.
+ """
+
+ def __init__(self, hs):
+ http_server = hs.get_http_server()
+
+ # You get import errors if you try to import before the classes in this
+ # file are defined, hence importing here instead.
+
+ import room
+ room.register_servlets(hs, http_server)
+
+ import events
+ events.register_servlets(hs, http_server)
+
+ import register
+ register.register_servlets(hs, http_server)
+
+ import profile
+ profile.register_servlets(hs, http_server)
+
+ import public
+ public.register_servlets(hs, http_server)
+
+ import presence
+ presence.register_servlets(hs, http_server)
+
+ import im
+ im.register_servlets(hs, http_server)
+
+ import login
+ login.register_servlets(hs, http_server)
+
+
+class RestServlet(object):
+
+ """ A Synapse REST Servlet.
+
+ An implementing class can either provide its own custom 'register' method,
+ or use the automatic pattern handling provided by the base class.
+
+ To use this latter, the implementing class instead provides a `PATTERN`
+ class attribute containing a pre-compiled regular expression. The automatic
+ register method will then use this method to register any of the following
+ instance methods associated with the corresponding HTTP method:
+
+ on_GET
+ on_PUT
+ on_POST
+ on_DELETE
+ on_OPTIONS
+
+ Automatically handles turning CodeMessageExceptions thrown by these methods
+ into the appropriate HTTP response.
+ """
+
+ def __init__(self, hs):
+ self.hs = hs
+
+ self.handlers = hs.get_handlers()
+ self.event_factory = hs.get_event_factory()
+ self.auth = hs.get_auth()
+
+ def register(self, http_server):
+ """ Register this servlet with the given HTTP server. """
+ if hasattr(self, "PATTERN"):
+ pattern = self.PATTERN
+
+ for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
+ if hasattr(self, "on_%s" % (method)):
+ method_handler = getattr(self, "on_%s" % (method))
+ http_server.register_path(method, pattern, method_handler)
+ else:
+ raise NotImplementedError("RestServlet must register something.")
diff --git a/synapse/rest/directory.py b/synapse/rest/directory.py
new file mode 100644
index 0000000000..a426003a38
--- /dev/null
+++ b/synapse/rest/directory.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.types import RoomAlias, RoomID
+from base import RestServlet, client_path_pattern
+
+import json
+import logging
+import urllib
+
+
+logger = logging.getLogger(__name__)
+
+
+def register_servlets(hs, http_server):
+ ClientDirectoryServer(hs).register(http_server)
+
+
+class ClientDirectoryServer(RestServlet):
+ PATTERN = client_path_pattern("/ds/room/(?P<room_alias>[^/]*)$")
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_alias):
+ # TODO(erikj): Handle request
+ local_only = "local_only" in request.args
+
+ room_alias = urllib.unquote(room_alias)
+ room_alias_obj = RoomAlias.from_string(room_alias, self.hs)
+
+ dir_handler = self.handlers.directory_handler
+ res = yield dir_handler.get_association(
+ room_alias_obj,
+ local_only=local_only
+ )
+
+ defer.returnValue((200, res))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, room_alias):
+ # TODO(erikj): Exceptions
+ content = json.loads(request.content.read())
+
+ logger.debug("Got content: %s", content)
+
+ room_alias = urllib.unquote(room_alias)
+ room_alias_obj = RoomAlias.from_string(room_alias, self.hs)
+
+ logger.debug("Got room name: %s", room_alias_obj.to_string())
+
+ room_id = content["room_id"]
+ servers = content["servers"]
+
+ logger.debug("Got room_id: %s", room_id)
+ logger.debug("Got servers: %s", servers)
+
+ # TODO(erikj): Check types.
+ # TODO(erikj): Check that room exists
+
+ dir_handler = self.handlers.directory_handler
+
+ try:
+ yield dir_handler.create_association(
+ room_alias_obj, room_id, servers
+ )
+ except:
+ logger.exception("Failed to create association")
+
+ defer.returnValue((200, {}))
diff --git a/synapse/rest/events.py b/synapse/rest/events.py
new file mode 100644
index 0000000000..147257a940
--- /dev/null
+++ b/synapse/rest/events.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This module contains REST servlets to do with event streaming, /events."""
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.api.streams import PaginationConfig
+from synapse.rest.base import RestServlet, client_path_pattern
+
+
+class EventStreamRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/events$")
+
+ DEFAULT_LONGPOLL_TIME_MS = 5000
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ auth_user = yield self.auth.get_user_by_req(request)
+
+ handler = self.handlers.event_stream_handler
+ pagin_config = PaginationConfig.from_request(request)
+ timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
+ if "timeout" in request.args:
+ try:
+ timeout = int(request.args["timeout"][0])
+ except ValueError:
+ raise SynapseError(400, "timeout must be in milliseconds.")
+
+ chunk = yield handler.get_stream(auth_user.to_string(), pagin_config,
+ timeout=timeout)
+ defer.returnValue((200, chunk))
+
+ def on_OPTIONS(self, request):
+ return (200, {})
+
+
+def register_servlets(hs, http_server):
+ EventStreamRestServlet(hs).register(http_server)
diff --git a/synapse/rest/im.py b/synapse/rest/im.py
new file mode 100644
index 0000000000..39f2dbd749
--- /dev/null
+++ b/synapse/rest/im.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.api.streams import PaginationConfig
+from base import RestServlet, client_path_pattern
+
+
+class ImSyncRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/im/sync$")
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ user = yield self.auth.get_user_by_req(request)
+ with_feedback = "feedback" in request.args
+ pagination_config = PaginationConfig.from_request(request)
+ handler = self.handlers.message_handler
+ content = yield handler.snapshot_all_rooms(
+ user_id=user.to_string(),
+ pagin_config=pagination_config,
+ feedback=with_feedback)
+
+ defer.returnValue((200, content))
+
+
+def register_servlets(hs, http_server):
+ ImSyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/login.py b/synapse/rest/login.py
new file mode 100644
index 0000000000..0284e125b4
--- /dev/null
+++ b/synapse/rest/login.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from base import RestServlet, client_path_pattern
+
+import json
+
+
+class LoginRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/login$")
+ PASS_TYPE = "m.login.password"
+
+ def on_GET(self, request):
+ return (200, {"type": LoginRestServlet.PASS_TYPE})
+
+ def on_OPTIONS(self, request):
+ return (200, {})
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ login_submission = _parse_json(request)
+ try:
+ if login_submission["type"] == LoginRestServlet.PASS_TYPE:
+ result = yield self.do_password_login(login_submission)
+ defer.returnValue(result)
+ else:
+ raise SynapseError(400, "Bad login type.")
+ except KeyError:
+ raise SynapseError(400, "Missing JSON keys.")
+
+ @defer.inlineCallbacks
+ def do_password_login(self, login_submission):
+ handler = self.handlers.login_handler
+ token = yield handler.login(
+ user=login_submission["user"],
+ password=login_submission["password"])
+
+ result = {
+ "access_token": token,
+ "home_server": self.hs.hostname,
+ }
+
+ defer.returnValue((200, result))
+
+
+class LoginFallbackRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/login/fallback$")
+
+ def on_GET(self, request):
+ # TODO(kegan): This should be returning some HTML which is capable of
+ # hitting LoginRestServlet
+ return (200, "")
+
+
+def _parse_json(request):
+ try:
+ content = json.loads(request.content.read())
+ if type(content) != dict:
+ raise SynapseError(400, "Content must be a JSON object.")
+ return content
+ except ValueError:
+ raise SynapseError(400, "Content not JSON.")
+
+
+def register_servlets(hs, http_server):
+ LoginRestServlet(hs).register(http_server)
diff --git a/synapse/rest/presence.py b/synapse/rest/presence.py
new file mode 100644
index 0000000000..e4925c20a5
--- /dev/null
+++ b/synapse/rest/presence.py
@@ -0,0 +1,134 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" This module contains REST servlets to do with presence: /presence/<paths>
+"""
+from twisted.internet import defer
+
+from base import RestServlet, client_path_pattern
+
+import json
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class PresenceStatusRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status")
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ auth_user = yield self.auth.get_user_by_req(request)
+ user = self.hs.parse_userid(user_id)
+
+ state = yield self.handlers.presence_handler.get_state(
+ target_user=user, auth_user=auth_user)
+
+ defer.returnValue((200, state))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, user_id):
+ auth_user = yield self.auth.get_user_by_req(request)
+ user = self.hs.parse_userid(user_id)
+
+ state = {}
+ try:
+ content = json.loads(request.content.read())
+
+ state["state"] = content.pop("state")
+
+ if "status_msg" in content:
+ state["status_msg"] = content.pop("status_msg")
+
+ if content:
+ raise KeyError()
+ except:
+ defer.returnValue((400, "Unable to parse state"))
+
+ yield self.handlers.presence_handler.set_state(
+ target_user=user, auth_user=auth_user, state=state)
+
+ defer.returnValue((200, ""))
+
+ def on_OPTIONS(self, request):
+ return (200, {})
+
+
+class PresenceListRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/presence_list/(?P<user_id>[^/]*)")
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ auth_user = yield self.auth.get_user_by_req(request)
+ user = self.hs.parse_userid(user_id)
+
+ if not user.is_mine:
+ defer.returnValue((400, "User not hosted on this Home Server"))
+
+ if auth_user != user:
+ defer.returnValue((400, "Cannot get another user's presence list"))
+
+ presence = yield self.handlers.presence_handler.get_presence_list(
+ observer_user=user, accepted=True)
+
+ for p in presence:
+ observed_user = p.pop("observed_user")
+ p["user_id"] = observed_user.to_string()
+
+ defer.returnValue((200, presence))
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, user_id):
+ auth_user = yield self.auth.get_user_by_req(request)
+ user = self.hs.parse_userid(user_id)
+
+ if not user.is_mine:
+ defer.returnValue((400, "User not hosted on this Home Server"))
+
+ if auth_user != user:
+ defer.returnValue((
+ 400, "Cannot modify another user's presence list"))
+
+ try:
+ content = json.loads(request.content.read())
+ except:
+ logger.exception("JSON parse error")
+ defer.returnValue((400, "Unable to parse content"))
+
+ deferreds = []
+
+ if "invite" in content:
+ for u in content["invite"]:
+ invited_user = self.hs.parse_userid(u)
+ deferreds.append(self.handlers.presence_handler.send_invite(
+ observer_user=user, observed_user=invited_user))
+
+ if "drop" in content:
+ for u in content["drop"]:
+ dropped_user = self.hs.parse_userid(u)
+ deferreds.append(self.handlers.presence_handler.drop(
+ observer_user=user, observed_user=dropped_user))
+
+ yield defer.DeferredList(deferreds)
+
+ defer.returnValue((200, ""))
+
+ def on_OPTIONS(self, request):
+ return (200, {})
+
+
+def register_servlets(hs, http_server):
+ PresenceStatusRestServlet(hs).register(http_server)
+ PresenceListRestServlet(hs).register(http_server)
diff --git a/synapse/rest/profile.py b/synapse/rest/profile.py
new file mode 100644
index 0000000000..f384227c29
--- /dev/null
+++ b/synapse/rest/profile.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" This module contains REST servlets to do with profile: /profile/<paths> """
+from twisted.internet import defer
+
+from base import RestServlet, client_path_pattern
+
+import json
+
+
+class ProfileDisplaynameRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname")
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ user = self.hs.parse_userid(user_id)
+
+ displayname = yield self.handlers.profile_handler.get_displayname(
+ user,
+ local_only="local_only" in request.args
+ )
+
+ defer.returnValue((200, {"displayname": displayname}))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, user_id):
+ auth_user = yield self.auth.get_user_by_req(request)
+ user = self.hs.parse_userid(user_id)
+
+ try:
+ content = json.loads(request.content.read())
+ new_name = content["displayname"]
+ except:
+ defer.returnValue((400, "Unable to parse name"))
+
+ yield self.handlers.profile_handler.set_displayname(
+ user, auth_user, new_name)
+
+ defer.returnValue((200, ""))
+
+ def on_OPTIONS(self, request, user_id):
+ return (200, {})
+
+
+class ProfileAvatarURLRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url")
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ user = self.hs.parse_userid(user_id)
+
+ avatar_url = yield self.handlers.profile_handler.get_avatar_url(
+ user,
+ local_only="local_only" in request.args
+ )
+
+ defer.returnValue((200, {"avatar_url": avatar_url}))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, user_id):
+ auth_user = yield self.auth.get_user_by_req(request)
+ user = self.hs.parse_userid(user_id)
+
+ try:
+ content = json.loads(request.content.read())
+ new_name = content["avatar_url"]
+ except:
+ defer.returnValue((400, "Unable to parse name"))
+
+ yield self.handlers.profile_handler.set_avatar_url(
+ user, auth_user, new_name)
+
+ defer.returnValue((200, ""))
+
+ def on_OPTIONS(self, request, user_id):
+ return (200, {})
+
+
+def register_servlets(hs, http_server):
+ ProfileDisplaynameRestServlet(hs).register(http_server)
+ ProfileAvatarURLRestServlet(hs).register(http_server)
diff --git a/synapse/rest/public.py b/synapse/rest/public.py
new file mode 100644
index 0000000000..6fd1731a61
--- /dev/null
+++ b/synapse/rest/public.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This module contains REST servlets to do with public paths: /public"""
+from twisted.internet import defer
+
+from base import RestServlet, client_path_pattern
+
+
+class PublicRoomListRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/public/rooms$")
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ handler = self.handlers.room_list_handler
+ data = yield handler.get_public_room_list()
+ defer.returnValue((200, data))
+
+
+def register_servlets(hs, http_server):
+ PublicRoomListRestServlet(hs).register(http_server)
diff --git a/synapse/rest/register.py b/synapse/rest/register.py
new file mode 100644
index 0000000000..f1cbce5c67
--- /dev/null
+++ b/synapse/rest/register.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This module contains REST servlets to do with registration: /register"""
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from base import RestServlet, client_path_pattern
+
+import json
+import urllib
+
+
+class RegisterRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/register$")
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ desired_user_id = None
+ password = None
+ try:
+ register_json = json.loads(request.content.read())
+ if "password" in register_json:
+ password = register_json["password"]
+
+ if type(register_json["user_id"]) == unicode:
+ desired_user_id = register_json["user_id"]
+ if urllib.quote(desired_user_id) != desired_user_id:
+ raise SynapseError(
+ 400,
+ "User ID must only contain characters which do not " +
+ "require URL encoding.")
+ except ValueError:
+ defer.returnValue((400, "No JSON object."))
+ except KeyError:
+ pass # user_id is optional
+
+ handler = self.handlers.registration_handler
+ (user_id, token) = yield handler.register(
+ localpart=desired_user_id,
+ password=password)
+
+ result = {
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname,
+ }
+ defer.returnValue(
+ (200, result)
+ )
+
+ def on_OPTIONS(self, request):
+ return (200, {})
+
+
+def register_servlets(hs, http_server):
+ RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/room.py b/synapse/rest/room.py
new file mode 100644
index 0000000000..c96de5e65d
--- /dev/null
+++ b/synapse/rest/room.py
@@ -0,0 +1,394 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" This module contains REST servlets to do with rooms: /rooms/<paths> """
+from twisted.internet import defer
+
+from base import RestServlet, client_path_pattern
+from synapse.api.errors import SynapseError, Codes
+from synapse.api.events.room import (RoomTopicEvent, MessageEvent,
+ RoomMemberEvent, FeedbackEvent)
+from synapse.api.constants import Feedback, Membership
+from synapse.api.streams import PaginationConfig
+from synapse.types import RoomAlias
+
+import json
+import logging
+import urllib
+
+
+logger = logging.getLogger(__name__)
+
+
+class RoomCreateRestServlet(RestServlet):
+ # No PATTERN; we have custom dispatch rules here
+
+ def register(self, http_server):
+ # /rooms OR /rooms/<roomid>
+ http_server.register_path("POST",
+ client_path_pattern("/rooms$"),
+ self.on_POST)
+ http_server.register_path("PUT",
+ client_path_pattern(
+ "/rooms/(?P<room_id>[^/]*)$"),
+ self.on_PUT)
+ # define CORS for all of /rooms in RoomCreateRestServlet for simplicity
+ http_server.register_path("OPTIONS",
+ client_path_pattern("/rooms(?:/.*)?$"),
+ self.on_OPTIONS)
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, room_id):
+ room_id = urllib.unquote(room_id)
+ auth_user = yield self.auth.get_user_by_req(request)
+
+ if not room_id:
+ raise SynapseError(400, "PUT must specify a room ID")
+
+ room_config = self.get_room_config(request)
+ info = yield self.make_room(room_config, auth_user, room_id)
+ room_config.update(info)
+ defer.returnValue((200, info))
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ auth_user = yield self.auth.get_user_by_req(request)
+
+ room_config = self.get_room_config(request)
+ info = yield self.make_room(room_config, auth_user, None)
+ room_config.update(info)
+ defer.returnValue((200, info))
+
+ @defer.inlineCallbacks
+ def make_room(self, room_config, auth_user, room_id):
+ handler = self.handlers.room_creation_handler
+ info = yield handler.create_room(
+ user_id=auth_user.to_string(),
+ room_id=room_id,
+ config=room_config
+ )
+ defer.returnValue(info)
+
+ def get_room_config(self, request):
+ try:
+ user_supplied_config = json.loads(request.content.read())
+ if "visibility" not in user_supplied_config:
+ # default visibility
+ user_supplied_config["visibility"] = "public"
+ return user_supplied_config
+ except (ValueError, TypeError):
+ raise SynapseError(400, "Body must be JSON.",
+ errcode=Codes.BAD_JSON)
+
+ def on_OPTIONS(self, request):
+ return (200, {})
+
+
+class RoomTopicRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/topic$")
+
+ def get_event_type(self):
+ return RoomTopicEvent.TYPE
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id):
+ user = yield self.auth.get_user_by_req(request)
+
+ msg_handler = self.handlers.message_handler
+ data = yield msg_handler.get_room_data(
+ user_id=user.to_string(),
+ room_id=urllib.unquote(room_id),
+ event_type=RoomTopicEvent.TYPE,
+ state_key="",
+ )
+
+ if not data:
+ raise SynapseError(404, "Topic not found.", errcode=Codes.NOT_FOUND)
+ defer.returnValue((200, json.loads(data.content)))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, room_id):
+ user = yield self.auth.get_user_by_req(request)
+
+ content = _parse_json(request)
+
+ event = self.event_factory.create_event(
+ etype=self.get_event_type(),
+ content=content,
+ room_id=urllib.unquote(room_id),
+ user_id=user.to_string(),
+ )
+
+ msg_handler = self.handlers.message_handler
+ yield msg_handler.store_room_data(
+ event=event
+ )
+ defer.returnValue((200, ""))
+
+
+class JoinRoomAliasServlet(RestServlet):
+ PATTERN = client_path_pattern("/join/(?P<room_alias>[^/]+)$")
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, room_alias):
+ user = yield self.auth.get_user_by_req(request)
+
+ if not user:
+ defer.returnValue((403, "Unrecognized user"))
+
+ logger.debug("room_alias: %s", room_alias)
+
+ room_alias = RoomAlias.from_string(
+ urllib.unquote(room_alias),
+ self.hs
+ )
+
+ handler = self.handlers.room_member_handler
+ ret_dict = yield handler.join_room_alias(user, room_alias)
+
+ defer.returnValue((200, ret_dict))
+
+
+class RoomMemberRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members/"
+ + "(?P<target_user_id>[^/]*)/state$")
+
+ def get_event_type(self):
+ return RoomMemberEvent.TYPE
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, target_user_id):
+ room_id = urllib.unquote(room_id)
+ user = yield self.auth.get_user_by_req(request)
+
+ handler = self.handlers.room_member_handler
+ member = yield handler.get_room_member(room_id, target_user_id,
+ user.to_string())
+ if not member:
+ raise SynapseError(404, "Member not found.",
+ errcode=Codes.NOT_FOUND)
+ defer.returnValue((200, json.loads(member.content)))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, roomid, target_user_id):
+ user = yield self.auth.get_user_by_req(request)
+
+ event = self.event_factory.create_event(
+ etype=self.get_event_type(),
+ target_user_id=target_user_id,
+ room_id=urllib.unquote(roomid),
+ user_id=user.to_string(),
+ membership=Membership.LEAVE,
+ content={"membership": Membership.LEAVE}
+ )
+
+ handler = self.handlers.room_member_handler
+ yield handler.change_membership(event, broadcast_msg=True)
+ defer.returnValue((200, ""))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, roomid, target_user_id):
+ user = yield self.auth.get_user_by_req(request)
+
+ content = _parse_json(request)
+ if "membership" not in content:
+ raise SynapseError(400, "No membership key.",
+ errcode=Codes.BAD_JSON)
+
+ valid_membership_values = [Membership.JOIN, Membership.INVITE]
+ if (content["membership"] not in valid_membership_values):
+ raise SynapseError(400, "Membership value must be %s." % (
+ valid_membership_values,), errcode=Codes.BAD_JSON)
+
+ event = self.event_factory.create_event(
+ etype=self.get_event_type(),
+ target_user_id=target_user_id,
+ room_id=urllib.unquote(roomid),
+ user_id=user.to_string(),
+ membership=content["membership"],
+ content=content
+ )
+
+ handler = self.handlers.room_member_handler
+ result = yield handler.change_membership(event, broadcast_msg=True)
+ defer.returnValue((200, result))
+
+
+class MessageRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages/"
+ + "(?P<sender_id>[^/]*)/(?P<msg_id>[^/]*)$")
+
+ def get_event_type(self):
+ return MessageEvent.TYPE
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, sender_id, msg_id):
+ user = yield self.auth.get_user_by_req(request)
+
+ msg_handler = self.handlers.message_handler
+ msg = yield msg_handler.get_message(room_id=urllib.unquote(room_id),
+ sender_id=sender_id,
+ msg_id=msg_id,
+ user_id=user.to_string(),
+ )
+
+ if not msg:
+ raise SynapseError(404, "Message not found.",
+ errcode=Codes.NOT_FOUND)
+
+ defer.returnValue((200, json.loads(msg.content)))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, room_id, sender_id, msg_id):
+ user = yield self.auth.get_user_by_req(request)
+
+ if user.to_string() != sender_id:
+ raise SynapseError(403, "Must send messages as yourself.",
+ errcode=Codes.FORBIDDEN)
+
+ content = _parse_json(request)
+
+ event = self.event_factory.create_event(
+ etype=self.get_event_type(),
+ room_id=urllib.unquote(room_id),
+ user_id=user.to_string(),
+ msg_id=msg_id,
+ content=content
+ )
+
+ msg_handler = self.handlers.message_handler
+ yield msg_handler.send_message(event)
+
+ defer.returnValue((200, ""))
+
+
+class FeedbackRestServlet(RestServlet):
+ PATTERN = client_path_pattern(
+ "/rooms/(?P<room_id>[^/]*)/messages/" +
+ "(?P<msg_sender_id>[^/]*)/(?P<msg_id>[^/]*)/feedback/" +
+ "(?P<sender_id>[^/]*)/(?P<feedback_type>[^/]*)$"
+ )
+
+ def get_event_type(self):
+ return FeedbackEvent.TYPE
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, msg_sender_id, msg_id, fb_sender_id,
+ feedback_type):
+ user = yield (self.auth.get_user_by_req(request))
+
+ if feedback_type not in Feedback.LIST:
+ raise SynapseError(400, "Bad feedback type.",
+ errcode=Codes.BAD_JSON)
+
+ msg_handler = self.handlers.message_handler
+ feedback = yield msg_handler.get_feedback(
+ room_id=urllib.unquote(room_id),
+ msg_sender_id=msg_sender_id,
+ msg_id=msg_id,
+ user_id=user.to_string(),
+ fb_sender_id=fb_sender_id,
+ fb_type=feedback_type
+ )
+
+ if not feedback:
+ raise SynapseError(404, "Feedback not found.",
+ errcode=Codes.NOT_FOUND)
+
+ defer.returnValue((200, json.loads(feedback.content)))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, room_id, sender_id, msg_id, fb_sender_id,
+ feedback_type):
+ user = yield (self.auth.get_user_by_req(request))
+
+ if user.to_string() != fb_sender_id:
+ raise SynapseError(403, "Must send feedback as yourself.",
+ errcode=Codes.FORBIDDEN)
+
+ if feedback_type not in Feedback.LIST:
+ raise SynapseError(400, "Bad feedback type.",
+ errcode=Codes.BAD_JSON)
+
+ content = _parse_json(request)
+
+ event = self.event_factory.create_event(
+ etype=self.get_event_type(),
+ room_id=urllib.unquote(room_id),
+ msg_sender_id=sender_id,
+ msg_id=msg_id,
+ user_id=user.to_string(), # user sending the feedback
+ feedback_type=feedback_type,
+ content=content
+ )
+
+ msg_handler = self.handlers.message_handler
+ yield msg_handler.send_feedback(event)
+
+ defer.returnValue((200, ""))
+
+
+class RoomMemberListRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members/list$")
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id):
+ # TODO support Pagination stream API (limit/tokens)
+ user = yield self.auth.get_user_by_req(request)
+ handler = self.handlers.room_member_handler
+ members = yield handler.get_room_members_as_pagination_chunk(
+ room_id=urllib.unquote(room_id),
+ user_id=user.to_string())
+
+ defer.returnValue((200, members))
+
+
+class RoomMessageListRestServlet(RestServlet):
+ PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages/list$")
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id):
+ user = yield self.auth.get_user_by_req(request)
+ pagination_config = PaginationConfig.from_request(request)
+ with_feedback = "feedback" in request.args
+ handler = self.handlers.message_handler
+ msgs = yield handler.get_messages(
+ room_id=urllib.unquote(room_id),
+ user_id=user.to_string(),
+ pagin_config=pagination_config,
+ feedback=with_feedback)
+
+ defer.returnValue((200, msgs))
+
+
+def _parse_json(request):
+ try:
+ content = json.loads(request.content.read())
+ if type(content) != dict:
+ raise SynapseError(400, "Content must be a JSON object.",
+ errcode=Codes.NOT_JSON)
+ return content
+ except ValueError:
+ raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+
+
+def register_servlets(hs, http_server):
+ RoomTopicRestServlet(hs).register(http_server)
+ RoomMemberRestServlet(hs).register(http_server)
+ MessageRestServlet(hs).register(http_server)
+ FeedbackRestServlet(hs).register(http_server)
+ RoomCreateRestServlet(hs).register(http_server)
+ RoomMemberListRestServlet(hs).register(http_server)
+ RoomMessageListRestServlet(hs).register(http_server)
+ JoinRoomAliasServlet(hs).register(http_server)
diff --git a/synapse/server.py b/synapse/server.py
new file mode 100644
index 0000000000..0aff75f399
--- /dev/null
+++ b/synapse/server.py
@@ -0,0 +1,176 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This file provides some classes for setting up (partially-populated)
+# homeservers; either as a full homeserver as a real application, or a small
+# partial one for unit test mocking.
+
+# Imports required for the default HomeServer() implementation
+from synapse.federation import initialize_http_replication
+from synapse.federation.handler import FederationEventHandler
+from synapse.api.events.factory import EventFactory
+from synapse.api.notifier import Notifier
+from synapse.api.auth import Auth
+from synapse.handlers import Handlers
+from synapse.rest import RestServletFactory
+from synapse.state import StateHandler
+from synapse.storage import DataStore
+from synapse.types import UserID
+from synapse.util import Clock
+from synapse.util.distributor import Distributor
+from synapse.util.lockutils import LockManager
+
+
+class BaseHomeServer(object):
+ """A basic homeserver object without lazy component builders.
+
+ This will need all of the components it requires to either be passed as
+ constructor arguments, or the relevant methods overriding to create them.
+ Typically this would only be used for unit tests.
+
+ For every dependency in the DEPENDENCIES list below, this class creates one
+ method,
+ def get_DEPENDENCY(self)
+ which returns the value of that dependency. If no value has yet been set
+ nor was provided to the constructor, it will attempt to call a lazy builder
+ method called
+ def build_DEPENDENCY(self)
+ which must be implemented by the subclass. This code may call any of the
+ required "get" methods on the instance to obtain the sub-dependencies that
+ one requires.
+ """
+
+ DEPENDENCIES = [
+ 'clock',
+ 'http_server',
+ 'http_client',
+ 'db_pool',
+ 'persistence_service',
+ 'federation',
+ 'replication_layer',
+ 'datastore',
+ 'event_factory',
+ 'handlers',
+ 'auth',
+ 'rest_servlet_factory',
+ 'state_handler',
+ 'room_lock_manager',
+ 'notifier',
+ 'distributor',
+ ]
+
+ def __init__(self, hostname, **kwargs):
+ """
+ Args:
+ hostname : The hostname for the server.
+ """
+ self.hostname = hostname
+ self._building = {}
+
+ # Other kwargs are explicit dependencies
+ for depname in kwargs:
+ setattr(self, depname, kwargs[depname])
+
+ @classmethod
+ def _make_dependency_method(cls, depname):
+ def _get(self):
+ if hasattr(self, depname):
+ return getattr(self, depname)
+
+ if hasattr(self, "build_%s" % (depname)):
+ # Prevent cyclic dependencies from deadlocking
+ if depname in self._building:
+ raise ValueError("Cyclic dependency while building %s" % (
+ depname,
+ ))
+ self._building[depname] = 1
+
+ builder = getattr(self, "build_%s" % (depname))
+ dep = builder()
+ setattr(self, depname, dep)
+
+ del self._building[depname]
+
+ return dep
+
+ raise NotImplementedError(
+ "%s has no %s nor a builder for it" % (
+ type(self).__name__, depname,
+ )
+ )
+
+ setattr(BaseHomeServer, "get_%s" % (depname), _get)
+
+ # Other utility methods
+ def parse_userid(self, s):
+ """Parse the string given by 's' as a User ID and return a UserID
+ object."""
+ return UserID.from_string(s, hs=self)
+
+# Build magic accessors for every dependency
+for depname in BaseHomeServer.DEPENDENCIES:
+ BaseHomeServer._make_dependency_method(depname)
+
+
+class HomeServer(BaseHomeServer):
+ """A homeserver object that will construct most of its dependencies as
+ required.
+
+ It still requires the following to be specified by the caller:
+ http_server
+ http_client
+ db_pool
+ """
+
+ def build_clock(self):
+ return Clock()
+
+ def build_replication_layer(self):
+ return initialize_http_replication(self)
+
+ def build_federation(self):
+ return FederationEventHandler(self)
+
+ def build_datastore(self):
+ return DataStore(self)
+
+ def build_event_factory(self):
+ return EventFactory()
+
+ def build_handlers(self):
+ return Handlers(self)
+
+ def build_notifier(self):
+ return Notifier(self)
+
+ def build_auth(self):
+ return Auth(self)
+
+ def build_rest_servlet_factory(self):
+ return RestServletFactory(self)
+
+ def build_state_handler(self):
+ return StateHandler(self)
+
+ def build_room_lock_manager(self):
+ return LockManager()
+
+ def build_distributor(self):
+ return Distributor()
+
+ def register_servlets(self):
+ """Simply building the ServletFactory is sufficient to have it
+ register."""
+ self.get_rest_servlet_factory()
diff --git a/synapse/state.py b/synapse/state.py
new file mode 100644
index 0000000000..439c0b519a
--- /dev/null
+++ b/synapse/state.py
@@ -0,0 +1,223 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.federation.pdu_codec import encode_event_id
+from synapse.util.logutils import log_function
+
+from collections import namedtuple
+
+import logging
+import hashlib
+
+logger = logging.getLogger(__name__)
+
+
+def _get_state_key_from_event(event):
+ return event.state_key
+
+
+KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
+
+
+class StateHandler(object):
+ """ Repsonsible for doing state conflict resolution.
+ """
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self._replication = hs.get_replication_layer()
+ self.server_name = hs.hostname
+
+ @defer.inlineCallbacks
+ @log_function
+ def handle_new_event(self, event):
+ """ Given an event this works out if a) we have sufficient power level
+ to update the state and b) works out what the prev_state should be.
+
+ Returns:
+ Deferred: Resolved with a boolean indicating if we succesfully
+ updated the state.
+
+ Raised:
+ AuthError
+ """
+ # This needs to be done in a transaction.
+
+ if not hasattr(event, "state_key"):
+ return
+
+ key = KeyStateTuple(
+ event.room_id,
+ event.type,
+ _get_state_key_from_event(event)
+ )
+
+ # Now I need to fill out the prev state and work out if it has auth
+ # (w.r.t. to power levels)
+
+ results = yield self.store.get_latest_pdus_in_context(
+ event.room_id
+ )
+
+ event.prev_events = [
+ encode_event_id(p_id, origin) for p_id, origin, _ in results
+ ]
+ event.prev_events = [
+ e for e in event.prev_events if e != event.event_id
+ ]
+
+ if results:
+ event.depth = max([int(v) for _, _, v in results]) + 1
+ else:
+ event.depth = 0
+
+ current_state = yield self.store.get_current_state(
+ key.context, key.type, key.state_key
+ )
+
+ if current_state:
+ event.prev_state = encode_event_id(
+ current_state.pdu_id, current_state.origin
+ )
+
+ # TODO check current_state to see if the min power level is less
+ # than the power level of the user
+ # power_level = self._get_power_level_for_event(event)
+
+ yield self.store.update_current_state(
+ pdu_id=event.event_id,
+ origin=self.server_name,
+ context=key.context,
+ pdu_type=key.type,
+ state_key=key.state_key
+ )
+
+ defer.returnValue(True)
+
+ @defer.inlineCallbacks
+ @log_function
+ def handle_new_state(self, new_pdu):
+ """ Apply conflict resolution to `new_pdu`.
+
+ This should be called on every new state pdu, regardless of whether or
+ not there is a conflict.
+
+ This function is safe against the race of it getting called with two
+ `PDU`s trying to update the same state.
+ """
+
+ # This needs to be done in a transaction.
+
+ is_new = yield self._handle_new_state(new_pdu)
+
+ if is_new:
+ yield self.store.update_current_state(
+ pdu_id=new_pdu.pdu_id,
+ origin=new_pdu.origin,
+ context=new_pdu.context,
+ pdu_type=new_pdu.pdu_type,
+ state_key=new_pdu.state_key
+ )
+
+ defer.returnValue(is_new)
+
+ def _get_power_level_for_event(self, event):
+ # return self._persistence.get_power_level_for_user(event.room_id,
+ # event.sender)
+ return event.power_level
+
+ @defer.inlineCallbacks
+ @log_function
+ def _handle_new_state(self, new_pdu):
+ tree = yield self.store.get_unresolved_state_tree(new_pdu)
+ new_branch, current_branch = tree
+
+ logger.debug(
+ "_handle_new_state new=%s, current=%s",
+ new_branch, current_branch
+ )
+
+ if not current_branch:
+ # There is no current state
+ defer.returnValue(True)
+ return
+
+ if new_branch[-1] == current_branch[-1]:
+ # We have all the PDUs we need, so we can just do the conflict
+ # resolution.
+
+ if len(current_branch) == 1:
+ # This is a direct clobber so we can just...
+ defer.returnValue(True)
+
+ conflict_res = [
+ self._do_power_level_conflict_res,
+ self._do_chain_length_conflict_res,
+ self._do_hash_conflict_res,
+ ]
+
+ for algo in conflict_res:
+ new_res, curr_res = algo(new_branch, current_branch)
+
+ if new_res < curr_res:
+ defer.returnValue(False)
+ elif new_res > curr_res:
+ defer.returnValue(True)
+
+ raise Exception("Conflict resolution failed.")
+
+ else:
+ # We need to ask for PDUs.
+ missing_prev = max(
+ new_branch[-1], current_branch[-1],
+ key=lambda x: x.depth
+ )
+
+ yield self._replication.get_pdu(
+ destination=missing_prev.origin,
+ pdu_origin=missing_prev.prev_state_origin,
+ pdu_id=missing_prev.prev_state_id,
+ outlier=True
+ )
+
+ updated_current = yield self._handle_new_state(new_pdu)
+ defer.returnValue(updated_current)
+
+ def _do_power_level_conflict_res(self, new_branch, current_branch):
+ max_power_new = max(
+ new_branch[:-1],
+ key=lambda t: t.power_level
+ ).power_level
+
+ max_power_current = max(
+ current_branch[:-1],
+ key=lambda t: t.power_level
+ ).power_level
+
+ return (max_power_new, max_power_current)
+
+ def _do_chain_length_conflict_res(self, new_branch, current_branch):
+ return (len(new_branch), len(current_branch))
+
+ def _do_hash_conflict_res(self, new_branch, current_branch):
+ new_str = "".join([p.pdu_id + p.origin for p in new_branch])
+ c_str = "".join([p.pdu_id + p.origin for p in current_branch])
+
+ return (
+ hashlib.sha1(new_str).hexdigest(),
+ hashlib.sha1(c_str).hexdigest()
+ )
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
new file mode 100644
index 0000000000..ec93f9f8a7
--- /dev/null
+++ b/synapse/storage/__init__.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api.events.room import (
+ RoomMemberEvent, MessageEvent, RoomTopicEvent, FeedbackEvent,
+ RoomConfigEvent
+)
+
+from .directory import DirectoryStore
+from .feedback import FeedbackStore
+from .message import MessageStore
+from .presence import PresenceStore
+from .profile import ProfileStore
+from .registration import RegistrationStore
+from .room import RoomStore
+from .roommember import RoomMemberStore
+from .roomdata import RoomDataStore
+from .stream import StreamStore
+from .pdu import StatePduStore, PduStore
+from .transactions import TransactionStore
+
+import json
+import os
+
+
+class DataStore(RoomDataStore, RoomMemberStore, MessageStore, RoomStore,
+ RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
+ PresenceStore, PduStore, StatePduStore, TransactionStore,
+ DirectoryStore):
+
+ def __init__(self, hs):
+ super(DataStore, self).__init__(hs)
+ self.event_factory = hs.get_event_factory()
+ self.hs = hs
+
+ def persist_event(self, event):
+ if event.type == MessageEvent.TYPE:
+ return self.store_message(
+ user_id=event.user_id,
+ room_id=event.room_id,
+ msg_id=event.msg_id,
+ content=json.dumps(event.content)
+ )
+ elif event.type == RoomMemberEvent.TYPE:
+ return self.store_room_member(
+ user_id=event.target_user_id,
+ sender=event.user_id,
+ room_id=event.room_id,
+ content=event.content,
+ membership=event.content["membership"]
+ )
+ elif event.type == FeedbackEvent.TYPE:
+ return self.store_feedback(
+ room_id=event.room_id,
+ msg_id=event.msg_id,
+ msg_sender_id=event.msg_sender_id,
+ fb_sender_id=event.user_id,
+ fb_type=event.feedback_type,
+ content=json.dumps(event.content)
+ )
+ elif event.type == RoomTopicEvent.TYPE:
+ return self.store_room_data(
+ room_id=event.room_id,
+ etype=event.type,
+ state_key=event.state_key,
+ content=json.dumps(event.content)
+ )
+ elif event.type == RoomConfigEvent.TYPE:
+ if "visibility" in event.content:
+ visibility = event.content["visibility"]
+ return self.store_room_config(
+ room_id=event.room_id,
+ visibility=visibility
+ )
+
+ else:
+ raise NotImplementedError(
+ "Don't know how to persist type=%s" % event.type
+ )
+
+
+def schema_path(schema):
+ """ Get a filesystem path for the named database schema
+
+ Args:
+ schema: Name of the database schema.
+ Returns:
+ A filesystem path pointing at a ".sql" file.
+
+ """
+ dir_path = os.path.dirname(__file__)
+ schemaPath = os.path.join(dir_path, "schema", schema + ".sql")
+ return schemaPath
+
+
+def read_schema(schema):
+ """ Read the named database schema.
+
+ Args:
+ schema: Name of the datbase schema.
+ Returns:
+ A string containing the database schema.
+ """
+ with open(schema_path(schema)) as schema_file:
+ return schema_file.read()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
new file mode 100644
index 0000000000..4d98a6fd0d
--- /dev/null
+++ b/synapse/storage/_base.py
@@ -0,0 +1,405 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+
+import collections
+
+logger = logging.getLogger(__name__)
+
+
+class SQLBaseStore(object):
+
+ def __init__(self, hs):
+ self._db_pool = hs.get_db_pool()
+
+ def cursor_to_dict(self, cursor):
+ """Converts a SQL cursor into an list of dicts.
+
+ Args:
+ cursor : The DBAPI cursor which has executed a query.
+ Returns:
+ A list of dicts where the key is the column header.
+ """
+ col_headers = list(column[0] for column in cursor.description)
+ results = list(
+ dict(zip(col_headers, row)) for row in cursor.fetchall()
+ )
+ return results
+
+ def _execute(self, decoder, query, *args):
+ """Runs a single query for a result set.
+
+ Args:
+ decoder - The function which can resolve the cursor results to
+ something meaningful.
+ query - The query string to execute
+ *args - Query args.
+ Returns:
+ The result of decoder(results)
+ """
+ logger.debug(
+ "[SQL] %s Args=%s Func=%s", query, args, decoder.__name__
+ )
+
+ def interaction(txn):
+ cursor = txn.execute(query, args)
+ return decoder(cursor)
+ return self._db_pool.runInteraction(interaction)
+
+ # "Simple" SQL API methods that operate on a single table with no JOINs,
+ # no complex WHERE clauses, just a dict of values for columns.
+
+ def _simple_insert(self, table, values):
+ """Executes an INSERT query on the named table.
+
+ Args:
+ table : string giving the table name
+ values : dict of new column names and values for them
+ """
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+ table,
+ ", ".join(k for k in values),
+ ", ".join("?" for k in values)
+ )
+
+ def func(txn):
+ txn.execute(sql, values.values())
+ return txn.lastrowid
+ return self._db_pool.runInteraction(func)
+
+ def _simple_select_one(self, table, keyvalues, retcols,
+ allow_none=False):
+ """Executes a SELECT query on the named table, which is expected to
+ return a single row, returning a single column from it.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ retcols : list of strings giving the names of the columns to return
+
+ allow_none : If true, return None instead of failing if the SELECT
+ statement returns no rows
+ """
+ return self._simple_selectupdate_one(
+ table, keyvalues, retcols=retcols, allow_none=allow_none
+ )
+
+ @defer.inlineCallbacks
+ def _simple_select_one_onecol(self, table, keyvalues, retcol,
+ allow_none=False):
+ """Executes a SELECT query on the named table, which is expected to
+ return a single row, returning a single column from it."
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ retcol : string giving the name of the column to return
+ """
+ ret = yield self._simple_select_one(
+ table=table,
+ keyvalues=keyvalues,
+ retcols=[retcol],
+ allow_none=allow_none
+ )
+
+ if ret:
+ defer.returnValue(ret[retcol])
+ else:
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def _simple_select_onecol(self, table, keyvalues, retcol):
+ """Executes a SELECT query on the named table, which returns a list
+ comprising of the values of the named column from the selected rows.
+
+ Args:
+ table (str): table name
+ keyvalues (dict): column names and values to select the rows with
+ retcol (str): column whos value we wish to retrieve.
+
+ Returns:
+ Deferred: Results in a list
+ """
+ sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
+ "retcol": retcol,
+ "table": table,
+ "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
+ }
+
+ def func(txn):
+ txn.execute(sql, keyvalues.values())
+ return txn.fetchall()
+
+ res = yield self._db_pool.runInteraction(func)
+
+ defer.returnValue([r[0] for r in res])
+
+ def _simple_select_list(self, table, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k) for k in keyvalues)
+ )
+
+ def func(txn):
+ txn.execute(sql, keyvalues.values())
+ return self.cursor_to_dict(txn)
+
+ return self._db_pool.runInteraction(func)
+
+ def _simple_update_one(self, table, keyvalues, updatevalues,
+ retcols=None):
+ """Executes an UPDATE query on the named table, setting new values for
+ columns in a row matching the key values.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ updatevalues : dict giving column names and values to update
+ retcols : optional list of column names to return
+
+ If present, retcols gives a list of column names on which to perform
+ a SELECT statement *before* performing the UPDATE statement. The values
+ of these will be returned in a dict.
+
+ These are performed within the same transaction, allowing an atomic
+ get-and-set. This can be used to implement compare-and-set by putting
+ the update column in the 'keyvalues' dict as well.
+ """
+ return self._simple_selectupdate_one(table, keyvalues, updatevalues,
+ retcols=retcols)
+
+ def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
+ retcols=None, allow_none=False):
+ """ Combined SELECT then UPDATE."""
+ if retcols:
+ select_sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k) for k in keyvalues)
+ )
+
+ if updatevalues:
+ update_sql = "UPDATE %s SET %s WHERE %s" % (
+ table,
+ ", ".join("%s = ?" % (k) for k in updatevalues),
+ " AND ".join("%s = ?" % (k) for k in keyvalues)
+ )
+
+ def func(txn):
+ ret = None
+ if retcols:
+ txn.execute(select_sql, keyvalues.values())
+
+ row = txn.fetchone()
+ if not row:
+ if allow_none:
+ return None
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched")
+
+ ret = dict(zip(retcols, row))
+
+ if updatevalues:
+ txn.execute(
+ update_sql,
+ updatevalues.values() + keyvalues.values()
+ )
+
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched")
+
+ return ret
+ return self._db_pool.runInteraction(func)
+
+ def _simple_delete_one(self, table, keyvalues):
+ """Executes a DELETE query on the named table, expecting to delete a
+ single row.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k) for k in keyvalues)
+ )
+
+ def func(txn):
+ txn.execute(sql, keyvalues.values())
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "more than one row matched")
+ return self._db_pool.runInteraction(func)
+
+ def _simple_max_id(self, table):
+ """Executes a SELECT query on the named table, expecting to return the
+ max value for the column "id".
+
+ Args:
+ table : string giving the table name
+ """
+ sql = "SELECT MAX(id) AS id FROM %s" % table
+
+ def func(txn):
+ txn.execute(sql)
+ max_id = self.cursor_to_dict(txn)[0]["id"]
+ if max_id is None:
+ return 0
+ return max_id
+
+ return self._db_pool.runInteraction(func)
+
+
+class Table(object):
+ """ A base class used to store information about a particular table.
+ """
+
+ table_name = None
+ """ str: The name of the table """
+
+ fields = None
+ """ list: The field names """
+
+ EntryType = None
+ """ Type: A tuple type used to decode the results """
+
+ _select_where_clause = "SELECT %s FROM %s WHERE %s"
+ _select_clause = "SELECT %s FROM %s"
+ _insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)"
+
+ @classmethod
+ def select_statement(cls, where_clause=None):
+ """
+ Args:
+ where_clause (str): The WHERE clause to use.
+
+ Returns:
+ str: An SQL statement to select rows from the table with the given
+ WHERE clause.
+ """
+ if where_clause:
+ return cls._select_where_clause % (
+ ", ".join(cls.fields),
+ cls.table_name,
+ where_clause
+ )
+ else:
+ return cls._select_clause % (
+ ", ".join(cls.fields),
+ cls.table_name,
+ )
+
+ @classmethod
+ def insert_statement(cls):
+ return cls._insert_clause % (
+ cls.table_name,
+ ", ".join(cls.fields),
+ ", ".join(["?"] * len(cls.fields)),
+ )
+
+ @classmethod
+ def decode_single_result(cls, results):
+ """ Given an iterable of tuples, return a single instance of
+ `EntryType` or None if the iterable is empty
+ Args:
+ results (list): The results list to convert to `EntryType`
+ Returns:
+ EntryType: An instance of `EntryType`
+ """
+ results = list(results)
+ if results:
+ return cls.EntryType(*results[0])
+ else:
+ return None
+
+ @classmethod
+ def decode_results(cls, results):
+ """ Given an iterable of tuples, return a list of `EntryType`
+ Args:
+ results (list): The results list to convert to `EntryType`
+
+ Returns:
+ list: A list of `EntryType`
+ """
+ return [cls.EntryType(*row) for row in results]
+
+ @classmethod
+ def get_fields_string(cls, prefix=None):
+ if prefix:
+ to_join = ("%s.%s" % (prefix, f) for f in cls.fields)
+ else:
+ to_join = cls.fields
+
+ return ", ".join(to_join)
+
+
+class JoinHelper(object):
+ """ Used to help do joins on tables by looking at the tables' fields and
+ creating a list of unique fields to use with SELECTs and a namedtuple
+ to dump the results into.
+
+ Attributes:
+ taples (list): List of `Table` classes
+ EntryType (type)
+ """
+
+ def __init__(self, *tables):
+ self.tables = tables
+
+ res = []
+ for table in self.tables:
+ res += [f for f in table.fields if f not in res]
+
+ self.EntryType = collections.namedtuple("JoinHelperEntry", res)
+
+ def get_fields(self, **prefixes):
+ """Get a string representing a list of fields for use in SELECT
+ statements with the given prefixes applied to each.
+
+ For example::
+
+ JoinHelper(PdusTable, StateTable).get_fields(
+ PdusTable="pdus",
+ StateTable="state"
+ )
+ """
+ res = []
+ for field in self.EntryType._fields:
+ for table in self.tables:
+ if field in table.fields:
+ res.append("%s.%s" % (prefixes[table.__name__], field))
+ break
+
+ return ", ".join(res)
+
+ def decode_results(self, rows):
+ return [self.EntryType(*row) for row in rows]
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
new file mode 100644
index 0000000000..71fa9d9c9c
--- /dev/null
+++ b/synapse/storage/directory.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+from twisted.internet import defer
+
+from collections import namedtuple
+
+
+RoomAliasMapping = namedtuple(
+ "RoomAliasMapping",
+ ("room_id", "room_alias", "servers",)
+)
+
+
+class DirectoryStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def get_association_from_room_alias(self, room_alias):
+ """ Get's the room_id and server list for a given room_alias
+
+ Args:
+ room_alias (RoomAlias)
+
+ Returns:
+ Deferred: results in namedtuple with keys "room_id" and
+ "servers" or None if no association can be found
+ """
+ room_id = yield self._simple_select_one_onecol(
+ "room_aliases",
+ {"room_alias": room_alias.to_string()},
+ "room_id",
+ allow_none=True,
+ )
+
+ if not room_id:
+ defer.returnValue(None)
+ return
+
+ servers = yield self._simple_select_onecol(
+ "room_alias_servers",
+ {"room_alias": room_alias.to_string()},
+ "server",
+ )
+
+ if not servers:
+ defer.returnValue(None)
+ return
+
+ defer.returnValue(
+ RoomAliasMapping(room_id, room_alias.to_string(), servers)
+ )
+
+ @defer.inlineCallbacks
+ def create_room_alias_association(self, room_alias, room_id, servers):
+ """ Creates an associatin between a room alias and room_id/servers
+
+ Args:
+ room_alias (RoomAlias)
+ room_id (str)
+ servers (list)
+
+ Returns:
+ Deferred
+ """
+ yield self._simple_insert(
+ "room_aliases",
+ {
+ "room_alias": room_alias.to_string(),
+ "room_id": room_id,
+ },
+ )
+
+ for server in servers:
+ # TODO(erikj): Fix this to bulk insert
+ yield self._simple_insert(
+ "room_alias_servers",
+ {
+ "room_alias": room_alias.to_string(),
+ "server": server,
+ }
+ )
diff --git a/synapse/storage/feedback.py b/synapse/storage/feedback.py
new file mode 100644
index 0000000000..2b421e3342
--- /dev/null
+++ b/synapse/storage/feedback.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore, Table
+from synapse.api.events.room import FeedbackEvent
+
+import collections
+import json
+
+
+class FeedbackStore(SQLBaseStore):
+
+ def store_feedback(self, room_id, msg_id, msg_sender_id,
+ fb_sender_id, fb_type, content):
+ return self._simple_insert(FeedbackTable.table_name, dict(
+ room_id=room_id,
+ msg_id=msg_id,
+ msg_sender_id=msg_sender_id,
+ fb_sender_id=fb_sender_id,
+ fb_type=fb_type,
+ content=content,
+ ))
+
+ def get_feedback(self, room_id=None, msg_id=None, msg_sender_id=None,
+ fb_sender_id=None, fb_type=None):
+ query = FeedbackTable.select_statement(
+ "msg_sender_id = ? AND room_id = ? AND msg_id = ? " +
+ "AND fb_sender_id = ? AND feedback_type = ? " +
+ "ORDER BY id DESC LIMIT 1")
+ return self._execute(
+ FeedbackTable.decode_single_result,
+ query, msg_sender_id, room_id, msg_id, fb_sender_id, fb_type,
+ )
+
+ def get_max_feedback_id(self):
+ return self._simple_max_id(FeedbackTable.table_name)
+
+
+class FeedbackTable(Table):
+ table_name = "feedback"
+
+ fields = [
+ "id",
+ "content",
+ "feedback_type",
+ "fb_sender_id",
+ "msg_id",
+ "room_id",
+ "msg_sender_id"
+ ]
+
+ class EntryType(collections.namedtuple("FeedbackEntry", fields)):
+
+ def as_event(self, event_factory):
+ return event_factory.create_event(
+ etype=FeedbackEvent.TYPE,
+ room_id=self.room_id,
+ msg_id=self.msg_id,
+ msg_sender_id=self.msg_sender_id,
+ user_id=self.fb_sender_id,
+ feedback_type=self.feedback_type,
+ content=json.loads(self.content),
+ )
diff --git a/synapse/storage/message.py b/synapse/storage/message.py
new file mode 100644
index 0000000000..4822fa709d
--- /dev/null
+++ b/synapse/storage/message.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore, Table
+from synapse.api.events.room import MessageEvent
+
+import collections
+import json
+
+
+class MessageStore(SQLBaseStore):
+
+ def get_message(self, user_id, room_id, msg_id):
+ """Get a message from the store.
+
+ Args:
+ user_id (str): The ID of the user who sent the message.
+ room_id (str): The room the message was sent in.
+ msg_id (str): The unique ID for this user/room combo.
+ """
+ query = MessagesTable.select_statement(
+ "user_id = ? AND room_id = ? AND msg_id = ? " +
+ "ORDER BY id DESC LIMIT 1")
+ return self._execute(
+ MessagesTable.decode_single_result,
+ query, user_id, room_id, msg_id,
+ )
+
+ def store_message(self, user_id, room_id, msg_id, content):
+ """Store a message in the store.
+
+ Args:
+ user_id (str): The ID of the user who sent the message.
+ room_id (str): The room the message was sent in.
+ msg_id (str): The unique ID for this user/room combo.
+ content (str): The content of the message (JSON)
+ """
+ return self._simple_insert(MessagesTable.table_name, dict(
+ user_id=user_id,
+ room_id=room_id,
+ msg_id=msg_id,
+ content=content,
+ ))
+
+ def get_max_message_id(self):
+ return self._simple_max_id(MessagesTable.table_name)
+
+
+class MessagesTable(Table):
+ table_name = "messages"
+
+ fields = [
+ "id",
+ "user_id",
+ "room_id",
+ "msg_id",
+ "content"
+ ]
+
+ class EntryType(collections.namedtuple("MessageEntry", fields)):
+
+ def as_event(self, event_factory):
+ return event_factory.create_event(
+ etype=MessageEvent.TYPE,
+ room_id=self.room_id,
+ user_id=self.user_id,
+ msg_id=self.msg_id,
+ content=json.loads(self.content),
+ )
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
new file mode 100644
index 0000000000..a1cdde0a3b
--- /dev/null
+++ b/synapse/storage/pdu.py
@@ -0,0 +1,993 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore, Table, JoinHelper
+
+from synapse.util.logutils import log_function
+
+from collections import namedtuple
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class PduStore(SQLBaseStore):
+ """A collection of queries for handling PDUs.
+ """
+
+ def get_pdu(self, pdu_id, origin):
+ """Given a pdu_id and origin, get a PDU.
+
+ Args:
+ txn
+ pdu_id (str)
+ origin (str)
+
+ Returns:
+ PduTuple: If the pdu does not exist in the database, returns None
+ """
+
+ return self._db_pool.runInteraction(
+ self._get_pdu_tuple, pdu_id, origin
+ )
+
+ def _get_pdu_tuple(self, txn, pdu_id, origin):
+ res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
+ return res[0] if res else None
+
+ def _get_pdu_tuples(self, txn, pdu_id_tuples):
+ results = []
+ for pdu_id, origin in pdu_id_tuples:
+ txn.execute(
+ PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
+ (pdu_id, origin)
+ )
+
+ edges = [
+ (r.prev_pdu_id, r.prev_origin)
+ for r in PduEdgesTable.decode_results(txn.fetchall())
+ ]
+
+ query = (
+ "SELECT %(fields)s FROM %(pdus)s as p "
+ "LEFT JOIN %(state)s as s "
+ "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
+ "WHERE p.pdu_id = ? AND p.origin = ? "
+ ) % {
+ "fields": _pdu_state_joiner.get_fields(
+ PdusTable="p", StatePdusTable="s"),
+ "pdus": PdusTable.table_name,
+ "state": StatePdusTable.table_name,
+ }
+
+ txn.execute(query, (pdu_id, origin))
+
+ row = txn.fetchone()
+ if row:
+ results.append(PduTuple(PduEntry(*row), edges))
+
+ return results
+
+ def get_current_state_for_context(self, context):
+ """Get a list of PDUs that represent the current state for a given
+ context
+
+ Args:
+ context (str)
+
+ Returns:
+ list: A list of PduTuples
+ """
+
+ return self._db_pool.runInteraction(
+ self._get_current_state_for_context,
+ context
+ )
+
+ def _get_current_state_for_context(self, txn, context):
+ query = (
+ "SELECT pdu_id, origin FROM %s WHERE context = ?"
+ % CurrentStateTable.table_name
+ )
+
+ logger.debug("get_current_state %s, Args=%s", query, context)
+ txn.execute(query, (context,))
+
+ res = txn.fetchall()
+
+ logger.debug("get_current_state %d results", len(res))
+
+ return self._get_pdu_tuples(txn, res)
+
+ def persist_pdu(self, prev_pdus, **cols):
+ """Inserts a (non-state) PDU into the database.
+
+ Args:
+ txn,
+ prev_pdus (list)
+ **cols: The columns to insert into the PdusTable.
+ """
+ return self._db_pool.runInteraction(
+ self._persist_pdu, prev_pdus, cols
+ )
+
+ def _persist_pdu(self, txn, prev_pdus, cols):
+ entry = PdusTable.EntryType(
+ **{k: cols.get(k, None) for k in PdusTable.fields}
+ )
+
+ txn.execute(PdusTable.insert_statement(), entry)
+
+ self._handle_prev_pdus(
+ txn, entry.outlier, entry.pdu_id, entry.origin,
+ prev_pdus, entry.context
+ )
+
+ def mark_pdu_as_processed(self, pdu_id, pdu_origin):
+ """Mark a received PDU as processed.
+
+ Args:
+ txn
+ pdu_id (str)
+ pdu_origin (str)
+ """
+
+ return self._db_pool.runInteraction(
+ self._mark_as_processed, pdu_id, pdu_origin
+ )
+
+ def _mark_as_processed(self, txn, pdu_id, pdu_origin):
+ txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
+
+ def get_all_pdus_from_context(self, context):
+ """Get a list of all PDUs for a given context."""
+ return self._db_pool.runInteraction(
+ self._get_all_pdus_from_context, context,
+ )
+
+ def _get_all_pdus_from_context(self, txn, context):
+ query = (
+ "SELECT pdu_id, origin FROM %s "
+ "WHERE context = ?"
+ ) % PdusTable.table_name
+
+ txn.execute(query, (context,))
+
+ return self._get_pdu_tuples(txn, txn.fetchall())
+
+ def get_pagination(self, context, pdu_list, limit):
+ """Get a list of Pdus for a given topic that occured before (and
+ including) the pdus in pdu_list. Return a list of max size `limit`.
+
+ Args:
+ txn
+ context (str)
+ pdu_list (list)
+ limit (int)
+
+ Return:
+ list: A list of PduTuples
+ """
+ return self._db_pool.runInteraction(
+ self._get_paginate, context, pdu_list, limit
+ )
+
+ def _get_paginate(self, txn, context, pdu_list, limit):
+ logger.debug(
+ "paginate: %s, %s, %s",
+ context, repr(pdu_list), limit
+ )
+
+ # We seed the pdu_results with the things from the pdu_list.
+ pdu_results = pdu_list
+
+ front = pdu_list
+
+ query = (
+ "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
+ "WHERE context = ? AND pdu_id = ? AND origin = ? "
+ "LIMIT ?"
+ ) % {
+ "edges_table": PduEdgesTable.table_name,
+ }
+
+ # We iterate through all pdu_ids in `front` to select their previous
+ # pdus. These are dumped in `new_front`. We continue until we reach the
+ # limit *or* new_front is empty (i.e., we've run out of things to
+ # select
+ while front and len(pdu_results) < limit:
+
+ new_front = []
+ for pdu_id, origin in front:
+ logger.debug(
+ "_paginate_interaction: i=%s, o=%s",
+ pdu_id, origin
+ )
+
+ txn.execute(
+ query,
+ (context, pdu_id, origin, limit - len(pdu_results))
+ )
+
+ for row in txn.fetchall():
+ logger.debug(
+ "_paginate_interaction: got i=%s, o=%s",
+ *row
+ )
+ new_front.append(row)
+
+ front = new_front
+ pdu_results += new_front
+
+ # We also want to update the `prev_pdus` attributes before returning.
+ return self._get_pdu_tuples(txn, pdu_results)
+
+ def get_min_depth_for_context(self, context):
+ """Get the current minimum depth for a context
+
+ Args:
+ txn
+ context (str)
+ """
+ return self._db_pool.runInteraction(
+ self._get_min_depth_for_context, context
+ )
+
+ def _get_min_depth_for_context(self, txn, context):
+ return self._get_min_depth_interaction(txn, context)
+
+ def _get_min_depth_interaction(self, txn, context):
+ txn.execute(
+ "SELECT min_depth FROM %s WHERE context = ?"
+ % ContextDepthTable.table_name,
+ (context,)
+ )
+
+ row = txn.fetchone()
+
+ return row[0] if row else None
+
+ def update_min_depth_for_context(self, context, depth):
+ """Update the minimum `depth` of the given context, which is the line
+ where we stop paginating backwards on.
+
+ Args:
+ context (str)
+ depth (int)
+ """
+ return self._db_pool.runInteraction(
+ self._update_min_depth_for_context, context, depth
+ )
+
+ def _update_min_depth_for_context(self, txn, context, depth):
+ min_depth = self._get_min_depth_interaction(txn, context)
+
+ do_insert = depth < min_depth if min_depth else True
+
+ if do_insert:
+ txn.execute(
+ "INSERT OR REPLACE INTO %s (context, min_depth) "
+ "VALUES (?,?)" % ContextDepthTable.table_name,
+ (context, depth)
+ )
+
+ def get_latest_pdus_in_context(self, context):
+ """Get's a list of the most current pdus for a given context. This is
+ used when we are sending a Pdu and need to fill out the `prev_pdus`
+ key
+
+ Args:
+ txn
+ context
+ """
+ return self._db_pool.runInteraction(
+ self._get_latest_pdus_in_context, context
+ )
+
+ def _get_latest_pdus_in_context(self, txn, context):
+ query = (
+ "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
+ "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
+ "AND f.origin = p.origin "
+ "WHERE f.context = ?"
+ ) % {
+ "pdus": PdusTable.table_name,
+ "forward": PduForwardExtremitiesTable.table_name,
+ }
+
+ logger.debug("get_prev query: %s", query)
+
+ txn.execute(
+ query,
+ (context, )
+ )
+
+ results = txn.fetchall()
+
+ return [(row[0], row[1], row[2]) for row in results]
+
+ def get_oldest_pdus_in_context(self, context):
+ """Get a list of Pdus that we paginated beyond yet (and haven't seen).
+ This list is used when we want to paginate backwards and is the list we
+ send to the remote server.
+
+ Args:
+ txn
+ context (str)
+
+ Returns:
+ list: A list of PduIdTuple.
+ """
+ return self._db_pool.runInteraction(
+ self._get_oldest_pdus_in_context, context
+ )
+
+ def _get_oldest_pdus_in_context(self, txn, context):
+ txn.execute(
+ "SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
+ % {"back": PduBackwardExtremitiesTable.table_name, },
+ (context,)
+ )
+ return [PduIdTuple(i, o) for i, o in txn.fetchall()]
+
+ def is_pdu_new(self, pdu_id, origin, context, depth):
+ """For a given Pdu, try and figure out if it's 'new', i.e., if it's
+ not something we got randomly from the past, for example when we
+ request the current state of the room that will probably return a bunch
+ of pdus from before we joined.
+
+ Args:
+ txn
+ pdu_id (str)
+ origin (str)
+ context (str)
+ depth (int)
+
+ Returns:
+ bool
+ """
+
+ return self._db_pool.runInteraction(
+ self._is_pdu_new,
+ pdu_id=pdu_id,
+ origin=origin,
+ context=context,
+ depth=depth
+ )
+
+ def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
+ # If depth > min depth in back table, then we classify it as new.
+ # OR if there is nothing in the back table, then it kinda needs to
+ # be a new thing.
+ query = (
+ "SELECT min(p.depth) FROM %(edges)s as e "
+ "INNER JOIN %(back)s as b "
+ "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
+ "INNER JOIN %(pdus)s as p "
+ "ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
+ "WHERE p.context = ?"
+ ) % {
+ "pdus": PdusTable.table_name,
+ "edges": PduEdgesTable.table_name,
+ "back": PduBackwardExtremitiesTable.table_name,
+ }
+
+ txn.execute(query, (context,))
+
+ min_depth, = txn.fetchone()
+
+ if not min_depth or depth > int(min_depth):
+ logger.debug(
+ "is_new true: id=%s, o=%s, d=%s min_depth=%s",
+ pdu_id, origin, depth, min_depth
+ )
+ return True
+
+ # If this pdu is in the forwards table, then it also is a new one
+ query = (
+ "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
+ ) % {
+ "forward": PduForwardExtremitiesTable.table_name,
+ }
+
+ txn.execute(query, (pdu_id, origin))
+
+ # Did we get anything?
+ if txn.fetchall():
+ logger.debug(
+ "is_new true: id=%s, o=%s, d=%s was forward",
+ pdu_id, origin, depth
+ )
+ return True
+
+ logger.debug(
+ "is_new false: id=%s, o=%s, d=%s",
+ pdu_id, origin, depth
+ )
+
+ # FINE THEN. It's probably old.
+ return False
+
+ @staticmethod
+ @log_function
+ def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
+ context):
+ txn.executemany(
+ PduEdgesTable.insert_statement(),
+ [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
+ )
+
+ # Update the extremities table if this is not an outlier.
+ if not outlier:
+
+ # First, we delete the new one from the forwards extremities table.
+ query = (
+ "DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
+ % PduForwardExtremitiesTable.table_name
+ )
+ txn.executemany(query, prev_pdus)
+
+ # We only insert as a forward extremety the new pdu if there are no
+ # other pdus that reference it as a prev pdu
+ query = (
+ "INSERT INTO %(table)s (pdu_id, origin, context) "
+ "SELECT ?, ?, ? WHERE NOT EXISTS ("
+ "SELECT 1 FROM %(pdu_edges)s WHERE "
+ "prev_pdu_id = ? AND prev_origin = ?"
+ ")"
+ ) % {
+ "table": PduForwardExtremitiesTable.table_name,
+ "pdu_edges": PduEdgesTable.table_name
+ }
+
+ logger.debug("query: %s", query)
+
+ txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
+
+ # Insert all the prev_pdus as a backwards thing, they'll get
+ # deleted in a second if they're incorrect anyway.
+ txn.executemany(
+ PduBackwardExtremitiesTable.insert_statement(),
+ [(i, o, context) for i, o in prev_pdus]
+ )
+
+ # Also delete from the backwards extremities table all ones that
+ # reference pdus that we have already seen
+ query = (
+ "DELETE FROM %(pdu_back)s WHERE EXISTS ("
+ "SELECT 1 FROM %(pdus)s AS pdus "
+ "WHERE "
+ "%(pdu_back)s.pdu_id = pdus.pdu_id "
+ "AND %(pdu_back)s.origin = pdus.origin "
+ "AND not pdus.outlier "
+ ")"
+ ) % {
+ "pdu_back": PduBackwardExtremitiesTable.table_name,
+ "pdus": PdusTable.table_name,
+ }
+ txn.execute(query)
+
+
+class StatePduStore(SQLBaseStore):
+ """A collection of queries for handling state PDUs.
+ """
+
+ def persist_state(self, prev_pdus, **cols):
+ """Inserts a state PDU into the database
+
+ Args:
+ txn,
+ prev_pdus (list)
+ **cols: The columns to insert into the PdusTable and StatePdusTable
+ """
+
+ return self._db_pool.runInteraction(
+ self._persist_state, prev_pdus, cols
+ )
+
+ def _persist_state(self, txn, prev_pdus, cols):
+ pdu_entry = PdusTable.EntryType(
+ **{k: cols.get(k, None) for k in PdusTable.fields}
+ )
+ state_entry = StatePdusTable.EntryType(
+ **{k: cols.get(k, None) for k in StatePdusTable.fields}
+ )
+
+ logger.debug("Inserting pdu: %s", repr(pdu_entry))
+ logger.debug("Inserting state: %s", repr(state_entry))
+
+ txn.execute(PdusTable.insert_statement(), pdu_entry)
+ txn.execute(StatePdusTable.insert_statement(), state_entry)
+
+ self._handle_prev_pdus(
+ txn,
+ pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
+ pdu_entry.context
+ )
+
+ def get_unresolved_state_tree(self, new_state_pdu):
+ return self._db_pool.runInteraction(
+ self._get_unresolved_state_tree, new_state_pdu
+ )
+
+ @log_function
+ def _get_unresolved_state_tree(self, txn, new_pdu):
+ current = self._get_current_interaction(
+ txn,
+ new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
+ )
+
+ ReturnType = namedtuple(
+ "StateReturnType", ["new_branch", "current_branch"]
+ )
+ return_value = ReturnType([new_pdu], [])
+
+ if not current:
+ logger.debug("get_unresolved_state_tree No current state.")
+ return return_value
+
+ return_value.current_branch.append(current)
+
+ enum_branches = self._enumerate_state_branches(
+ txn, new_pdu, current
+ )
+
+ for branch, prev_state, state in enum_branches:
+ if state:
+ return_value[branch].append(state)
+ else:
+ break
+
+ return return_value
+
+ def update_current_state(self, pdu_id, origin, context, pdu_type,
+ state_key):
+ return self._db_pool.runInteraction(
+ self._update_current_state,
+ pdu_id, origin, context, pdu_type, state_key
+ )
+
+ def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
+ state_key):
+ query = (
+ "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
+ ) % {
+ "curr": CurrentStateTable.table_name,
+ "fields": CurrentStateTable.get_fields_string(),
+ "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
+ }
+
+ query_args = CurrentStateTable.EntryType(
+ pdu_id=pdu_id,
+ origin=origin,
+ context=context,
+ pdu_type=pdu_type,
+ state_key=state_key
+ )
+
+ txn.execute(query, query_args)
+
+ def get_current_state(self, context, pdu_type, state_key):
+ """For a given context, pdu_type, state_key 3-tuple, return what is
+ currently considered the current state.
+
+ Args:
+ txn
+ context (str)
+ pdu_type (str)
+ state_key (str)
+
+ Returns:
+ PduEntry
+ """
+
+ return self._db_pool.runInteraction(
+ self._get_current_state, context, pdu_type, state_key
+ )
+
+ def _get_current_state(self, txn, context, pdu_type, state_key):
+ return self._get_current_interaction(txn, context, pdu_type, state_key)
+
+ def _get_current_interaction(self, txn, context, pdu_type, state_key):
+ logger.debug(
+ "_get_current_interaction %s %s %s",
+ context, pdu_type, state_key
+ )
+
+ fields = _pdu_state_joiner.get_fields(
+ PdusTable="p", StatePdusTable="s")
+
+ current_query = (
+ "SELECT %(fields)s FROM %(state)s as s "
+ "INNER JOIN %(pdus)s as p "
+ "ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
+ "INNER JOIN %(curr)s as c "
+ "ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
+ "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
+ ) % {
+ "fields": fields,
+ "curr": CurrentStateTable.table_name,
+ "state": StatePdusTable.table_name,
+ "pdus": PdusTable.table_name,
+ }
+
+ txn.execute(
+ current_query,
+ (context, pdu_type, state_key)
+ )
+
+ row = txn.fetchone()
+
+ result = PduEntry(*row) if row else None
+
+ if not result:
+ logger.debug("_get_current_interaction not found")
+ else:
+ logger.debug(
+ "_get_current_interaction found %s %s",
+ result.pdu_id, result.origin
+ )
+
+ return result
+
+ def get_next_missing_pdu(self, new_pdu):
+ """When we get a new state pdu we need to check whether we need to do
+ any conflict resolution, if we do then we need to check if we need
+ to go back and request some more state pdus that we haven't seen yet.
+
+ Args:
+ txn
+ new_pdu
+
+ Returns:
+ PduIdTuple: A pdu that we are missing, or None if we have all the
+ pdus required to do the conflict resolution.
+ """
+ return self._db_pool.runInteraction(
+ self._get_next_missing_pdu, new_pdu
+ )
+
+ def _get_next_missing_pdu(self, txn, new_pdu):
+ logger.debug(
+ "get_next_missing_pdu %s %s",
+ new_pdu.pdu_id, new_pdu.origin
+ )
+
+ current = self._get_current_interaction(
+ txn,
+ new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
+ )
+
+ if (not current or not current.prev_state_id
+ or not current.prev_state_origin):
+ return None
+
+ # Oh look, it's a straight clobber, so wooooo almost no-op.
+ if (new_pdu.prev_state_id == current.pdu_id
+ and new_pdu.prev_state_origin == current.origin):
+ return None
+
+ enum_branches = self._enumerate_state_branches(txn, new_pdu, current)
+ for branch, prev_state, state in enum_branches:
+ if not state:
+ return PduIdTuple(
+ prev_state.prev_state_id,
+ prev_state.prev_state_origin
+ )
+
+ return None
+
+ def handle_new_state(self, new_pdu):
+ """Actually perform conflict resolution on the new_pdu on the
+ assumption we have all the pdus required to perform it.
+
+ Args:
+ new_pdu
+
+ Returns:
+ bool: True if the new_pdu clobbered the current state, False if not
+ """
+ return self._db_pool.runInteraction(
+ self._handle_new_state, new_pdu
+ )
+
+ def _handle_new_state(self, txn, new_pdu):
+ logger.debug(
+ "handle_new_state %s %s",
+ new_pdu.pdu_id, new_pdu.origin
+ )
+
+ current = self._get_current_interaction(
+ txn,
+ new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
+ )
+
+ is_current = False
+
+ if (not current or not current.prev_state_id
+ or not current.prev_state_origin):
+ # Oh, we don't have any state for this yet.
+ is_current = True
+ elif (current.pdu_id == new_pdu.prev_state_id
+ and current.origin == new_pdu.prev_state_origin):
+ # Oh! A direct clobber. Just do it.
+ is_current = True
+ else:
+ ##
+ # Ok, now loop through until we get to a common ancestor.
+ max_new = int(new_pdu.power_level)
+ max_current = int(current.power_level)
+
+ enum_branches = self._enumerate_state_branches(
+ txn, new_pdu, current
+ )
+ for branch, prev_state, state in enum_branches:
+ if not state:
+ raise RuntimeError(
+ "Could not find state_pdu %s %s" %
+ (
+ prev_state.prev_state_id,
+ prev_state.prev_state_origin
+ )
+ )
+
+ if branch == 0:
+ max_new = max(int(state.depth), max_new)
+ else:
+ max_current = max(int(state.depth), max_current)
+
+ is_current = max_new > max_current
+
+ if is_current:
+ logger.debug("handle_new_state make current")
+
+ # Right, this is a new thing, so woo, just insert it.
+ txn.execute(
+ "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
+ % {
+ "curr": CurrentStateTable.table_name,
+ "fields": CurrentStateTable.get_fields_string(),
+ "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
+ },
+ CurrentStateTable.EntryType(
+ *(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
+ )
+ )
+ else:
+ logger.debug("handle_new_state not current")
+
+ logger.debug("handle_new_state done")
+
+ return is_current
+
+ @classmethod
+ @log_function
+ def _enumerate_state_branches(cls, txn, pdu_a, pdu_b):
+ branch_a = pdu_a
+ branch_b = pdu_b
+
+ get_query = (
+ "SELECT %(fields)s FROM %(pdus)s as p "
+ "LEFT JOIN %(state)s as s "
+ "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
+ "WHERE p.pdu_id = ? AND p.origin = ? "
+ ) % {
+ "fields": _pdu_state_joiner.get_fields(
+ PdusTable="p", StatePdusTable="s"),
+ "pdus": PdusTable.table_name,
+ "state": StatePdusTable.table_name,
+ }
+
+ while True:
+ if (branch_a.pdu_id == branch_b.pdu_id
+ and branch_a.origin == branch_b.origin):
+ # Woo! We found a common ancestor
+ logger.debug("_enumerate_state_branches Found common ancestor")
+ break
+
+ do_branch_a = (
+ hasattr(branch_a, "prev_state_id") and
+ branch_a.prev_state_id
+ )
+
+ do_branch_b = (
+ hasattr(branch_b, "prev_state_id") and
+ branch_b.prev_state_id
+ )
+
+ logger.debug(
+ "do_branch_a=%s, do_branch_b=%s",
+ do_branch_a, do_branch_b
+ )
+
+ if do_branch_a and do_branch_b:
+ do_branch_a = int(branch_a.depth) > int(branch_b.depth)
+
+ if do_branch_a:
+ pdu_tuple = PduIdTuple(
+ branch_a.prev_state_id,
+ branch_a.prev_state_origin
+ )
+
+ logger.debug("getting branch_a prev %s", pdu_tuple)
+ txn.execute(get_query, pdu_tuple)
+
+ prev_branch = branch_a
+
+ res = txn.fetchone()
+ branch_a = PduEntry(*res) if res else None
+
+ logger.debug("branch_a=%s", branch_a)
+
+ yield (0, prev_branch, branch_a)
+
+ if not branch_a:
+ break
+ elif do_branch_b:
+ pdu_tuple = PduIdTuple(
+ branch_b.prev_state_id,
+ branch_b.prev_state_origin
+ )
+ txn.execute(get_query, pdu_tuple)
+
+ logger.debug("getting branch_b prev %s", pdu_tuple)
+
+ prev_branch = branch_b
+
+ res = txn.fetchone()
+ branch_b = PduEntry(*res) if res else None
+
+ logger.debug("branch_b=%s", branch_b)
+
+ yield (1, prev_branch, branch_b)
+
+ if not branch_b:
+ break
+ else:
+ break
+
+
+class PdusTable(Table):
+ table_name = "pdus"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "context",
+ "pdu_type",
+ "ts",
+ "depth",
+ "is_state",
+ "content_json",
+ "unrecognized_keys",
+ "outlier",
+ "have_processed",
+ ]
+
+ EntryType = namedtuple("PdusEntry", fields)
+
+
+class PduDestinationsTable(Table):
+ table_name = "pdu_destinations"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "destination",
+ "delivered_ts",
+ ]
+
+ EntryType = namedtuple("PduDestinationsEntry", fields)
+
+
+class PduEdgesTable(Table):
+ table_name = "pdu_edges"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "prev_pdu_id",
+ "prev_origin",
+ "context"
+ ]
+
+ EntryType = namedtuple("PduEdgesEntry", fields)
+
+
+class PduForwardExtremitiesTable(Table):
+ table_name = "pdu_forward_extremities"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "context",
+ ]
+
+ EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
+
+
+class PduBackwardExtremitiesTable(Table):
+ table_name = "pdu_backward_extremities"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "context",
+ ]
+
+ EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
+
+
+class ContextDepthTable(Table):
+ table_name = "context_depth"
+
+ fields = [
+ "context",
+ "min_depth",
+ ]
+
+ EntryType = namedtuple("ContextDepthEntry", fields)
+
+
+class StatePdusTable(Table):
+ table_name = "state_pdus"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "context",
+ "pdu_type",
+ "state_key",
+ "power_level",
+ "prev_state_id",
+ "prev_state_origin",
+ ]
+
+ EntryType = namedtuple("StatePdusEntry", fields)
+
+
+class CurrentStateTable(Table):
+ table_name = "current_state"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "context",
+ "pdu_type",
+ "state_key",
+ ]
+
+ EntryType = namedtuple("CurrentStateEntry", fields)
+
+_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
+
+
+# TODO: These should probably be put somewhere more sensible
+PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
+
+PduEntry = _pdu_state_joiner.EntryType
+""" We are always interested in the join of the PdusTable and StatePdusTable,
+rather than just the PdusTable.
+
+This does not include a prev_pdus key.
+"""
+
+PduTuple = namedtuple(
+ "PduTuple",
+ ("pdu_entry", "prev_pdu_list")
+)
+""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
+the `prev_pdus` key of a PDU.
+"""
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
new file mode 100644
index 0000000000..e57ddaf149
--- /dev/null
+++ b/synapse/storage/presence.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+
+
+class PresenceStore(SQLBaseStore):
+ def create_presence(self, user_localpart):
+ return self._simple_insert(
+ table="presence",
+ values={"user_id": user_localpart},
+ )
+
+ def has_presence_state(self, user_localpart):
+ return self._simple_select_one(
+ table="presence",
+ keyvalues={"user_id": user_localpart},
+ retcols=["user_id"],
+ allow_none=True,
+ )
+
+ def get_presence_state(self, user_localpart):
+ return self._simple_select_one(
+ table="presence",
+ keyvalues={"user_id": user_localpart},
+ retcols=["state", "status_msg"],
+ )
+
+ def set_presence_state(self, user_localpart, new_state):
+ return self._simple_update_one(
+ table="presence",
+ keyvalues={"user_id": user_localpart},
+ updatevalues={"state": new_state["state"],
+ "status_msg": new_state["status_msg"]},
+ retcols=["state"],
+ )
+
+ def allow_presence_visible(self, observed_localpart, observer_userid):
+ return self._simple_insert(
+ table="presence_allow_inbound",
+ values={"observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid},
+ )
+
+ def disallow_presence_visible(self, observed_localpart, observer_userid):
+ return self._simple_delete_one(
+ table="presence_allow_inbound",
+ keyvalues={"observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid},
+ )
+
+ def is_presence_visible(self, observed_localpart, observer_userid):
+ return self._simple_select_one(
+ table="presence_allow_inbound",
+ keyvalues={"observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid},
+ allow_none=True,
+ )
+
+ def add_presence_list_pending(self, observer_localpart, observed_userid):
+ return self._simple_insert(
+ table="presence_list",
+ values={"user_id": observer_localpart,
+ "observed_user_id": observed_userid,
+ "accepted": False},
+ )
+
+ def set_presence_list_accepted(self, observer_localpart, observed_userid):
+ return self._simple_update_one(
+ table="presence_list",
+ keyvalues={"user_id": observer_localpart,
+ "observed_user_id": observed_userid},
+ updatevalues={"accepted": True},
+ )
+
+ def get_presence_list(self, observer_localpart, accepted=None):
+ keyvalues = {"user_id": observer_localpart}
+ if accepted is not None:
+ keyvalues["accepted"] = accepted
+
+ return self._simple_select_list(
+ table="presence_list",
+ keyvalues=keyvalues,
+ retcols=["observed_user_id", "accepted"],
+ )
+
+ def del_presence_list(self, observer_localpart, observed_userid):
+ return self._simple_delete_one(
+ table="presence_list",
+ keyvalues={"user_id": observer_localpart,
+ "observed_user_id": observed_userid},
+ )
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
new file mode 100644
index 0000000000..d2f24930c1
--- /dev/null
+++ b/synapse/storage/profile.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+
+
+class ProfileStore(SQLBaseStore):
+ def create_profile(self, user_localpart):
+ return self._simple_insert(
+ table="profiles",
+ values={"user_id": user_localpart},
+ )
+
+ def get_profile_displayname(self, user_localpart):
+ return self._simple_select_one_onecol(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ retcol="displayname",
+ )
+
+ def set_profile_displayname(self, user_localpart, new_displayname):
+ return self._simple_update_one(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ updatevalues={"displayname": new_displayname},
+ )
+
+ def get_profile_avatar_url(self, user_localpart):
+ return self._simple_select_one_onecol(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ retcol="avatar_url",
+ )
+
+ def set_profile_avatar_url(self, user_localpart, new_avatar_url):
+ return self._simple_update_one(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ updatevalues={"avatar_url": new_avatar_url},
+ )
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
new file mode 100644
index 0000000000..4a970dd546
--- /dev/null
+++ b/synapse/storage/registration.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from sqlite3 import IntegrityError
+
+from synapse.api.errors import StoreError
+
+from ._base import SQLBaseStore
+
+
+class RegistrationStore(SQLBaseStore):
+
+ def __init__(self, hs):
+ super(RegistrationStore, self).__init__(hs)
+
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def add_access_token_to_user(self, user_id, token):
+ """Adds an access token for the given user.
+
+ Args:
+ user_id (str): The user ID.
+ token (str): The new access token to add.
+ Raises:
+ StoreError if there was a problem adding this.
+ """
+ row = yield self._simple_select_one("users", {"name": user_id}, ["id"])
+ if not row:
+ raise StoreError(400, "Bad user ID supplied.")
+ row_id = row["id"]
+ yield self._simple_insert(
+ "access_tokens",
+ {
+ "user_id": row_id,
+ "token": token
+ }
+ )
+
+ @defer.inlineCallbacks
+ def register(self, user_id, token, password_hash):
+ """Attempts to register an account.
+
+ Args:
+ user_id (str): The desired user ID to register.
+ token (str): The desired access token to use for this user.
+ password_hash (str): Optional. The password hash for this user.
+ Raises:
+ StoreError if the user_id could not be registered.
+ """
+ yield self._db_pool.runInteraction(self._register, user_id, token,
+ password_hash)
+
+ def _register(self, txn, user_id, token, password_hash):
+ now = int(self.clock.time())
+
+ try:
+ txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
+ "VALUES (?,?,?)",
+ [user_id, password_hash, now])
+ except IntegrityError:
+ raise StoreError(400, "User ID already taken.")
+
+ # it's possible for this to get a conflict, but only for a single user
+ # since tokens are namespaced based on their user ID
+ txn.execute("INSERT INTO access_tokens(user_id, token) " +
+ "VALUES (?,?)", [txn.lastrowid, token])
+
+ def get_user_by_id(self, user_id):
+ query = ("SELECT users.name, users.password_hash FROM users "
+ "WHERE users.name = ?")
+ return self._execute(
+ self.cursor_to_dict,
+ query, user_id
+ )
+
+ @defer.inlineCallbacks
+ def get_user_by_token(self, token):
+ """Get a user from the given access token.
+
+ Args:
+ token (str): The access token of a user.
+ Returns:
+ str: The user ID of the user.
+ Raises:
+ StoreError if no user was found.
+ """
+ user_id = yield self._db_pool.runInteraction(self._query_for_auth,
+ token)
+ defer.returnValue(user_id)
+
+ def _query_for_auth(self, txn, token):
+ txn.execute("SELECT users.name FROM access_tokens LEFT JOIN users" +
+ " ON users.id = access_tokens.user_id WHERE token = ?",
+ [token])
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ raise StoreError(404, "Token not found.")
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
new file mode 100644
index 0000000000..174cbcf3d8
--- /dev/null
+++ b/synapse/storage/room.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from sqlite3 import IntegrityError
+
+from synapse.api.errors import StoreError
+from synapse.api.events.room import RoomTopicEvent
+
+from ._base import SQLBaseStore, Table
+
+import collections
+import json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class RoomStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def store_room(self, room_id, room_creator_user_id, is_public):
+ """Stores a room.
+
+ Args:
+ room_id (str): The desired room ID, can be None.
+ room_creator_user_id (str): The user ID of the room creator.
+ is_public (bool): True to indicate that this room should appear in
+ public room lists.
+ Raises:
+ StoreError if the room could not be stored.
+ """
+ try:
+ yield self._simple_insert(RoomsTable.table_name, dict(
+ room_id=room_id,
+ creator=room_creator_user_id,
+ is_public=is_public
+ ))
+ except IntegrityError:
+ raise StoreError(409, "Room ID in use.")
+ except Exception as e:
+ logger.error("store_room with room_id=%s failed: %s", room_id, e)
+ raise StoreError(500, "Problem creating room.")
+
+ def store_room_config(self, room_id, visibility):
+ return self._simple_update_one(
+ table=RoomsTable.table_name,
+ keyvalues={"room_id": room_id},
+ updatevalues={"is_public": visibility}
+ )
+
+ def get_room(self, room_id):
+ """Retrieve a room.
+
+ Args:
+ room_id (str): The ID of the room to retrieve.
+ Returns:
+ A namedtuple containing the room information, or an empty list.
+ """
+ query = RoomsTable.select_statement("room_id=?")
+ return self._execute(
+ RoomsTable.decode_single_result, query, room_id,
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms(self, is_public, with_topics):
+ """Retrieve a list of all public rooms.
+
+ Args:
+ is_public (bool): True if the rooms returned should be public.
+ with_topics (bool): True to include the current topic for the room
+ in the response.
+ Returns:
+ A list of room dicts containing at least a "room_id" key, and a
+ "topic" key if one is set and with_topic=True.
+ """
+ room_data_type = RoomTopicEvent.TYPE
+ public = 1 if is_public else 0
+
+ latest_topic = ("SELECT max(room_data.id) FROM room_data WHERE "
+ + "room_data.type = ? GROUP BY room_id")
+
+ query = ("SELECT rooms.*, room_data.content FROM rooms LEFT JOIN "
+ + "room_data ON rooms.room_id = room_data.room_id WHERE "
+ + "(room_data.id IN (" + latest_topic + ") "
+ + "OR room_data.id IS NULL) AND rooms.is_public = ?")
+
+ res = yield self._execute(
+ self.cursor_to_dict, query, room_data_type, public
+ )
+
+ # return only the keys the specification expects
+ ret_keys = ["room_id", "topic"]
+
+ # extract topic from the json (icky) FIXME
+ for i, room_row in enumerate(res):
+ try:
+ content_json = json.loads(room_row["content"])
+ room_row["topic"] = content_json["topic"]
+ except:
+ pass # no topic set
+ # filter the dict based on ret_keys
+ res[i] = {k: v for k, v in room_row.iteritems() if k in ret_keys}
+
+ defer.returnValue(res)
+
+
+class RoomsTable(Table):
+ table_name = "rooms"
+
+ fields = [
+ "room_id",
+ "is_public",
+ "creator"
+ ]
+
+ EntryType = collections.namedtuple("RoomEntry", fields)
diff --git a/synapse/storage/roomdata.py b/synapse/storage/roomdata.py
new file mode 100644
index 0000000000..781d477931
--- /dev/null
+++ b/synapse/storage/roomdata.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore, Table
+
+import collections
+import json
+
+
+class RoomDataStore(SQLBaseStore):
+
+ """Provides various CRUD operations for Room Events. """
+
+ def get_room_data(self, room_id, etype, state_key=""):
+ """Retrieve the data stored under this type and state_key.
+
+ Args:
+ room_id (str)
+ etype (str)
+ state_key (str)
+ Returns:
+ namedtuple: Or None if nothing exists at this path.
+ """
+ query = RoomDataTable.select_statement(
+ "room_id = ? AND type = ? AND state_key = ? "
+ "ORDER BY id DESC LIMIT 1"
+ )
+ return self._execute(
+ RoomDataTable.decode_single_result,
+ query, room_id, etype, state_key,
+ )
+
+ def store_room_data(self, room_id, etype, state_key="", content=None):
+ """Stores room specific data.
+
+ Args:
+ room_id (str)
+ etype (str)
+ state_key (str)
+ data (str)- The data to store for this path in JSON.
+ Returns:
+ The store ID for this data.
+ """
+ return self._simple_insert(RoomDataTable.table_name, dict(
+ etype=etype,
+ state_key=state_key,
+ room_id=room_id,
+ content=content,
+ ))
+
+ def get_max_room_data_id(self):
+ return self._simple_max_id(RoomDataTable.table_name)
+
+
+class RoomDataTable(Table):
+ table_name = "room_data"
+
+ fields = [
+ "id",
+ "room_id",
+ "type",
+ "state_key",
+ "content"
+ ]
+
+ class EntryType(collections.namedtuple("RoomDataEntry", fields)):
+
+ def as_event(self, event_factory):
+ return event_factory.create_event(
+ etype=self.type,
+ room_id=self.room_id,
+ content=json.loads(self.content),
+ )
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
new file mode 100644
index 0000000000..e6e7617797
--- /dev/null
+++ b/synapse/storage/roommember.py
@@ -0,0 +1,171 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.types import UserID
+from synapse.api.constants import Membership
+from synapse.api.events.room import RoomMemberEvent
+
+from ._base import SQLBaseStore, Table
+
+
+import collections
+import json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class RoomMemberStore(SQLBaseStore):
+
+ def get_room_member(self, user_id, room_id):
+ """Retrieve the current state of a room member.
+
+ Args:
+ user_id (str): The member's user ID.
+ room_id (str): The room the member is in.
+ Returns:
+ namedtuple: The room member from the database, or None if this
+ member does not exist.
+ """
+ query = RoomMemberTable.select_statement(
+ "room_id = ? AND user_id = ? ORDER BY id DESC LIMIT 1")
+ return self._execute(
+ RoomMemberTable.decode_single_result,
+ query, room_id, user_id,
+ )
+
+ def store_room_member(self, user_id, sender, room_id, membership, content):
+ """Store a room member in the database.
+
+ Args:
+ user_id (str): The member's user ID.
+ room_id (str): The room in relation to the member.
+ membership (synapse.api.constants.Membership): The new membership
+ state.
+ content (dict): The content of the membership (JSON).
+ """
+ content_json = json.dumps(content)
+ return self._simple_insert(RoomMemberTable.table_name, dict(
+ user_id=user_id,
+ sender=sender,
+ room_id=room_id,
+ membership=membership,
+ content=content_json,
+ ))
+
+ @defer.inlineCallbacks
+ def get_room_members(self, room_id, membership=None):
+ """Retrieve the current room member list for a room.
+
+ Args:
+ room_id (str): The room to get the list of members.
+ membership (synapse.api.constants.Membership): The filter to apply
+ to this list, or None to return all members with some state
+ associated with this room.
+ Returns:
+ list of namedtuples representing the members in this room.
+ """
+ query = RoomMemberTable.select_statement(
+ "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
+ + " WHERE room_id = ? GROUP BY user_id)"
+ )
+ res = yield self._execute(
+ RoomMemberTable.decode_results, query, room_id,
+ )
+ # strip memberships which don't match
+ if membership:
+ res = [entry for entry in res if entry.membership == membership]
+ defer.returnValue(res)
+
+ def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
+ """ Get all the rooms for this user where the membership for this user
+ matches one in the membership list.
+
+ Args:
+ user_id (str): The user ID.
+ membership_list (list): A list of synapse.api.constants.Membership
+ values which the user must be in.
+ Returns:
+ A list of dicts with "room_id" and "membership" keys.
+ """
+ if not membership_list:
+ return defer.succeed(None)
+
+ args = [user_id]
+ membership_placeholder = ["membership=?"] * len(membership_list)
+ where_membership = "(" + " OR ".join(membership_placeholder) + ")"
+ for membership in membership_list:
+ args.append(membership)
+
+ query = ("SELECT room_id, membership FROM room_memberships"
+ + " WHERE user_id=? AND " + where_membership
+ + " GROUP BY room_id ORDER BY id DESC")
+ return self._execute(
+ self.cursor_to_dict, query, *args
+ )
+
+ @defer.inlineCallbacks
+ def get_joined_hosts_for_room(self, room_id):
+ query = RoomMemberTable.select_statement(
+ "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
+ + " WHERE room_id = ? GROUP BY user_id)"
+ )
+
+ res = yield self._execute(
+ RoomMemberTable.decode_results, query, room_id,
+ )
+
+ def host_from_user_id_string(user_id):
+ domain = UserID.from_string(entry.user_id, self.hs).domain
+ return domain
+
+ # strip memberships which don't match
+ hosts = [
+ host_from_user_id_string(entry.user_id)
+ for entry in res
+ if entry.membership == Membership.JOIN
+ ]
+
+ logger.debug("Returning hosts: %s from results: %s", hosts, res)
+
+ defer.returnValue(hosts)
+
+ def get_max_room_member_id(self):
+ return self._simple_max_id(RoomMemberTable.table_name)
+
+
+class RoomMemberTable(Table):
+ table_name = "room_memberships"
+
+ fields = [
+ "id",
+ "user_id",
+ "sender",
+ "room_id",
+ "membership",
+ "content"
+ ]
+
+ class EntryType(collections.namedtuple("RoomMemberEntry", fields)):
+
+ def as_event(self, event_factory):
+ return event_factory.create_event(
+ etype=RoomMemberEvent.TYPE,
+ room_id=self.room_id,
+ target_user_id=self.user_id,
+ user_id=self.sender,
+ content=json.loads(self.content),
+ )
diff --git a/synapse/storage/schema/edge_pdus.sql b/synapse/storage/schema/edge_pdus.sql
new file mode 100644
index 0000000000..17b3c52f0d
--- /dev/null
+++ b/synapse/storage/schema/edge_pdus.sql
@@ -0,0 +1,31 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS context_edge_pdus(
+ id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin)
+);
+
+CREATE TABLE IF NOT EXISTS origin_edge_pdus(
+ id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
+ pdu_id TEXT,
+ origin TEXT,
+ CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin)
+);
+
+CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin);
+CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin);
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
new file mode 100644
index 0000000000..77096546b2
--- /dev/null
+++ b/synapse/storage/schema/im.sql
@@ -0,0 +1,54 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS rooms(
+ room_id TEXT PRIMARY KEY NOT NULL,
+ is_public INTEGER,
+ creator TEXT
+);
+
+CREATE TABLE IF NOT EXISTS room_memberships(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id TEXT NOT NULL, -- no foreign key to users table, it could be an id belonging to another home server
+ sender TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ membership TEXT NOT NULL,
+ content TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS messages(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id TEXT,
+ room_id TEXT,
+ msg_id TEXT,
+ content TEXT
+);
+
+CREATE TABLE IF NOT EXISTS feedback(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ content TEXT,
+ feedback_type TEXT,
+ fb_sender_id TEXT,
+ msg_id TEXT,
+ room_id TEXT,
+ msg_sender_id TEXT
+);
+
+CREATE TABLE IF NOT EXISTS room_data(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ content TEXT
+);
diff --git a/synapse/storage/schema/pdu.sql b/synapse/storage/schema/pdu.sql
new file mode 100644
index 0000000000..ca3de005e9
--- /dev/null
+++ b/synapse/storage/schema/pdu.sql
@@ -0,0 +1,106 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- Stores pdus and their content
+CREATE TABLE IF NOT EXISTS pdus(
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ pdu_type TEXT,
+ ts INTEGER,
+ depth INTEGER DEFAULT 0 NOT NULL,
+ is_state BOOL,
+ content_json TEXT,
+ unrecognized_keys TEXT,
+ outlier BOOL NOT NULL,
+ have_processed BOOL,
+ CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
+);
+
+-- Stores what the current state pdu is for a given (context, pdu_type, key) tuple
+CREATE TABLE IF NOT EXISTS state_pdus(
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ pdu_type TEXT,
+ state_key TEXT,
+ power_level TEXT,
+ prev_state_id TEXT,
+ prev_state_origin TEXT,
+ CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
+ CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin)
+);
+
+CREATE TABLE IF NOT EXISTS current_state(
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ pdu_type TEXT,
+ state_key TEXT,
+ CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
+ CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE
+);
+
+-- Stores where each pdu we want to send should be sent and the delivery status.
+create TABLE IF NOT EXISTS pdu_destinations(
+ pdu_id TEXT,
+ origin TEXT,
+ destination TEXT,
+ delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
+ CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE
+);
+
+CREATE TABLE IF NOT EXISTS pdu_forward_extremities(
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
+);
+
+CREATE TABLE IF NOT EXISTS pdu_backward_extremities(
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
+);
+
+CREATE TABLE IF NOT EXISTS pdu_edges(
+ pdu_id TEXT,
+ origin TEXT,
+ prev_pdu_id TEXT,
+ prev_origin TEXT,
+ context TEXT,
+ CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context)
+);
+
+CREATE TABLE IF NOT EXISTS context_depth(
+ context TEXT,
+ min_depth INTEGER,
+ CONSTRAINT uniqueness UNIQUE (context)
+);
+
+CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context);
+
+
+CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin);
+
+CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin);
+-- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination);
+
+CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context);
+CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin);
+
+CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin);
+
+CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context);
diff --git a/synapse/storage/schema/presence.sql b/synapse/storage/schema/presence.sql
new file mode 100644
index 0000000000..b22e3ba863
--- /dev/null
+++ b/synapse/storage/schema/presence.sql
@@ -0,0 +1,37 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS presence(
+ user_id INTEGER NOT NULL,
+ state INTEGER,
+ status_msg TEXT,
+ FOREIGN KEY(user_id) REFERENCES users(id)
+);
+
+-- For each of /my/ users which possibly-remote users are allowed to see their
+-- presence state
+CREATE TABLE IF NOT EXISTS presence_allow_inbound(
+ observed_user_id INTEGER NOT NULL,
+ observer_user_id TEXT, -- a UserID,
+ FOREIGN KEY(observed_user_id) REFERENCES users(id)
+);
+
+-- For each of /my/ users (watcher), which possibly-remote users are they
+-- watching?
+CREATE TABLE IF NOT EXISTS presence_list(
+ user_id INTEGER NOT NULL,
+ observed_user_id TEXT, -- a UserID,
+ accepted BOOLEAN,
+ FOREIGN KEY(user_id) REFERENCES users(id)
+);
diff --git a/synapse/storage/schema/profiles.sql b/synapse/storage/schema/profiles.sql
new file mode 100644
index 0000000000..1092d7672c
--- /dev/null
+++ b/synapse/storage/schema/profiles.sql
@@ -0,0 +1,20 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS profiles(
+ user_id INTEGER NOT NULL,
+ displayname TEXT,
+ avatar_url TEXT,
+ FOREIGN KEY(user_id) REFERENCES users(id)
+);
diff --git a/synapse/storage/schema/room_aliases.sql b/synapse/storage/schema/room_aliases.sql
new file mode 100644
index 0000000000..71a8b90e4d
--- /dev/null
+++ b/synapse/storage/schema/room_aliases.sql
@@ -0,0 +1,12 @@
+CREATE TABLE IF NOT EXISTS room_aliases(
+ room_alias TEXT NOT NULL,
+ room_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS room_alias_servers(
+ room_alias TEXT NOT NULL,
+ server TEXT NOT NULL
+);
+
+
+
diff --git a/synapse/storage/schema/transactions.sql b/synapse/storage/schema/transactions.sql
new file mode 100644
index 0000000000..4b1a2368f6
--- /dev/null
+++ b/synapse/storage/schema/transactions.sql
@@ -0,0 +1,61 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- Stores what transaction ids we have received and what our response was
+CREATE TABLE IF NOT EXISTS received_transactions(
+ transaction_id TEXT,
+ origin TEXT,
+ ts INTEGER,
+ response_code INTEGER,
+ response_json TEXT,
+ has_been_referenced BOOL default 0, -- Whether thishas been referenced by a prev_tx
+ CONSTRAINT uniquesss UNIQUE (transaction_id, origin) ON CONFLICT REPLACE
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS transactions_txid ON received_transactions(transaction_id, origin);
+CREATE INDEX IF NOT EXISTS transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
+
+
+-- Stores what transactions we've sent, what their response was (if we got one) and whether we have
+-- since referenced the transaction in another outgoing transaction
+CREATE TABLE IF NOT EXISTS sent_transactions(
+ id INTEGER PRIMARY KEY AUTOINCREMENT, -- This is used to apply insertion ordering
+ transaction_id TEXT,
+ destination TEXT,
+ response_code INTEGER DEFAULT 0,
+ response_json TEXT,
+ ts INTEGER
+);
+
+CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination);
+CREATE INDEX IF NOT EXISTS sent_transaction_dest_referenced ON sent_transactions(
+ destination
+);
+-- So that we can do an efficient look up of all transactions that have yet to be successfully
+-- sent.
+CREATE INDEX IF NOT EXISTS sent_transaction_sent ON sent_transactions(response_code);
+
+
+-- For sent transactions only.
+CREATE TABLE IF NOT EXISTS transaction_id_to_pdu(
+ transaction_id INTEGER,
+ destination TEXT,
+ pdu_id TEXT,
+ pdu_origin TEXT
+);
+
+CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination);
+CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
+CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination);
+
diff --git a/synapse/storage/schema/users.sql b/synapse/storage/schema/users.sql
new file mode 100644
index 0000000000..46b60297cb
--- /dev/null
+++ b/synapse/storage/schema/users.sql
@@ -0,0 +1,31 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS users(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT,
+ password_hash TEXT,
+ creation_ts INTEGER,
+ UNIQUE(name) ON CONFLICT ROLLBACK
+);
+
+CREATE TABLE IF NOT EXISTS access_tokens(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ device_id TEXT,
+ token TEXT NOT NULL,
+ last_used INTEGER,
+ FOREIGN KEY(user_id) REFERENCES users(id),
+ UNIQUE(token) ON CONFLICT ROLLBACK
+);
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
new file mode 100644
index 0000000000..c3b1bfeb32
--- /dev/null
+++ b/synapse/storage/stream.py
@@ -0,0 +1,282 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from .message import MessagesTable
+from .feedback import FeedbackTable
+from .roomdata import RoomDataTable
+from .roommember import RoomMemberTable
+
+import json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class StreamStore(SQLBaseStore):
+
+ def get_message_stream(self, user_id, from_key, to_key, room_id, limit=0,
+ with_feedback=False):
+ """Get all messages for this user between the given keys.
+
+ Args:
+ user_id (str): The user who is requesting messages.
+ from_key (int): The ID to start returning results from (exclusive).
+ to_key (int): The ID to stop returning results (exclusive).
+ room_id (str): Gets messages only for this room. Can be None, in
+ which case all room messages will be returned.
+ Returns:
+ A tuple of rows (list of namedtuples), new_id(int)
+ """
+ if with_feedback and room_id: # with fb MUST specify a room ID
+ return self._db_pool.runInteraction(
+ self._get_message_rows_with_feedback,
+ user_id, from_key, to_key, room_id, limit
+ )
+ else:
+ return self._db_pool.runInteraction(
+ self._get_message_rows,
+ user_id, from_key, to_key, room_id, limit
+ )
+
+ def _get_message_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
+ limit):
+ # work out which rooms this user is joined in on and join them with
+ # the room id on the messages table, bounded by the specified pkeys
+
+ # get all messages where the *current* membership state is 'join' for
+ # this user in that room.
+ query = ("SELECT messages.* FROM messages WHERE ? IN"
+ + " (SELECT membership from room_memberships WHERE user_id=?"
+ + " AND room_id = messages.room_id ORDER BY id DESC LIMIT 1)")
+ query_args = ["join", user_id]
+
+ if room_id:
+ query += " AND messages.room_id=?"
+ query_args.append(room_id)
+
+ (query, query_args) = self._append_stream_operations(
+ "messages", query, query_args, from_pkey, to_pkey, limit=limit
+ )
+
+ logger.debug("[SQL] %s : %s", query, query_args)
+ cursor = txn.execute(query, query_args)
+ return self._as_events(cursor, MessagesTable, from_pkey)
+
+ def _get_message_rows_with_feedback(self, txn, user_id, from_pkey, to_pkey,
+ room_id, limit):
+ # this col represents the compressed feedback JSON as per spec
+ compressed_feedback_col = (
+ "'[' || group_concat('{\"sender_id\":\"' || f.fb_sender_id"
+ + " || '\",\"feedback_type\":\"' || f.feedback_type"
+ + " || '\",\"content\":' || f.content || '}') || ']'"
+ )
+
+ global_msg_id_join = ("f.room_id = messages.room_id"
+ + " and f.msg_id = messages.msg_id"
+ + " and messages.user_id = f.msg_sender_id")
+
+ select_query = (
+ "SELECT messages.*, f.content AS fb_content, f.fb_sender_id"
+ + ", " + compressed_feedback_col + " AS compressed_fb"
+ + " FROM messages LEFT JOIN feedback f ON " + global_msg_id_join)
+
+ current_membership_sub_query = (
+ "(SELECT membership from room_memberships rm"
+ + " WHERE user_id=? AND room_id = rm.room_id"
+ + " ORDER BY id DESC LIMIT 1)")
+
+ where = (" WHERE ? IN " + current_membership_sub_query
+ + " AND messages.room_id=?")
+
+ query = select_query + where
+ query_args = ["join", user_id, room_id]
+
+ (query, query_args) = self._append_stream_operations(
+ "messages", query, query_args, from_pkey, to_pkey,
+ limit=limit, group_by=" GROUP BY messages.id "
+ )
+
+ logger.debug("[SQL] %s : %s", query, query_args)
+ cursor = txn.execute(query, query_args)
+
+ # convert the result set into events
+ entries = self.cursor_to_dict(cursor)
+ events = []
+ for entry in entries:
+ # TODO we should spec the cursor > event mapping somewhere else.
+ event = {}
+ straight_mappings = ["msg_id", "user_id", "room_id"]
+ for key in straight_mappings:
+ event[key] = entry[key]
+ event["content"] = json.loads(entry["content"])
+ if entry["compressed_fb"]:
+ event["feedback"] = json.loads(entry["compressed_fb"])
+ events.append(event)
+
+ latest_pkey = from_pkey if len(entries) == 0 else entries[-1]["id"]
+
+ return (events, latest_pkey)
+
+ def get_room_member_stream(self, user_id, from_key, to_key):
+ """Get all room membership events for this user between the given keys.
+
+ Args:
+ user_id (str): The user who is requesting membership events.
+ from_key (int): The ID to start returning results from (exclusive).
+ to_key (int): The ID to stop returning results (exclusive).
+ Returns:
+ A tuple of rows (list of namedtuples), new_id(int)
+ """
+ return self._db_pool.runInteraction(
+ self._get_room_member_rows, user_id, from_key, to_key
+ )
+
+ def _get_room_member_rows(self, txn, user_id, from_pkey, to_pkey):
+ # get all room membership events for rooms which the user is
+ # *currently* joined in on, or all invite events for this user.
+ current_membership_sub_query = (
+ "(SELECT membership FROM room_memberships"
+ + " WHERE user_id=? AND room_id = rm.room_id"
+ + " ORDER BY id DESC LIMIT 1)")
+
+ query = ("SELECT rm.* FROM room_memberships rm "
+ # all membership events for rooms you've currently joined.
+ + " WHERE (? IN " + current_membership_sub_query
+ # all invite membership events for this user
+ + " OR rm.membership=? AND user_id=?)"
+ + " AND rm.id > ?")
+ query_args = ["join", user_id, "invite", user_id, from_pkey]
+
+ if to_pkey != -1:
+ query += " AND rm.id < ?"
+ query_args.append(to_pkey)
+
+ cursor = txn.execute(query, query_args)
+ return self._as_events(cursor, RoomMemberTable, from_pkey)
+
+ def get_feedback_stream(self, user_id, from_key, to_key, room_id, limit=0):
+ return self._db_pool.runInteraction(
+ self._get_feedback_rows,
+ user_id, from_key, to_key, room_id, limit
+ )
+
+ def _get_feedback_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
+ limit):
+ # work out which rooms this user is joined in on and join them with
+ # the room id on the feedback table, bounded by the specified pkeys
+
+ # get all messages where the *current* membership state is 'join' for
+ # this user in that room.
+ query = (
+ "SELECT feedback.* FROM feedback WHERE ? IN "
+ + "(SELECT membership from room_memberships WHERE user_id=?"
+ + " AND room_id = feedback.room_id ORDER BY id DESC LIMIT 1)")
+ query_args = ["join", user_id]
+
+ if room_id:
+ query += " AND feedback.room_id=?"
+ query_args.append(room_id)
+
+ (query, query_args) = self._append_stream_operations(
+ "feedback", query, query_args, from_pkey, to_pkey, limit=limit
+ )
+
+ logger.debug("[SQL] %s : %s", query, query_args)
+ cursor = txn.execute(query, query_args)
+ return self._as_events(cursor, FeedbackTable, from_pkey)
+
+ def get_room_data_stream(self, user_id, from_key, to_key, room_id,
+ limit=0):
+ return self._db_pool.runInteraction(
+ self._get_room_data_rows,
+ user_id, from_key, to_key, room_id, limit
+ )
+
+ def _get_room_data_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
+ limit):
+ # work out which rooms this user is joined in on and join them with
+ # the room id on the feedback table, bounded by the specified pkeys
+
+ # get all messages where the *current* membership state is 'join' for
+ # this user in that room.
+ query = (
+ "SELECT room_data.* FROM room_data WHERE ? IN "
+ + "(SELECT membership from room_memberships WHERE user_id=?"
+ + " AND room_id = room_data.room_id ORDER BY id DESC LIMIT 1)")
+ query_args = ["join", user_id]
+
+ if room_id:
+ query += " AND room_data.room_id=?"
+ query_args.append(room_id)
+
+ (query, query_args) = self._append_stream_operations(
+ "room_data", query, query_args, from_pkey, to_pkey, limit=limit
+ )
+
+ logger.debug("[SQL] %s : %s", query, query_args)
+ cursor = txn.execute(query, query_args)
+ return self._as_events(cursor, RoomDataTable, from_pkey)
+
+ def _append_stream_operations(self, table_name, query, query_args,
+ from_pkey, to_pkey, limit=None,
+ group_by=""):
+ LATEST_ROW = -1
+ order_by = ""
+ if to_pkey > from_pkey:
+ if from_pkey != LATEST_ROW:
+ # e.g. from=5 to=9 >> from 5 to 9 >> id>5 AND id<9
+ query += (" AND %s.id > ? AND %s.id < ?" %
+ (table_name, table_name))
+ query_args.append(from_pkey)
+ query_args.append(to_pkey)
+ else:
+ # e.g. from=-1 to=5 >> from now to 5 >> id>5 ORDER BY id DESC
+ query += " AND %s.id > ? " % table_name
+ order_by = "ORDER BY id DESC"
+ query_args.append(to_pkey)
+ elif from_pkey > to_pkey:
+ if to_pkey != LATEST_ROW:
+ # from=9 to=5 >> from 9 to 5 >> id>5 AND id<9 ORDER BY id DESC
+ query += (" AND %s.id > ? AND %s.id < ? " %
+ (table_name, table_name))
+ order_by = "ORDER BY id DESC"
+ query_args.append(to_pkey)
+ query_args.append(from_pkey)
+ else:
+ # from=5 to=-1 >> from 5 to now >> id>5
+ query += " AND %s.id > ?" % table_name
+ query_args.append(from_pkey)
+
+ query += group_by + order_by
+
+ if limit and limit > 0:
+ query += " LIMIT ?"
+ query_args.append(str(limit))
+
+ return (query, query_args)
+
+ def _as_events(self, cursor, table, from_pkey):
+ data_entries = table.decode_results(cursor)
+ last_pkey = from_pkey
+ if data_entries:
+ last_pkey = data_entries[-1].id
+
+ events = [
+ entry.as_event(self.event_factory).get_dict()
+ for entry in data_entries
+ ]
+
+ return (events, last_pkey)
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
new file mode 100644
index 0000000000..aa41e2ad7f
--- /dev/null
+++ b/synapse/storage/transactions.py
@@ -0,0 +1,287 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore, Table
+from .pdu import PdusTable
+
+from collections import namedtuple
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class TransactionStore(SQLBaseStore):
+ """A collection of queries for handling PDUs.
+ """
+
+ def get_received_txn_response(self, transaction_id, origin):
+ """For an incoming transaction from a given origin, check if we have
+ already responded to it. If so, return the response code and response
+ body (as a dict).
+
+ Args:
+ transaction_id (str)
+ origin(str)
+
+ Returns:
+ tuple: None if we have not previously responded to
+ this transaction or a 2-tuple of (int, dict)
+ """
+
+ return self._db_pool.runInteraction(
+ self._get_received_txn_response, transaction_id, origin
+ )
+
+ def _get_received_txn_response(self, txn, transaction_id, origin):
+ where_clause = "transaction_id = ? AND origin = ?"
+ query = ReceivedTransactionsTable.select_statement(where_clause)
+
+ txn.execute(query, (transaction_id, origin))
+
+ results = ReceivedTransactionsTable.decode_results(txn.fetchall())
+
+ if results and results[0].response_code:
+ return (results[0].response_code, results[0].response_json)
+ else:
+ return None
+
+ def set_received_txn_response(self, transaction_id, origin, code,
+ response_dict):
+ """Persist the response we returened for an incoming transaction, and
+ should return for subsequent transactions with the same transaction_id
+ and origin.
+
+ Args:
+ txn
+ transaction_id (str)
+ origin (str)
+ code (int)
+ response_json (str)
+ """
+
+ return self._db_pool.runInteraction(
+ self._set_received_txn_response,
+ transaction_id, origin, code, response_dict
+ )
+
+ def _set_received_txn_response(self, txn, transaction_id, origin, code,
+ response_json):
+ query = (
+ "UPDATE %s "
+ "SET response_code = ?, response_json = ? "
+ "WHERE transaction_id = ? AND origin = ?"
+ ) % ReceivedTransactionsTable.table_name
+
+ txn.execute(query, (code, response_json, transaction_id, origin))
+
+ def prep_send_transaction(self, transaction_id, destination, ts, pdu_list):
+ """Persists an outgoing transaction and calculates the values for the
+ previous transaction id list.
+
+ This should be called before sending the transaction so that it has the
+ correct value for the `prev_ids` key.
+
+ Args:
+ transaction_id (str)
+ destination (str)
+ ts (int)
+ pdu_list (list)
+
+ Returns:
+ list: A list of previous transaction ids.
+ """
+
+ return self._db_pool.runInteraction(
+ self._prep_send_transaction,
+ transaction_id, destination, ts, pdu_list
+ )
+
+ def _prep_send_transaction(self, txn, transaction_id, destination, ts,
+ pdu_list):
+
+ # First we find out what the prev_txs should be.
+ # Since we know that we are only sending one transaction at a time,
+ # we can simply take the last one.
+ query = "%s ORDER BY id DESC LIMIT 1" % (
+ SentTransactions.select_statement("destination = ?"),
+ )
+
+ results = txn.execute(query, (destination,))
+ results = SentTransactions.decode_results(results)
+
+ prev_txns = [r.transaction_id for r in results]
+
+ # Actually add the new transaction to the sent_transactions table.
+
+ query = SentTransactions.insert_statement()
+ txn.execute(query, SentTransactions.EntryType(
+ None,
+ transaction_id=transaction_id,
+ destination=destination,
+ ts=ts,
+ response_code=0,
+ response_json=None
+ ))
+
+ # Update the tx id -> pdu id mapping
+
+ values = [
+ (transaction_id, destination, pdu[0], pdu[1])
+ for pdu in pdu_list
+ ]
+
+ logger.debug("Inserting: %s", repr(values))
+
+ query = TransactionsToPduTable.insert_statement()
+ txn.executemany(query, values)
+
+ return prev_txns
+
+ def delivered_txn(self, transaction_id, destination, code, response_dict):
+ """Persists the response for an outgoing transaction.
+
+ Args:
+ transaction_id (str)
+ destination (str)
+ code (int)
+ response_json (str)
+ """
+ return self._db_pool.runInteraction(
+ self._delivered_txn,
+ transaction_id, destination, code, response_dict
+ )
+
+ def _delivered_txn(cls, txn, transaction_id, destination,
+ code, response_json):
+ query = (
+ "UPDATE %s "
+ "SET response_code = ?, response_json = ? "
+ "WHERE transaction_id = ? AND destination = ?"
+ ) % SentTransactions.table_name
+
+ txn.execute(query, (code, response_json, transaction_id, destination))
+
+ def get_transactions_after(self, transaction_id, destination):
+ """Get all transactions after a given local transaction_id.
+
+ Args:
+ transaction_id (str)
+ destination (str)
+
+ Returns:
+ list: A list of `ReceivedTransactionsTable.EntryType`
+ """
+ return self._db_pool.runInteraction(
+ self._get_transactions_after, transaction_id, destination
+ )
+
+ def _get_transactions_after(cls, txn, transaction_id, destination):
+ where = (
+ "destination = ? AND id > (select id FROM %s WHERE "
+ "transaction_id = ? AND destination = ?)"
+ ) % (
+ SentTransactions.table_name
+ )
+ query = SentTransactions.select_statement(where)
+
+ txn.execute(query, (destination, transaction_id, destination))
+
+ return ReceivedTransactionsTable.decode_results(txn.fetchall())
+
+ def get_pdus_after_transaction(self, transaction_id, destination):
+ """For a given local transaction_id that we sent to a given destination
+ home server, return a list of PDUs that were sent to that destination
+ after it.
+
+ Args:
+ txn
+ transaction_id (str)
+ destination (str)
+
+ Returns
+ list: A list of PduTuple
+ """
+ return self._db_pool.runInteraction(
+ self._get_pdus_after_transaction,
+ transaction_id, destination
+ )
+
+ def _get_pdus_after_transaction(self, txn, transaction_id, destination):
+
+ # Query that first get's all transaction_ids with an id greater than
+ # the one given from the `sent_transactions` table. Then JOIN on this
+ # from the `tx->pdu` table to get a list of (pdu_id, origin) that
+ # specify the pdus that were sent in those transactions.
+ query = (
+ "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
+ "INNER JOIN %(sent_tx)s as st "
+ "ON tp.transaction_id = st.transaction_id "
+ "AND tp.destination = st.destination "
+ "WHERE st.id > ("
+ "SELECT id FROM %(sent_tx)s "
+ "WHERE transaction_id = ? AND destination = ?"
+ ) % {
+ "tx_pdu": TransactionsToPduTable.table_name,
+ "sent_tx": SentTransactions.table_name,
+ }
+
+ txn.execute(query, (transaction_id, destination))
+
+ pdus = PdusTable.decode_results(txn.fetchall())
+
+ return self._get_pdu_tuples(txn, pdus)
+
+
+class ReceivedTransactionsTable(Table):
+ table_name = "received_transactions"
+
+ fields = [
+ "transaction_id",
+ "origin",
+ "ts",
+ "response_code",
+ "response_json",
+ "has_been_referenced",
+ ]
+
+ EntryType = namedtuple("ReceivedTransactionsEntry", fields)
+
+
+class SentTransactions(Table):
+ table_name = "sent_transactions"
+
+ fields = [
+ "id",
+ "transaction_id",
+ "destination",
+ "ts",
+ "response_code",
+ "response_json",
+ ]
+
+ EntryType = namedtuple("SentTransactionsEntry", fields)
+
+
+class TransactionsToPduTable(Table):
+ table_name = "transaction_id_to_pdu"
+
+ fields = [
+ "transaction_id",
+ "destination",
+ "pdu_id",
+ "pdu_origin",
+ ]
+
+ EntryType = namedtuple("TransactionsToPduEntry", fields)
diff --git a/synapse/types.py b/synapse/types.py
new file mode 100644
index 0000000000..1adc95bbb0
--- /dev/null
+++ b/synapse/types.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.errors import SynapseError
+
+from collections import namedtuple
+
+
+class DomainSpecificString(
+ namedtuple("DomainSpecificString", ("localpart", "domain", "is_mine"))
+):
+ """Common base class among ID/name strings that have a local part and a
+ domain name, prefixed with a sigil.
+
+ Has the fields:
+
+ 'localpart' : The local part of the name (without the leading sigil)
+ 'domain' : The domain part of the name
+ 'is_mine' : Boolean indicating if the domain name is recognised by the
+ HomeServer as being its own
+ """
+
+ @classmethod
+ def from_string(cls, s, hs):
+ """Parse the string given by 's' into a structure object."""
+ if s[0] != cls.SIGIL:
+ raise SynapseError(400, "Expected %s string to start with '%s'" % (
+ cls.__name__, cls.SIGIL,
+ ))
+
+ parts = s[1:].split(':', 1)
+ if len(parts) != 2:
+ raise SynapseError(
+ 400, "Expected %s of the form '%slocalname:domain'" % (
+ cls.__name__, cls.SIGIL,
+ )
+ )
+
+ domain = parts[1]
+
+ # This code will need changing if we want to support multiple domain
+ # names on one HS
+ is_mine = domain == hs.hostname
+ return cls(localpart=parts[0], domain=domain, is_mine=is_mine)
+
+ def to_string(self):
+ """Return a string encoding the fields of the structure object."""
+ return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
+
+ @classmethod
+ def create_local(cls, localpart, hs):
+ """Create a structure on the local domain"""
+ return cls(localpart=localpart, domain=hs.hostname, is_mine=True)
+
+
+class UserID(DomainSpecificString):
+ """Structure representing a user ID."""
+ SIGIL = "@"
+
+
+class RoomAlias(DomainSpecificString):
+ """Structure representing a room name."""
+ SIGIL = "#"
+
+
+class RoomID(DomainSpecificString):
+ """Structure representing a room id. """
+ SIGIL = "!"
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
new file mode 100644
index 0000000000..5361cb7ec2
--- /dev/null
+++ b/synapse/util/__init__.py
@@ -0,0 +1,40 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import reactor
+
+import time
+
+
+class Clock(object):
+ """A small utility that obtains current time-of-day so that time may be
+ mocked during unit-tests.
+
+ TODO(paul): Also move the sleep() functionallity into it
+ """
+
+ def time(self):
+ """Returns the current system time in seconds since epoch."""
+ return time.time()
+
+ def time_msec(self):
+ """Returns the current system time in miliseconds since epoch."""
+ return self.time() * 1000
+
+ def call_later(self, delay, callback):
+ return reactor.callLater(delay, callback)
+
+ def cancel_call_later(self, timer):
+ timer.cancel()
diff --git a/synapse/util/async.py b/synapse/util/async.py
new file mode 100644
index 0000000000..e04db8e285
--- /dev/null
+++ b/synapse/util/async.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, reactor
+
+
+def sleep(seconds):
+ d = defer.Deferred()
+ reactor.callLater(seconds, d.callback, seconds)
+ return d
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
new file mode 100644
index 0000000000..32d19402b4
--- /dev/null
+++ b/synapse/util/distributor.py
@@ -0,0 +1,108 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class Distributor(object):
+ """A central dispatch point for loosely-connected pieces of code to
+ register, observe, and fire signals.
+
+ Signals are named simply by strings.
+
+ TODO(paul): It would be nice to give signals stronger object identities,
+ so we can attach metadata, docstrings, detect typoes, etc... But this
+ model will do for today.
+ """
+
+ def __init__(self):
+ self.signals = {}
+ self.pre_registration = {}
+
+ def declare(self, name):
+ if name in self.signals:
+ raise KeyError("%r already has a signal named %s" % (self, name))
+
+ self.signals[name] = Signal(name)
+
+ if name in self.pre_registration:
+ signal = self.signals[name]
+ for observer in self.pre_registration[name]:
+ signal.observe(observer)
+
+ def observe(self, name, observer):
+ if name in self.signals:
+ self.signals[name].observe(observer)
+ else:
+ # TODO: Avoid strong ordering dependency by allowing people to
+ # pre-register observations on signals that don't exist yet.
+ if name not in self.pre_registration:
+ self.pre_registration[name] = []
+ self.pre_registration[name].append(observer)
+
+ def fire(self, name, *args, **kwargs):
+ if name not in self.signals:
+ raise KeyError("%r does not have a signal named %s" % (self, name))
+
+ return self.signals[name].fire(*args, **kwargs)
+
+
+class Signal(object):
+ """A Signal is a dispatch point that stores a list of callables as
+ observers of it.
+
+ Signals can be "fired", meaning that every callable observing it is
+ invoked. Firing a signal does not change its state; it can be fired again
+ at any later point. Firing a signal passes any arguments from the fire
+ method into all of the observers.
+ """
+
+ def __init__(self, name):
+ self.name = name
+ self.observers = []
+
+ def observe(self, observer):
+ """Adds a new callable to the observer list which will be invoked by
+ the 'fire' method.
+
+ Each observer callable may return a Deferred."""
+ self.observers.append(observer)
+
+ def fire(self, *args, **kwargs):
+ """Invokes every callable in the observer list, passing in the args and
+ kwargs. Exceptions thrown by observers are logged but ignored. It is
+ not an error to fire a signal with no observers.
+
+ Returns a Deferred that will complete when all the observers have
+ completed."""
+ deferreds = []
+ for observer in self.observers:
+ d = defer.maybeDeferred(observer, *args, **kwargs)
+
+ def eb(failure):
+ logger.warning(
+ "%s signal observer %s failed: %r",
+ self.name, observer, failure,
+ exc_info=(
+ failure.type,
+ failure.value,
+ failure.getTracebackObject()))
+ deferreds.append(d.addErrback(eb))
+
+ return defer.DeferredList(deferreds)
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
new file mode 100644
index 0000000000..190a80a322
--- /dev/null
+++ b/synapse/util/jsonobject.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+class JsonEncodedObject(object):
+ """ A common base class for defining protocol units that are represented
+ as JSON.
+
+ Attributes:
+ unrecognized_keys (dict): A dict containing all the key/value pairs we
+ don't recognize.
+ """
+
+ valid_keys = [] # keys we will store
+ """A list of strings that represent keys we know about
+ and can handle. If we have values for these keys they will be
+ included in the `dictionary` instance variable.
+ """
+
+ internal_keys = [] # keys to ignore while building dict
+ """A list of strings that should *not* be encoded into JSON.
+ """
+
+ required_keys = []
+ """A list of strings that we require to exist. If they are not given upon
+ construction it raises an exception.
+ """
+
+ def __init__(self, **kwargs):
+ """ Takes the dict of `kwargs` and loads all keys that are *valid*
+ (i.e., are included in the `valid_keys` list) into the dictionary`
+ instance variable.
+
+ Any keys that aren't recognized are added to the `unrecognized_keys`
+ attribute.
+
+ Args:
+ **kwargs: Attributes associated with this protocol unit.
+ """
+ for required_key in self.required_keys:
+ if required_key not in kwargs:
+ raise RuntimeError("Key %s is required" % required_key)
+
+ self.unrecognized_keys = {} # Keys we were given not listed as valid
+ for k, v in kwargs.items():
+ if k in self.valid_keys or k in self.internal_keys:
+ self.__dict__[k] = v
+ else:
+ self.unrecognized_keys[k] = v
+
+ def get_dict(self):
+ """ Converts this protocol unit into a :py:class:`dict`, ready to be
+ encoded as JSON.
+
+ The keys it encodes are: `valid_keys` - `internal_keys`
+
+ Returns
+ dict
+ """
+ d = {
+ k: _encode(v) for (k, v) in self.__dict__.items()
+ if k in self.valid_keys and k not in self.internal_keys
+ }
+ d.update(self.unrecognized_keys)
+ return copy.deepcopy(d)
+
+ def get_full_dict(self):
+ d = {
+ k: v for (k, v) in self.__dict__.items()
+ if k in self.valid_keys or k in self.internal_keys
+ }
+ d.update(self.unrecognized_keys)
+ return copy.deepcopy(d)
+
+ def __str__(self):
+ return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
+
+def _encode(obj):
+ if type(obj) is list:
+ return [_encode(o) for o in obj]
+
+ if isinstance(obj, JsonEncodedObject):
+ return obj.get_dict()
+
+ return obj
diff --git a/synapse/util/lockutils.py b/synapse/util/lockutils.py
new file mode 100644
index 0000000000..e4d609d84e
--- /dev/null
+++ b/synapse/util/lockutils.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class Lock(object):
+
+ def __init__(self, deferred):
+ self._deferred = deferred
+ self.released = False
+
+ def release(self):
+ self.released = True
+ self._deferred.callback(None)
+
+ def __del__(self):
+ if not self.released:
+ logger.critical("Lock was destructed but never released!")
+ self.release()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.release()
+
+
+class LockManager(object):
+ """ Utility class that allows us to lock based on a `key` """
+
+ def __init__(self):
+ self._lock_deferreds = {}
+
+ @defer.inlineCallbacks
+ def lock(self, key):
+ """ Allows us to block until it is our turn.
+ Args:
+ key (str)
+ Returns:
+ Lock
+ """
+ new_deferred = defer.Deferred()
+ old_deferred = self._lock_deferreds.get(key)
+ self._lock_deferreds[key] = new_deferred
+
+ if old_deferred:
+ yield old_deferred
+
+ defer.returnValue(Lock(new_deferred))
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
new file mode 100644
index 0000000000..08d5aafca4
--- /dev/null
+++ b/synapse/util/logutils.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from inspect import getcallargs
+
+import logging
+
+
+def log_function(f):
+ """ Function decorator that logs every call to that function.
+ """
+ func_name = f.__name__
+ lineno = f.func_code.co_firstlineno
+ pathname = f.func_code.co_filename
+
+ def wrapped(*args, **kwargs):
+ name = f.__module__
+ logger = logging.getLogger(name)
+ level = logging.DEBUG
+
+ if logger.isEnabledFor(level):
+ bound_args = getcallargs(f, *args, **kwargs)
+
+ def format(value):
+ r = str(value)
+ if len(r) > 50:
+ r = r[:50] + "..."
+ return r
+
+ func_args = [
+ "%s=%s" % (k, format(v)) for k, v in bound_args.items()
+ ]
+
+ msg_args = {
+ "func_name": func_name,
+ "args": ", ".join(func_args)
+ }
+
+ record = logging.LogRecord(
+ name=name,
+ level=level,
+ pathname=pathname,
+ lineno=lineno,
+ msg="Invoked '%(func_name)s' with args: %(args)s",
+ args=msg_args,
+ exc_info=None
+ )
+
+ logger.handle(record)
+
+ return f(*args, **kwargs)
+
+ return wrapped
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
new file mode 100644
index 0000000000..91550583a4
--- /dev/null
+++ b/synapse/util/stringutils.py
@@ -0,0 +1,24 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import random
+import string
+
+
+def origin_from_ucid(ucid):
+ return ucid.split("@", 1)[1]
+
+
+def random_string(length):
+ return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
|