diff options
Diffstat (limited to 'synapse')
83 files changed, 11941 insertions, 0 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py new file mode 100644 index 0000000000..aa760fb341 --- /dev/null +++ b/synapse/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" This is a reference implementation of a synapse home server. +""" diff --git a/synapse/api/__init__.py b/synapse/api/__init__.py new file mode 100644 index 0000000000..fe8a073cd3 --- /dev/null +++ b/synapse/api/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/api/auth.py b/synapse/api/auth.py new file mode 100644 index 0000000000..5c66a7261f --- /dev/null +++ b/synapse/api/auth.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains classes for authenticating the user.""" +from twisted.internet import defer + +from synapse.api.constants import Membership +from synapse.api.errors import AuthError, StoreError +from synapse.api.events.room import (RoomTopicEvent, RoomMemberEvent, + MessageEvent, FeedbackEvent) + +import logging + +logger = logging.getLogger(__name__) + + +class Auth(object): + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def check(self, event, raises=False): + """ Checks if this event is correctly authed. + + Returns: + True if the auth checks pass. + Raises: + AuthError if there was a problem authorising this event. This will + be raised only if raises=True. + """ + try: + if event.type in [RoomTopicEvent.TYPE, MessageEvent.TYPE, + FeedbackEvent.TYPE]: + yield self.check_joined_room(event.room_id, event.user_id) + defer.returnValue(True) + elif event.type == RoomMemberEvent.TYPE: + allowed = yield self.is_membership_change_allowed(event) + defer.returnValue(allowed) + else: + raise AuthError(500, "Unknown event type %s" % event.type) + except AuthError as e: + logger.info("Event auth check failed on event %s with msg: %s", + event, e.msg) + if raises: + raise e + defer.returnValue(False) + + @defer.inlineCallbacks + def check_joined_room(self, room_id, user_id): + try: + member = yield self.store.get_room_member( + room_id=room_id, + user_id=user_id + ) + if not member or member.membership != Membership.JOIN: + raise AuthError(403, "User %s not in room %s" % + (user_id, room_id)) + defer.returnValue(member) + except AttributeError: + pass + defer.returnValue(None) + + @defer.inlineCallbacks + def is_membership_change_allowed(self, event): + # does this room even exist + room = yield self.store.get_room(event.room_id) + if not room: + raise AuthError(403, "Room does not exist") + + # get info about the caller + try: + caller = yield self.store.get_room_member( + user_id=event.user_id, + room_id=event.room_id) + except: + caller = None + caller_in_room = caller and caller.membership == "join" + + # get info about the target + try: + target = yield self.store.get_room_member( + user_id=event.target_user_id, + room_id=event.room_id) + except: + target = None + target_in_room = target and target.membership == "join" + + membership = event.content["membership"] + + if Membership.INVITE == membership: + # Invites are valid iff caller is in the room and target isn't. + if not caller_in_room: # caller isn't joined + raise AuthError(403, "You are not in room %s." % event.room_id) + elif target_in_room: # the target is already in the room. + raise AuthError(403, "%s is already in the room." % + event.target_user_id) + elif Membership.JOIN == membership: + # Joins are valid iff caller == target and they were: + # invited: They are accepting the invitation + # joined: It's a NOOP + if event.user_id != event.target_user_id: + raise AuthError(403, "Cannot force another user to join.") + elif room.is_public: + pass # anyone can join public rooms. + elif (not caller or caller.membership not in + [Membership.INVITE, Membership.JOIN]): + raise AuthError(403, "You are not invited to this room.") + elif Membership.LEAVE == membership: + if not caller_in_room: # trying to leave a room you aren't joined + raise AuthError(403, "You are not in room %s." % event.room_id) + elif event.target_user_id != event.user_id: + # trying to force another user to leave + raise AuthError(403, "Cannot force %s to leave." % + event.target_user_id) + else: + raise AuthError(500, "Unknown membership %s" % membership) + + defer.returnValue(True) + + def get_user_by_req(self, request): + """ Get a registered user's ID. + + Args: + request - An HTTP request with an access_token query parameter. + Returns: + UserID : User ID object of the user making the request + Raises: + AuthError if no user by that token exists or the token is invalid. + """ + # Can optionally look elsewhere in the request (e.g. headers) + try: + return self.get_user_by_token(request.args["access_token"][0]) + except KeyError: + raise AuthError(403, "Missing access token.") + + @defer.inlineCallbacks + def get_user_by_token(self, token): + """ Get a registered user's ID. + + Args: + token (str)- The access token to get the user by. + Returns: + UserID : User ID object of the user who has that access token. + Raises: + AuthError if no user by that token exists or the token is invalid. + """ + try: + user_id = yield self.store.get_user_by_token(token=token) + defer.returnValue(self.hs.parse_userid(user_id)) + except StoreError: + raise AuthError(403, "Unrecognised access token.") diff --git a/synapse/api/constants.py b/synapse/api/constants.py new file mode 100644 index 0000000000..37bf41bfb3 --- /dev/null +++ b/synapse/api/constants.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains constants from the specification.""" + + +class Membership(object): + + """Represents the membership states of a user in a room.""" + INVITE = u"invite" + JOIN = u"join" + KNOCK = u"knock" + LEAVE = u"leave" + + +class Feedback(object): + + """Represents the types of feedback a user can send in response to a + message.""" + + DELIVERED = u"d" + READ = u"r" + LIST = (DELIVERED, READ) + + +class PresenceState(object): + """Represents the presence state of a user.""" + OFFLINE = 0 + BUSY = 1 + ONLINE = 2 + FREE_FOR_CHAT = 3 diff --git a/synapse/api/errors.py b/synapse/api/errors.py new file mode 100644 index 0000000000..7ad4d636c2 --- /dev/null +++ b/synapse/api/errors.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains exceptions and error codes.""" + +import logging + + +class Codes(object): + FORBIDDEN = "M_FORBIDDEN" + BAD_JSON = "M_BAD_JSON" + NOT_JSON = "M_NOT_JSON" + USER_IN_USE = "M_USER_IN_USE" + ROOM_IN_USE = "M_ROOM_IN_USE" + BAD_PAGINATION = "M_BAD_PAGINATION" + UNKNOWN = "M_UNKNOWN" + NOT_FOUND = "M_NOT_FOUND" + + +class CodeMessageException(Exception): + """An exception with integer code and message string attributes.""" + + def __init__(self, code, msg): + logging.error("%s: %s, %s", type(self).__name__, code, msg) + super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) + self.code = code + self.msg = msg + + +class SynapseError(CodeMessageException): + """A base error which can be caught for all synapse events.""" + def __init__(self, code, msg, errcode=""): + """Constructs a synapse error. + + Args: + code (int): The integer error code (typically an HTTP response code) + msg (str): The human-readable error message. + err (str): The error code e.g 'M_FORBIDDEN' + """ + super(SynapseError, self).__init__(code, msg) + self.errcode = errcode + + +class RoomError(SynapseError): + """An error raised when a room event fails.""" + pass + + +class RegistrationError(SynapseError): + """An error raised when a registration event fails.""" + pass + + +class AuthError(SynapseError): + """An error raised when there was a problem authorising an event.""" + + def __init__(self, *args, **kwargs): + if "errcode" not in kwargs: + kwargs["errcode"] = Codes.FORBIDDEN + super(AuthError, self).__init__(*args, **kwargs) + + +class EventStreamError(SynapseError): + """An error raised when there a problem with the event stream.""" + pass + + +class LoginError(SynapseError): + """An error raised when there was a problem logging in.""" + pass + + +class StoreError(SynapseError): + """An error raised when there was a problem storing some data.""" + pass + + +def cs_exception(exception): + if isinstance(exception, SynapseError): + return cs_error( + exception.msg, + Codes.UNKNOWN if not exception.errcode else exception.errcode) + elif isinstance(exception, CodeMessageException): + return cs_error(exception.msg) + else: + logging.error("Unknown exception type: %s", type(exception)) + + +def cs_error(msg, code=Codes.UNKNOWN, **kwargs): + """ Utility method for constructing an error response for client-server + interactions. + + Args: + msg (str): The error message. + code (int): The error code. + kwargs : Additional keys to add to the response. + Returns: + A dict representing the error response JSON. + """ + err = {"error": msg, "errcode": code} + for key, value in kwargs.iteritems(): + err[key] = value + return err diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py new file mode 100644 index 0000000000..bc2daf3361 --- /dev/null +++ b/synapse/api/events/__init__.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import SynapseError, Codes +from synapse.util.jsonobject import JsonEncodedObject + + +class SynapseEvent(JsonEncodedObject): + + """Base class for Synapse events. These are JSON objects which must abide + by a certain well-defined structure. + """ + + # Attributes that are currently assumed by the federation side: + # Mandatory: + # - event_id + # - room_id + # - type + # - is_state + # + # Optional: + # - state_key (mandatory when is_state is True) + # - prev_events (these can be filled out by the federation layer itself.) + # - prev_state + + valid_keys = [ + "event_id", + "type", + "room_id", + "user_id", # sender/initiator + "content", # HTTP body, JSON + ] + + internal_keys = [ + "is_state", + "state_key", + "prev_events", + "prev_state", + "depth", + "destinations", + "origin", + ] + + required_keys = [ + "event_id", + "room_id", + "content", + ] + + def __init__(self, raises=True, **kwargs): + super(SynapseEvent, self).__init__(**kwargs) + if "content" in kwargs: + self.check_json(self.content, raises=raises) + + def get_content_template(self): + """ Retrieve the JSON template for this event as a dict. + + The template must be a dict representing the JSON to match. Only + required keys should be present. The values of the keys in the template + are checked via type() to the values of the same keys in the actual + event JSON. + + NB: If loading content via json.loads, you MUST define strings as + unicode. + + For example: + Content: + { + "name": u"bob", + "age": 18, + "friends": [u"mike", u"jill"] + } + Template: + { + "name": u"string", + "age": 0, + "friends": [u"string"] + } + The values "string" and 0 could be anything, so long as the types + are the same as the content. + """ + raise NotImplementedError("get_content_template not implemented.") + + def check_json(self, content, raises=True): + """Checks the given JSON content abides by the rules of the template. + + Args: + content : A JSON object to check. + raises: True to raise a SynapseError if the check fails. + Returns: + True if the content passes the template. Returns False if the check + fails and raises=False. + Raises: + SynapseError if the check fails and raises=True. + """ + # recursively call to inspect each layer + err_msg = self._check_json(content, self.get_content_template()) + if err_msg: + if raises: + raise SynapseError(400, err_msg, Codes.BAD_JSON) + else: + return False + else: + return True + + def _check_json(self, content, template): + """Check content and template matches. + + If the template is a dict, each key in the dict will be validated with + the content, else it will just compare the types of content and + template. This basic type check is required because this function will + be recursively called and could be called with just strs or ints. + + Args: + content: The content to validate. + template: The validation template. + Returns: + str: An error message if the validation fails, else None. + """ + if type(content) != type(template): + return "Mismatched types: %s" % template + + if type(template) == dict: + for key in template: + if key not in content: + return "Missing %s key" % key + + if type(content[key]) != type(template[key]): + return "Key %s is of the wrong type." % key + + if type(content[key]) == dict: + # we must go deeper + msg = self._check_json(content[key], template[key]) + if msg: + return msg + elif type(content[key]) == list: + # make sure each item type in content matches the template + for entry in content[key]: + msg = self._check_json(entry, template[key][0]) + if msg: + return msg diff --git a/synapse/api/events/factory.py b/synapse/api/events/factory.py new file mode 100644 index 0000000000..ea7afa234e --- /dev/null +++ b/synapse/api/events/factory.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.events.room import ( + RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent, + InviteJoinEvent, RoomConfigEvent +) + +from synapse.util.stringutils import random_string + + +class EventFactory(object): + + _event_classes = [ + RoomTopicEvent, + MessageEvent, + RoomMemberEvent, + FeedbackEvent, + InviteJoinEvent, + RoomConfigEvent + ] + + def __init__(self): + self._event_list = {} # dict of TYPE to event class + for event_class in EventFactory._event_classes: + self._event_list[event_class.TYPE] = event_class + + def create_event(self, etype=None, **kwargs): + kwargs["type"] = etype + if "event_id" not in kwargs: + kwargs["event_id"] = random_string(10) + + try: + handler = self._event_list[etype] + except KeyError: # unknown event type + # TODO allow custom event types. + raise NotImplementedError("Unknown etype=%s" % etype) + + return handler(**kwargs) diff --git a/synapse/api/events/room.py b/synapse/api/events/room.py new file mode 100644 index 0000000000..b31cd19f4b --- /dev/null +++ b/synapse/api/events/room.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from . import SynapseEvent + + +class RoomTopicEvent(SynapseEvent): + TYPE = "m.room.topic" + + def __init__(self, **kwargs): + kwargs["state_key"] = "" + super(RoomTopicEvent, self).__init__(**kwargs) + + def get_content_template(self): + return {"topic": u"string"} + + +class RoomMemberEvent(SynapseEvent): + TYPE = "m.room.member" + + valid_keys = SynapseEvent.valid_keys + [ + "target_user_id", # target + "membership", # action + ] + + def __init__(self, **kwargs): + if "target_user_id" in kwargs: + kwargs["state_key"] = kwargs["target_user_id"] + super(RoomMemberEvent, self).__init__(**kwargs) + + def get_content_template(self): + return {"membership": u"string"} + + +class MessageEvent(SynapseEvent): + TYPE = "m.room.message" + + valid_keys = SynapseEvent.valid_keys + [ + "msg_id", # unique per room + user combo + ] + + def __init__(self, **kwargs): + super(MessageEvent, self).__init__(**kwargs) + + def get_content_template(self): + return {"msgtype": u"string"} + + +class FeedbackEvent(SynapseEvent): + TYPE = "m.room.message.feedback" + + valid_keys = SynapseEvent.valid_keys + [ + "msg_id", # the message ID being acknowledged + "msg_sender_id", # person who is sending the feedback is 'user_id' + "feedback_type", # the type of feedback (delivery, read, etc) + ] + + def __init__(self, **kwargs): + super(FeedbackEvent, self).__init__(**kwargs) + + def get_content_template(self): + return {} + + +class InviteJoinEvent(SynapseEvent): + TYPE = "m.room.invite_join" + + valid_keys = SynapseEvent.valid_keys + [ + "target_user_id", + "target_host", + ] + + def __init__(self, **kwargs): + super(InviteJoinEvent, self).__init__(**kwargs) + + def get_content_template(self): + return {} + + +class RoomConfigEvent(SynapseEvent): + TYPE = "m.room.config" + + def __init__(self, **kwargs): + kwargs["state_key"] = "" + super(RoomConfigEvent, self).__init__(**kwargs) + + def get_content_template(self): + return {} diff --git a/synapse/api/notifier.py b/synapse/api/notifier.py new file mode 100644 index 0000000000..974f7f0ba0 --- /dev/null +++ b/synapse/api/notifier.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import Membership +from synapse.api.events.room import RoomMemberEvent + +from twisted.internet import defer +from twisted.internet import reactor + +import logging + +logger = logging.getLogger(__name__) + + +class Notifier(object): + + def __init__(self, hs): + self.store = hs.get_datastore() + self.hs = hs + self.stored_event_listeners = {} + + @defer.inlineCallbacks + def on_new_room_event(self, event, store_id): + """Called when there is a new room event which may potentially be sent + down listening users' event streams. + + This function looks for interested *users* who may want to be notified + for this event. This is different to users requesting from the event + stream which looks for interested *events* for this user. + + Args: + event (SynapseEvent): The new event, which must have a room_id + store_id (int): The ID of this event after it was stored with the + data store. + '""" + member_list = yield self.store.get_room_members(room_id=event.room_id, + membership="join") + if not member_list: + member_list = [] + + member_list = [u.user_id for u in member_list] + + # invites MUST prod the person being invited, who won't be in the room. + if (event.type == RoomMemberEvent.TYPE and + event.content["membership"] == Membership.INVITE): + member_list.append(event.target_user_id) + + for user_id in member_list: + if user_id in self.stored_event_listeners: + self._notify_and_callback( + user_id=user_id, + event_data=event.get_dict(), + stream_type=event.type, + store_id=store_id) + + def on_new_user_event(self, user_id, event_data, stream_type, store_id): + if user_id in self.stored_event_listeners: + self._notify_and_callback( + user_id=user_id, + event_data=event_data, + stream_type=stream_type, + store_id=store_id + ) + + def _notify_and_callback(self, user_id, event_data, stream_type, store_id): + logger.debug( + "Notifying %s of a new event.", + user_id + ) + + stream_ids = list(self.stored_event_listeners[user_id]) + for stream_id in stream_ids: + self._notify_and_callback_stream(user_id, stream_id, event_data, + stream_type, store_id) + + if not self.stored_event_listeners[user_id]: + del self.stored_event_listeners[user_id] + + def _notify_and_callback_stream(self, user_id, stream_id, event_data, + stream_type, store_id): + + event_listener = self.stored_event_listeners[user_id].pop(stream_id) + return_event_object = { + k: event_listener[k] for k in ["start", "chunk", "end"] + } + + # work out the new end token + token = event_listener["start"] + end = self._next_token(stream_type, store_id, token) + return_event_object["end"] = end + + # add the event to the chunk + chunk = event_listener["chunk"] + chunk.append(event_data) + + # callback the defer. We know this can't have been resolved before as + # we always remove the event_listener from the map before resolving. + event_listener["defer"].callback(return_event_object) + + def _next_token(self, stream_type, store_id, current_token): + stream_handler = self.hs.get_handlers().event_stream_handler + return stream_handler.get_event_stream_token( + stream_type, + store_id, + current_token + ) + + def store_events_for(self, user_id=None, stream_id=None, from_tok=None): + """Store all incoming events for this user. This should be paired with + get_events_for to return chunked data. + + Args: + user_id (str): The user to monitor incoming events for. + stream (object): The stream that is receiving events + from_tok (str): The token to monitor incoming events from. + """ + event_listener = { + "start": from_tok, + "chunk": [], + "end": from_tok, + "defer": defer.Deferred(), + } + + if user_id not in self.stored_event_listeners: + self.stored_event_listeners[user_id] = {stream_id: event_listener} + else: + self.stored_event_listeners[user_id][stream_id] = event_listener + + def purge_events_for(self, user_id=None, stream_id=None): + """Purges any stored events for this user. + + Args: + user_id (str): The user to purge stored events for. + """ + try: + del self.stored_event_listeners[user_id][stream_id] + if not self.stored_event_listeners[user_id]: + del self.stored_event_listeners[user_id] + except KeyError: + pass + + def get_events_for(self, user_id=None, stream_id=None, timeout=0): + """Retrieve stored events for this user, waiting if necessary. + + It is advisable to wrap this call in a maybeDeferred. + + Args: + user_id (str): The user to get events for. + timeout (int): The time in seconds to wait before giving up. + Returns: + A Deferred or a dict containing the chunk data, depending on if + there was data to return yet. The Deferred callback may be None if + there were no events before the timeout expired. + """ + logger.debug("%s is listening for events.", user_id) + + if len(self.stored_event_listeners[user_id][stream_id]["chunk"]) > 0: + logger.debug("%s returning existing chunk.", user_id) + return self.stored_event_listeners[user_id][stream_id] + + reactor.callLater( + (timeout / 1000.0), self._timeout, user_id, stream_id + ) + return self.stored_event_listeners[user_id][stream_id]["defer"] + + def _timeout(self, user_id, stream_id): + try: + # We remove the event_listener from the map so that we can't + # resolve the deferred twice. + event_listeners = self.stored_event_listeners[user_id] + event_listener = event_listeners.pop(stream_id) + event_listener["defer"].callback(None) + logger.debug("%s event listening timed out.", user_id) + except KeyError: + pass diff --git a/synapse/api/streams/__init__.py b/synapse/api/streams/__init__.py new file mode 100644 index 0000000000..08137c1e79 --- /dev/null +++ b/synapse/api/streams/__init__.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import SynapseError + + +class PaginationConfig(object): + + """A configuration object which stores pagination parameters.""" + + def __init__(self, from_tok=None, to_tok=None, limit=0): + self.from_tok = from_tok + self.to_tok = to_tok + self.limit = limit + + @classmethod + def from_request(cls, request, raise_invalid_params=True): + params = { + "from_tok": PaginationStream.TOK_START, + "to_tok": PaginationStream.TOK_END, + "limit": 0 + } + + query_param_mappings = [ # 3-tuple of qp_key, attribute, rules + ("from", "from_tok", lambda x: type(x) == str), + ("to", "to_tok", lambda x: type(x) == str), + ("limit", "limit", lambda x: x.isdigit()) + ] + + for qp, attr, is_valid in query_param_mappings: + if qp in request.args: + if is_valid(request.args[qp][0]): + params[attr] = request.args[qp][0] + elif raise_invalid_params: + raise SynapseError(400, "%s parameter is invalid." % qp) + + return PaginationConfig(**params) + + +class PaginationStream(object): + + """ An interface for streaming data as chunks. """ + + TOK_START = "START" + TOK_END = "END" + + def get_chunk(self, config=None): + """ Return the next chunk in the stream. + + Args: + config (PaginationConfig): The config to aid which chunk to get. + Returns: + A dict containing the new start token "start", the new end token + "end" and the data "chunk" as a list. + """ + raise NotImplementedError() + + +class StreamData(object): + + """ An interface for obtaining streaming data from a table. """ + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + + def get_rows(self, user_id, from_pkey, to_pkey, limit): + """ Get event stream data between the specified pkeys. + + Args: + user_id : The user's ID + from_pkey : The starting pkey. + to_pkey : The end pkey. May be -1 to mean "latest". + limit: The max number of results to return. + Returns: + A tuple containing the list of event stream data and the last pkey. + """ + raise NotImplementedError() + + def max_token(self): + """ Get the latest currently-valid token. + + Returns: + The latest token.""" + raise NotImplementedError() diff --git a/synapse/api/streams/event.py b/synapse/api/streams/event.py new file mode 100644 index 0000000000..0cc1a3e36a --- /dev/null +++ b/synapse/api/streams/event.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains classes for streaming from the event stream: /events. +""" +from twisted.internet import defer + +from synapse.api.errors import EventStreamError +from synapse.api.events.room import ( + RoomMemberEvent, MessageEvent, FeedbackEvent, RoomTopicEvent +) +from synapse.api.streams import PaginationStream, StreamData + +import logging + +logger = logging.getLogger(__name__) + + +class MessagesStreamData(StreamData): + EVENT_TYPE = MessageEvent.TYPE + + def __init__(self, hs, room_id=None, feedback=False): + super(MessagesStreamData, self).__init__(hs) + self.room_id = room_id + self.with_feedback = feedback + + @defer.inlineCallbacks + def get_rows(self, user_id, from_key, to_key, limit): + (data, latest_ver) = yield self.store.get_message_stream( + user_id=user_id, + from_key=from_key, + to_key=to_key, + limit=limit, + room_id=self.room_id, + with_feedback=self.with_feedback + ) + defer.returnValue((data, latest_ver)) + + @defer.inlineCallbacks + def max_token(self): + val = yield self.store.get_max_message_id() + defer.returnValue(val) + + +class RoomMemberStreamData(StreamData): + EVENT_TYPE = RoomMemberEvent.TYPE + + @defer.inlineCallbacks + def get_rows(self, user_id, from_key, to_key, limit): + (data, latest_ver) = yield self.store.get_room_member_stream( + user_id=user_id, + from_key=from_key, + to_key=to_key + ) + + defer.returnValue((data, latest_ver)) + + @defer.inlineCallbacks + def max_token(self): + val = yield self.store.get_max_room_member_id() + defer.returnValue(val) + + +class FeedbackStreamData(StreamData): + EVENT_TYPE = FeedbackEvent.TYPE + + def __init__(self, hs, room_id=None): + super(FeedbackStreamData, self).__init__(hs) + self.room_id = room_id + + @defer.inlineCallbacks + def get_rows(self, user_id, from_key, to_key, limit): + (data, latest_ver) = yield self.store.get_feedback_stream( + user_id=user_id, + from_key=from_key, + to_key=to_key, + limit=limit, + room_id=self.room_id + ) + defer.returnValue((data, latest_ver)) + + @defer.inlineCallbacks + def max_token(self): + val = yield self.store.get_max_feedback_id() + defer.returnValue(val) + + +class RoomDataStreamData(StreamData): + EVENT_TYPE = RoomTopicEvent.TYPE # TODO need multiple event types + + def __init__(self, hs, room_id=None): + super(RoomDataStreamData, self).__init__(hs) + self.room_id = room_id + + @defer.inlineCallbacks + def get_rows(self, user_id, from_key, to_key, limit): + (data, latest_ver) = yield self.store.get_room_data_stream( + user_id=user_id, + from_key=from_key, + to_key=to_key, + limit=limit, + room_id=self.room_id + ) + defer.returnValue((data, latest_ver)) + + @defer.inlineCallbacks + def max_token(self): + val = yield self.store.get_max_room_data_id() + defer.returnValue(val) + + +class EventStream(PaginationStream): + + SEPARATOR = '_' + + def __init__(self, user_id, stream_data_list): + super(EventStream, self).__init__() + self.user_id = user_id + self.stream_data = stream_data_list + + @defer.inlineCallbacks + def fix_tokens(self, pagination_config): + pagination_config.from_tok = yield self.fix_token( + pagination_config.from_tok) + pagination_config.to_tok = yield self.fix_token( + pagination_config.to_tok) + defer.returnValue(pagination_config) + + @defer.inlineCallbacks + def fix_token(self, token): + """Fixes unknown values in a token to known values. + + Args: + token (str): The token to fix up. + Returns: + The fixed-up token, which may == token. + """ + # replace TOK_START and TOK_END with 0_0_0 or -1_-1_-1 depending. + replacements = [ + (PaginationStream.TOK_START, "0"), + (PaginationStream.TOK_END, "-1") + ] + for magic_token, key in replacements: + if magic_token == token: + token = EventStream.SEPARATOR.join( + [key] * len(self.stream_data) + ) + + # replace -1 values with an actual pkey + token_segments = self._split_token(token) + for i, tok in enumerate(token_segments): + if tok == -1: + # add 1 to the max token because results are EXCLUSIVE from the + # latest version. + token_segments[i] = 1 + (yield self.stream_data[i].max_token()) + defer.returnValue(EventStream.SEPARATOR.join( + str(x) for x in token_segments + )) + + @defer.inlineCallbacks + def get_chunk(self, config=None): + # no support for limit on >1 streams, makes no sense. + if config.limit and len(self.stream_data) > 1: + raise EventStreamError( + 400, "Limit not supported on multiplexed streams." + ) + + (chunk_data, next_tok) = yield self._get_chunk_data(config.from_tok, + config.to_tok, + config.limit) + + defer.returnValue({ + "chunk": chunk_data, + "start": config.from_tok, + "end": next_tok + }) + + @defer.inlineCallbacks + def _get_chunk_data(self, from_tok, to_tok, limit): + """ Get event data between the two tokens. + + Tokens are SEPARATOR separated values representing pkey values of + certain tables, and the position determines the StreamData invoked + according to the STREAM_DATA list. + + The magic value '-1' can be used to get the latest value. + + Args: + from_tok - The token to start from. + to_tok - The token to end at. Must have values > from_tok or be -1. + Returns: + A list of event data. + Raises: + EventStreamError if something went wrong. + """ + # sanity check + if (from_tok.count(EventStream.SEPARATOR) != + to_tok.count(EventStream.SEPARATOR) or + (from_tok.count(EventStream.SEPARATOR) + 1) != + len(self.stream_data)): + raise EventStreamError(400, "Token lengths don't match.") + + chunk = [] + next_ver = [] + for i, (from_pkey, to_pkey) in enumerate(zip( + self._split_token(from_tok), + self._split_token(to_tok) + )): + if from_pkey == to_pkey: + # tokens are the same, we have nothing to do. + next_ver.append(str(to_pkey)) + continue + + (event_chunk, max_pkey) = yield self.stream_data[i].get_rows( + self.user_id, from_pkey, to_pkey, limit + ) + + chunk += event_chunk + next_ver.append(str(max_pkey)) + + defer.returnValue((chunk, EventStream.SEPARATOR.join(next_ver))) + + def _split_token(self, token): + """Splits the given token into a list of pkeys. + + Args: + token (str): The token with SEPARATOR values. + Returns: + A list of ints. + """ + segments = token.split(EventStream.SEPARATOR) + try: + int_segments = [int(x) for x in segments] + except ValueError: + raise EventStreamError(400, "Bad token: %s" % token) + return int_segments diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py new file mode 100644 index 0000000000..fe8a073cd3 --- /dev/null +++ b/synapse/app/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py new file mode 100644 index 0000000000..5708b3ad95 --- /dev/null +++ b/synapse/app/homeserver.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/usr/bin/env python + +from synapse.storage import read_schema + +from synapse.server import HomeServer + +from twisted.internet import reactor +from twisted.enterprise import adbapi +from twisted.python.log import PythonLoggingObserver +from synapse.http.server import TwistedHttpServer +from synapse.http.client import TwistedHttpClient + +from daemonize import Daemonize + +import argparse +import logging +import logging.config +import sqlite3 + +logger = logging.getLogger(__name__) + + +class SynapseHomeServer(HomeServer): + def build_http_server(self): + return TwistedHttpServer() + + def build_http_client(self): + return TwistedHttpClient() + + def build_db_pool(self): + """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we + don't have to worry about overwriting existing content. + """ + logging.info("Preparing database: %s...", self.db_name) + pool = adbapi.ConnectionPool( + 'sqlite3', self.db_name, check_same_thread=False, + cp_min=1, cp_max=1) + + schemas = [ + "transactions", + "pdu", + "users", + "profiles", + "presence", + "im", + "room_aliases", + ] + + for sql_loc in schemas: + sql_script = read_schema(sql_loc) + + with sqlite3.connect(self.db_name) as db_conn: + c = db_conn.cursor() + c.executescript(sql_script) + c.close() + db_conn.commit() + + logging.info("Database prepared in %s.", self.db_name) + + return pool + + +def setup_logging(verbosity=0, filename=None, config_path=None): + """ Sets up logging with verbosity levels. + + Args: + verbosity: The verbosity level. + filename: Log to the given file rather than to the console. + config_path: Path to a python logging config file. + """ + + if config_path is None: + log_format = ( + '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s' + ) + + if not verbosity or verbosity == 0: + level = logging.WARNING + elif verbosity == 1: + level = logging.INFO + else: + level = logging.DEBUG + + logging.basicConfig(level=level, filename=filename, format=log_format) + else: + logging.config.fileConfig(config_path) + + observer = PythonLoggingObserver() + observer.start() + + +def run(): + reactor.run() + + +def setup(): + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--port", dest="port", type=int, default=8080, + help="The port to listen on.") + parser.add_argument("-d", "--database", dest="db", default="homeserver.db", + help="The database name.") + parser.add_argument("-H", "--host", dest="host", default="localhost", + help="The hostname of the server.") + parser.add_argument('-v', '--verbose', dest="verbose", action='count', + help="The verbosity level.") + parser.add_argument('-f', '--log-file', dest="log_file", default=None, + help="File to log to.") + parser.add_argument('--log-config', dest="log_config", default=None, + help="Python logging config") + parser.add_argument('-D', '--daemonize', action='store_true', + default=False, help="Daemonize the home server") + parser.add_argument('--pid-file', dest="pid", help="When running as a " + "daemon, the file to store the pid in", + default="hs.pid") + args = parser.parse_args() + + verbosity = int(args.verbose) if args.verbose else None + + setup_logging( + verbosity=verbosity, + filename=args.log_file, + config_path=args.log_config, + ) + + logger.info("Server hostname: %s", args.host) + + hs = SynapseHomeServer( + args.host, + db_name=args.db + ) + + # This object doesn't need to be saved because it's set as the handler for + # the replication layer + hs.get_federation() + + hs.register_servlets() + + hs.get_http_server().start_listening(args.port) + + hs.build_db_pool() + + if args.daemonize: + daemon = Daemonize( + app="synapse-homeserver", + pid=args.pid, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + + daemon.start() + else: + run() + + +if __name__ == '__main__': + setup() diff --git a/synapse/crypto/__init__.py b/synapse/crypto/__init__.py new file mode 100644 index 0000000000..fe8a073cd3 --- /dev/null +++ b/synapse/crypto/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/crypto/config.py b/synapse/crypto/config.py new file mode 100644 index 0000000000..801dfd8656 --- /dev/null +++ b/synapse/crypto/config.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ConfigParser as configparser +import argparse +import socket +import sys +import os +from OpenSSL import crypto +import nacl.signing +from syutil.base64util import encode_base64 +import subprocess + + +def load_config(description, argv): + config_parser = argparse.ArgumentParser(add_help=False) + config_parser.add_argument("-c", "--config-path", metavar="CONFIG_FILE", + help="Specify config file") + config_args, remaining_args = config_parser.parse_known_args(argv) + if config_args.config_path: + config = configparser.SafeConfigParser() + config.read([config_args.config_path]) + defaults = dict(config.items("KeyServer")) + else: + defaults = {} + parser = argparse.ArgumentParser( + parents=[config_parser], + description=description, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.set_defaults(**defaults) + parser.add_argument("--server-name", default=socket.getfqdn(), + help="The name of the server") + parser.add_argument("--signing-key-path", + help="The signing key to sign responses with") + parser.add_argument("--tls-certificate-path", + help="PEM encoded X509 certificate for TLS") + parser.add_argument("--tls-private-key-path", + help="PEM encoded private key for TLS") + parser.add_argument("--tls-dh-params-path", + help="PEM encoded dh parameters for ephemeral keys") + parser.add_argument("--bind-port", type=int, + help="TCP port to listen on") + parser.add_argument("--bind-host", default="", + help="Local interface to listen on") + + args = parser.parse_args(remaining_args) + + server_config = vars(args) + del server_config["config_path"] + return server_config + + +def generate_config(argv): + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config-path", help="Specify config file", + metavar="CONFIG_FILE", required=True) + parser.add_argument("--server-name", default=socket.getfqdn(), + help="The name of the server") + parser.add_argument("--signing-key-path", + help="The signing key to sign responses with") + parser.add_argument("--tls-certificate-path", + help="PEM encoded X509 certificate for TLS") + parser.add_argument("--tls-private-key-path", + help="PEM encoded private key for TLS") + parser.add_argument("--tls-dh-params-path", + help="PEM encoded dh parameters for ephemeral keys") + parser.add_argument("--bind-port", type=int, required=True, + help="TCP port to listen on") + parser.add_argument("--bind-host", default="", + help="Local interface to listen on") + + args = parser.parse_args(argv) + + dir_name = os.path.dirname(args.config_path) + base_key_name = os.path.join(dir_name, args.server_name) + + if args.signing_key_path is None: + args.signing_key_path = base_key_name + ".signing.key" + + if args.tls_certificate_path is None: + args.tls_certificate_path = base_key_name + ".tls.crt" + + if args.tls_private_key_path is None: + args.tls_private_key_path = base_key_name + ".tls.key" + + if args.tls_dh_params_path is None: + args.tls_dh_params_path = base_key_name + ".tls.dh" + + if not os.path.exists(args.signing_key_path): + with open(args.signing_key_path, "w") as signing_key_file: + key = nacl.signing.SigningKey.generate() + signing_key_file.write(encode_base64(key.encode())) + + if not os.path.exists(args.tls_private_key_path): + with open(args.tls_private_key_path, "w") as private_key_file: + tls_private_key = crypto.PKey() + tls_private_key.generate_key(crypto.TYPE_RSA, 2048) + private_key_pem = crypto.dump_privatekey( + crypto.FILETYPE_PEM, tls_private_key + ) + private_key_file.write(private_key_pem) + else: + with open(args.tls_private_key_path) as private_key_file: + private_key_pem = private_key_file.read() + tls_private_key = crypto.load_privatekey( + crypto.FILETYPE_PEM, private_key_pem + ) + + if not os.path.exists(args.tls_certificate_path): + with open(args.tls_certificate_path, "w") as certifcate_file: + cert = crypto.X509() + subject = cert.get_subject() + subject.CN = args.server_name + + cert.set_serial_number(1000) + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60) + cert.set_issuer(cert.get_subject()) + cert.set_pubkey(tls_private_key) + + cert.sign(tls_private_key, 'sha256') + + cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) + + certifcate_file.write(cert_pem) + + if not os.path.exists(args.tls_dh_params_path): + subprocess.check_call([ + "openssl", "dhparam", + "-outform", "PEM", + "-out", args.tls_dh_params_path, + "2048" + ]) + + config = configparser.SafeConfigParser() + config.add_section("KeyServer") + for key, value in vars(args).items(): + if key != "config_path": + config.set("KeyServer", key, str(value)) + + with open(args.config_path, "w") as config_file: + config.write(config_file) + + +if __name__ == "__main__": + generate_config(sys.argv[1:]) diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py new file mode 100644 index 0000000000..b53d1c572b --- /dev/null +++ b/synapse/crypto/keyclient.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.http import HTTPClient +from twisted.internet import defer, reactor +from twisted.internet.protocol import ClientFactory +from twisted.names.srvconnect import SRVConnector +import json +import logging + + +logger = logging.getLogger(__name__) + + +@defer.inlineCallbacks +def fetch_server_key(server_name, ssl_context_factory): + """Fetch the keys for a remote server.""" + + factory = SynapseKeyClientFactory() + + SRVConnector( + reactor, "matrix", server_name, factory, + protocol="tcp", connectFuncName="connectSSL", defaultPort=443, + connectFuncKwArgs=dict(contextFactory=ssl_context_factory)).connect() + + server_key, server_certificate = yield factory.remote_key + + defer.returnValue((server_key, server_certificate)) + + +class SynapseKeyClientError(Exception): + """The key wasn't retireved from the remote server.""" + pass + + +class SynapseKeyClientProtocol(HTTPClient): + """Low level HTTPS client which retrieves an application/json response from + the server and extracts the X.509 certificate for the remote peer from the + SSL connection.""" + + def connectionMade(self): + logger.debug("Connected to %s", self.transport.getHost()) + self.sendCommand(b"GET", b"/key") + self.endHeaders() + self.timer = reactor.callLater( + self.factory.timeout_seconds, + self.on_timeout + ) + + def handleStatus(self, version, status, message): + if status != b"200": + logger.info("Non-200 response from %s: %s %s", + self.transport.getHost(), status, message) + self.transport.abortConnection() + + def handleResponse(self, response_body_bytes): + try: + json_response = json.loads(response_body_bytes) + except ValueError: + logger.info("Invalid JSON response from %s", + self.transport.getHost()) + self.transport.abortConnection() + return + + certificate = self.transport.getPeerCertificate() + self.factory.on_remote_key((json_response, certificate)) + self.transport.abortConnection() + self.timer.cancel() + + def on_timeout(self): + logger.debug("Timeout waiting for response from %s", + self.transport.getHost()) + self.transport.abortConnection() + + +class SynapseKeyClientFactory(ClientFactory): + protocol = SynapseKeyClientProtocol + max_retries = 5 + timeout_seconds = 30 + + def __init__(self): + self.succeeded = False + self.retries = 0 + self.remote_key = defer.Deferred() + + def on_remote_key(self, key): + self.succeeded = True + self.remote_key.callback(key) + + def retry_connection(self, connector): + self.retries += 1 + if self.retries < self.max_retries: + connector.connector = None + connector.connect() + else: + self.remote_key.errback( + SynapseKeyClientError("Max retries exceeded")) + + def clientConnectionFailed(self, connector, reason): + logger.info("Connection failed %s", reason) + self.retry_connection(connector) + + def clientConnectionLost(self, connector, reason): + logger.info("Connection lost %s", reason) + if not self.succeeded: + self.retry_connection(connector) diff --git a/synapse/crypto/keyserver.py b/synapse/crypto/keyserver.py new file mode 100644 index 0000000000..48bd380781 --- /dev/null +++ b/synapse/crypto/keyserver.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import reactor, ssl +from twisted.web import server +from twisted.web.resource import Resource +from twisted.python.log import PythonLoggingObserver + +from synapse.crypto.resource.key import LocalKey +from synapse.crypto.config import load_config + +from syutil.base64util import decode_base64 + +from OpenSSL import crypto, SSL + +import logging +import nacl.signing +import sys + + +class KeyServerSSLContextFactory(ssl.ContextFactory): + """Factory for PyOpenSSL SSL contexts that are used to handle incoming + connections and to make connections to remote servers.""" + + def __init__(self, key_server): + self._context = SSL.Context(SSL.SSLv23_METHOD) + self.configure_context(self._context, key_server) + + @staticmethod + def configure_context(context, key_server): + context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) + context.use_certificate(key_server.tls_certificate) + context.use_privatekey(key_server.tls_private_key) + context.load_tmp_dh(key_server.tls_dh_params_path) + context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH") + + def getContext(self): + return self._context + + +class KeyServer(object): + """An HTTPS server serving LocalKey and RemoteKey resources.""" + + def __init__(self, server_name, tls_certificate_path, tls_private_key_path, + tls_dh_params_path, signing_key_path, bind_host, bind_port): + self.server_name = server_name + self.tls_certificate = self.read_tls_certificate(tls_certificate_path) + self.tls_private_key = self.read_tls_private_key(tls_private_key_path) + self.tls_dh_params_path = tls_dh_params_path + self.signing_key = self.read_signing_key(signing_key_path) + self.bind_host = bind_host + self.bind_port = int(bind_port) + self.ssl_context_factory = KeyServerSSLContextFactory(self) + + @staticmethod + def read_tls_certificate(cert_path): + with open(cert_path) as cert_file: + cert_pem = cert_file.read() + return crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) + + @staticmethod + def read_tls_private_key(private_key_path): + with open(private_key_path) as private_key_file: + private_key_pem = private_key_file.read() + return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem) + + @staticmethod + def read_signing_key(signing_key_path): + with open(signing_key_path) as signing_key_file: + signing_key_b64 = signing_key_file.read() + signing_key_bytes = decode_base64(signing_key_b64) + return nacl.signing.SigningKey(signing_key_bytes) + + def run(self): + root = Resource() + root.putChild("key", LocalKey(self)) + site = server.Site(root) + reactor.listenSSL( + self.bind_port, + site, + self.ssl_context_factory, + interface=self.bind_host + ) + + logging.basicConfig(level=logging.DEBUG) + observer = PythonLoggingObserver() + observer.start() + + reactor.run() + + +def main(): + key_server = KeyServer(**load_config(__doc__, sys.argv[1:])) + key_server.run() + + +if __name__ == "__main__": + main() diff --git a/synapse/crypto/resource/__init__.py b/synapse/crypto/resource/__init__.py new file mode 100644 index 0000000000..fe8a073cd3 --- /dev/null +++ b/synapse/crypto/resource/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/crypto/resource/key.py b/synapse/crypto/resource/key.py new file mode 100644 index 0000000000..6ce6e0b034 --- /dev/null +++ b/synapse/crypto/resource/key.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET +from twisted.internet import defer +from synapse.http.server import respond_with_json_bytes +from synapse.crypto.keyclient import fetch_server_key +from syutil.crypto.jsonsign import sign_json, verify_signed_json +from syutil.base64util import encode_base64, decode_base64 +from syutil.jsonutil import encode_canonical_json +from OpenSSL import crypto +from nacl.signing import VerifyKey +import logging + + +logger = logging.getLogger(__name__) + + +class LocalKey(Resource): + """HTTP resource containing encoding the TLS X.509 certificate and NACL + signature verification keys for this server:: + + GET /key HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + { + "server_name": "this.server.example.com" + "signature_verify_key": # base64 encoded NACL verification key. + "tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert. + "signatures": { + "this.server.example.com": # NACL signature for this server. + } + } + """ + + def __init__(self, key_server): + self.key_server = key_server + self.response_body = encode_canonical_json( + self.response_json_object(key_server) + ) + Resource.__init__(self) + + @staticmethod + def response_json_object(key_server): + verify_key_bytes = key_server.signing_key.verify_key.encode() + x509_certificate_bytes = crypto.dump_certificate( + crypto.FILETYPE_ASN1, + key_server.tls_certificate + ) + json_object = { + u"server_name": key_server.server_name, + u"signature_verify_key": encode_base64(verify_key_bytes), + u"tls_certificate": encode_base64(x509_certificate_bytes) + } + signed_json = sign_json( + json_object, + key_server.server_name, + key_server.signing_key + ) + return signed_json + + def getChild(self, name, request): + logger.info("getChild %s %s", name, request) + if name == '': + return self + else: + return RemoteKey(name, self.key_server) + + def render_GET(self, request): + return respond_with_json_bytes(request, 200, self.response_body) + + +class RemoteKey(Resource): + """HTTP resource for retreiving the TLS certificate and NACL signature + verification keys for a another server. Checks that the reported X.509 TLS + certificate matches the one used in the HTTPS connection. Checks that the + NACL signature for the remote server is valid. Returns JSON signed by both + the remote server and by this server. + + GET /key/remote.server.example.com HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + { + "server_name": "remote.server.example.com" + "signature_verify_key": # base64 encoded NACL verification key. + "tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert. + "signatures": { + "remote.server.example.com": # NACL signature for remote server. + "this.server.example.com": # NACL signature for this server. + } + } + """ + + isLeaf = True + + def __init__(self, server_name, key_server): + self.server_name = server_name + self.key_server = key_server + Resource.__init__(self) + + def render_GET(self, request): + self._async_render_GET(request) + return NOT_DONE_YET + + @defer.inlineCallbacks + def _async_render_GET(self, request): + try: + server_keys, certificate = yield fetch_server_key( + self.server_name, + self.key_server.ssl_context_factory + ) + + resp_server_name = server_keys[u"server_name"] + verify_key_b64 = server_keys[u"signature_verify_key"] + tls_certificate_b64 = server_keys[u"tls_certificate"] + verify_key = VerifyKey(decode_base64(verify_key_b64)) + + if resp_server_name != self.server_name: + raise ValueError("Wrong server name '%s' != '%s'" % + (resp_server_name, self.server_name)) + + x509_certificate_bytes = crypto.dump_certificate( + crypto.FILETYPE_ASN1, + certificate + ) + + if encode_base64(x509_certificate_bytes) != tls_certificate_b64: + raise ValueError("TLS certificate doesn't match") + + verify_signed_json(server_keys, self.server_name, verify_key) + + signed_json = sign_json( + server_keys, + self.key_server.server_name, + self.key_server.signing_key + ) + + json_bytes = encode_canonical_json(signed_json) + respond_with_json_bytes(request, 200, json_bytes) + + except Exception as e: + json_bytes = encode_canonical_json({ + u"error": {u"code": 502, u"message": e.message} + }) + respond_with_json_bytes(request, 502, json_bytes) diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py new file mode 100644 index 0000000000..b4d95ed5ac --- /dev/null +++ b/synapse/federation/__init__.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" This package includes all the federation specific logic. +""" + +from .replication import ReplicationLayer +from .transport import TransportLayer + + +def initialize_http_replication(homeserver): + transport = TransportLayer( + homeserver.hostname, + server=homeserver.get_http_server(), + client=homeserver.get_http_client() + ) + + return ReplicationLayer(homeserver, transport) diff --git a/synapse/federation/handler.py b/synapse/federation/handler.py new file mode 100644 index 0000000000..31e8470b33 --- /dev/null +++ b/synapse/federation/handler.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from .pdu_codec import PduCodec + +from synapse.api.errors import AuthError +from synapse.util.logutils import log_function + +import logging + + +logger = logging.getLogger(__name__) + + +class FederationEventHandler(object): + """ Responsible for: + a) handling received Pdus before handing them on as Events to the rest + of the home server (including auth and state conflict resoultion) + b) converting events that were produced by local clients that may need + to be sent to remote home servers. + """ + + def __init__(self, hs): + self.store = hs.get_datastore() + self.replication_layer = hs.get_replication_layer() + self.state_handler = hs.get_state_handler() + # self.auth_handler = gs.get_auth_handler() + self.event_handler = hs.get_handlers().federation_handler + self.server_name = hs.hostname + + self.lock_manager = hs.get_room_lock_manager() + + self.replication_layer.set_handler(self) + + self.pdu_codec = PduCodec(hs) + + @log_function + @defer.inlineCallbacks + def handle_new_event(self, event): + """ Takes in an event from the client to server side, that has already + been authed and handled by the state module, and sends it to any + remote home servers that may be interested. + + Args: + event + + Returns: + Deferred: Resolved when it has successfully been queued for + processing. + """ + yield self._fill_out_prev_events(event) + + pdu = self.pdu_codec.pdu_from_event(event) + + if not hasattr(pdu, "destinations") or not pdu.destinations: + pdu.destinations = [] + + yield self.replication_layer.send_pdu(pdu) + + @log_function + @defer.inlineCallbacks + def backfill(self, room_id, limit): + # TODO: Work out which destinations to ask for pagination + # self.replication_layer.paginate(dest, room_id, limit) + pass + + @log_function + def get_state_for_room(self, destination, room_id): + return self.replication_layer.get_state_for_context( + destination, room_id + ) + + @log_function + @defer.inlineCallbacks + def on_receive_pdu(self, pdu): + """ Called by the ReplicationLayer when we have a new pdu. We need to + do auth checks and put it throught the StateHandler. + """ + event = self.pdu_codec.event_from_pdu(pdu) + + try: + with (yield self.lock_manager.lock(pdu.context)): + if event.is_state: + is_new_state = yield self.state_handler.handle_new_state( + pdu + ) + if not is_new_state: + return + else: + is_new_state = False + + yield self.event_handler.on_receive(event, is_new_state) + + except AuthError: + # TODO: Implement something in federation that allows us to + # respond to PDU. + raise + + return + + @defer.inlineCallbacks + def _on_new_state(self, pdu, new_state_event): + # TODO: Do any store stuff here. Notifiy C2S about this new + # state. + + yield self.store.update_current_state( + pdu_id=pdu.pdu_id, + origin=pdu.origin, + context=pdu.context, + pdu_type=pdu.pdu_type, + state_key=pdu.state_key + ) + + yield self.event_handler.on_receive(new_state_event) + + @defer.inlineCallbacks + def _fill_out_prev_events(self, event): + if hasattr(event, "prev_events"): + return + + results = yield self.store.get_latest_pdus_in_context( + event.room_id + ) + + es = [ + "%s@%s" % (p_id, origin) for p_id, origin, _ in results + ] + + event.prev_events = [e for e in es if e != event.event_id] + + if results: + event.depth = max([int(v) for _, _, v in results]) + 1 + else: + event.depth = 0 diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py new file mode 100644 index 0000000000..9155930e47 --- /dev/null +++ b/synapse/federation/pdu_codec.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .units import Pdu + +import copy + + +def decode_event_id(event_id, server_name): + parts = event_id.split("@") + if len(parts) < 2: + return (event_id, server_name) + else: + return (parts[0], "".join(parts[1:])) + + +def encode_event_id(pdu_id, origin): + return "%s@%s" % (pdu_id, origin) + + +class PduCodec(object): + + def __init__(self, hs): + self.server_name = hs.hostname + self.event_factory = hs.get_event_factory() + self.clock = hs.get_clock() + + def event_from_pdu(self, pdu): + kwargs = {} + + kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin) + kwargs["room_id"] = pdu.context + kwargs["etype"] = pdu.pdu_type + kwargs["prev_events"] = [ + encode_event_id(p[0], p[1]) for p in pdu.prev_pdus + ] + + if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"): + kwargs["prev_state"] = encode_event_id( + pdu.prev_state_id, pdu.prev_state_origin + ) + + kwargs.update({ + k: v + for k, v in pdu.get_full_dict().items() + if k not in [ + "pdu_id", + "context", + "pdu_type", + "prev_pdus", + "prev_state_id", + "prev_state_origin", + ] + }) + + return self.event_factory.create_event(**kwargs) + + def pdu_from_event(self, event): + d = event.get_full_dict() + + d["pdu_id"], d["origin"] = decode_event_id( + event.event_id, self.server_name + ) + d["context"] = event.room_id + d["pdu_type"] = event.type + + if hasattr(event, "prev_events"): + d["prev_pdus"] = [ + decode_event_id(e, self.server_name) + for e in event.prev_events + ] + + if hasattr(event, "prev_state"): + d["prev_state_id"], d["prev_state_origin"] = ( + decode_event_id(event.prev_state, self.server_name) + ) + + if hasattr(event, "state_key"): + d["is_state"] = True + + kwargs = copy.deepcopy(event.unrecognized_keys) + kwargs.update({ + k: v for k, v in d.items() + if k not in ["event_id", "room_id", "type", "prev_events"] + }) + + if "ts" not in kwargs: + kwargs["ts"] = int(self.clock.time_msec()) + + return Pdu(**kwargs) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py new file mode 100644 index 0000000000..ad4111c683 --- /dev/null +++ b/synapse/federation/persistence.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" This module contains all the persistence actions done by the federation +package. + +These actions are mostly only used by the :py:mod:`.replication` module. +""" + +from twisted.internet import defer + +from .units import Pdu + +from synapse.util.logutils import log_function + +import copy +import json +import logging + + +logger = logging.getLogger(__name__) + + +class PduActions(object): + """ Defines persistence actions that relate to handling PDUs. + """ + + def __init__(self, datastore): + self.store = datastore + + @log_function + def persist_received(self, pdu): + """ Persists the given `Pdu` that was received from a remote home + server. + + Returns: + Deferred + """ + return self._persist(pdu) + + @defer.inlineCallbacks + @log_function + def persist_outgoing(self, pdu): + """ Persists the given `Pdu` that this home server created. + + Returns: + Deferred + """ + ret = yield self._persist(pdu) + + defer.returnValue(ret) + + @log_function + def mark_as_processed(self, pdu): + """ Persist the fact that we have fully processed the given `Pdu` + + Returns: + Deferred + """ + return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin) + + @defer.inlineCallbacks + @log_function + def populate_previous_pdus(self, pdu): + """ Given an outgoing `Pdu` fill out its `prev_ids` key with the `Pdu`s + that we have received. + + Returns: + Deferred + """ + results = yield self.store.get_latest_pdus_in_context(pdu.context) + + pdu.prev_pdus = [(p_id, origin) for p_id, origin, _ in results] + + vs = [int(v) for _, _, v in results] + if vs: + pdu.depth = max(vs) + 1 + else: + pdu.depth = 0 + + @defer.inlineCallbacks + @log_function + def after_transaction(self, transaction_id, destination, origin): + """ Returns all `Pdu`s that we sent to the given remote home server + after a given transaction id. + + Returns: + Deferred: Results in a list of `Pdu`s + """ + results = yield self.store.get_pdus_after_transaction( + transaction_id, + destination + ) + + defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) + + @defer.inlineCallbacks + @log_function + def get_all_pdus_from_context(self, context): + results = yield self.store.get_all_pdus_from_context(context) + defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) + + @defer.inlineCallbacks + @log_function + def paginate(self, context, pdu_list, limit): + """ For a given list of PDU id and origins return the proceeding + `limit` `Pdu`s in the given `context`. + + Returns: + Deferred: Results in a list of `Pdu`s. + """ + results = yield self.store.get_pagination( + context, pdu_list, limit + ) + + defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) + + @log_function + def is_new(self, pdu): + """ When we receive a `Pdu` from a remote home server, we want to + figure out whether it is `new`, i.e. it is not some historic PDU that + we haven't seen simply because we haven't paginated back that far. + + Returns: + Deferred: Results in a `bool` + """ + return self.store.is_pdu_new( + pdu_id=pdu.pdu_id, + origin=pdu.origin, + context=pdu.context, + depth=pdu.depth + ) + + @defer.inlineCallbacks + @log_function + def _persist(self, pdu): + kwargs = copy.copy(pdu.__dict__) + unrec_keys = copy.copy(pdu.unrecognized_keys) + del kwargs["content"] + kwargs["content_json"] = json.dumps(pdu.content) + kwargs["unrecognized_keys"] = json.dumps(unrec_keys) + + logger.debug("Persisting: %s", repr(kwargs)) + + if pdu.is_state: + ret = yield self.store.persist_state(**kwargs) + else: + ret = yield self.store.persist_pdu(**kwargs) + + yield self.store.update_min_depth_for_context( + pdu.context, pdu.depth + ) + + defer.returnValue(ret) + + +class TransactionActions(object): + """ Defines persistence actions that relate to handling Transactions. + """ + + def __init__(self, datastore): + self.store = datastore + + @log_function + def have_responded(self, transaction): + """ Have we already responded to a transaction with the same id and + origin? + + Returns: + Deferred: Results in `None` if we have not previously responded to + this transaction or a 2-tuple of `(int, dict)` representing the + response code and response body. + """ + if not transaction.transaction_id: + raise RuntimeError("Cannot persist a transaction with no " + "transaction_id") + + return self.store.get_received_txn_response( + transaction.transaction_id, transaction.origin + ) + + @log_function + def set_response(self, transaction, code, response): + """ Persist how we responded to a transaction. + + Returns: + Deferred + """ + if not transaction.transaction_id: + raise RuntimeError("Cannot persist a transaction with no " + "transaction_id") + + return self.store.set_received_txn_response( + transaction.transaction_id, + transaction.origin, + code, + json.dumps(response) + ) + + @defer.inlineCallbacks + @log_function + def prepare_to_send(self, transaction): + """ Persists the `Transaction` we are about to send and works out the + correct value for the `prev_ids` key. + + Returns: + Deferred + """ + transaction.prev_ids = yield self.store.prep_send_transaction( + transaction.transaction_id, + transaction.destination, + transaction.ts, + [(p["pdu_id"], p["origin"]) for p in transaction.pdus] + ) + + @log_function + def delivered(self, transaction, response_code, response_dict): + """ Marks the given `Transaction` as having been successfully + delivered to the remote homeserver, and what the response was. + + Returns: + Deferred + """ + return self.store.delivered_txn( + transaction.transaction_id, + transaction.destination, + response_code, + json.dumps(response_dict) + ) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py new file mode 100644 index 0000000000..0f5b974291 --- /dev/null +++ b/synapse/federation/replication.py @@ -0,0 +1,582 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This layer is responsible for replicating with remote home servers using +a given transport. +""" + +from twisted.internet import defer + +from .units import Transaction, Pdu, Edu + +from .persistence import PduActions, TransactionActions + +from synapse.util.logutils import log_function + +import logging + + +logger = logging.getLogger(__name__) + + +class ReplicationLayer(object): + """This layer is responsible for replicating with remote home servers over + the given transport. I.e., does the sending and receiving of PDUs to + remote home servers. + + The layer communicates with the rest of the server via a registered + ReplicationHandler. + + In more detail, the layer: + * Receives incoming data and processes it into transactions and pdus. + * Fetches any PDUs it thinks it might have missed. + * Keeps the current state for contexts up to date by applying the + suitable conflict resolution. + * Sends outgoing pdus wrapped in transactions. + * Fills out the references to previous pdus/transactions appropriately + for outgoing data. + """ + + def __init__(self, hs, transport_layer): + self.server_name = hs.hostname + + self.transport_layer = transport_layer + self.transport_layer.register_received_handler(self) + self.transport_layer.register_request_handler(self) + + self.store = hs.get_datastore() + self.pdu_actions = PduActions(self.store) + self.transaction_actions = TransactionActions(self.store) + + self._transaction_queue = _TransactionQueue( + hs, self.transaction_actions, transport_layer + ) + + self.handler = None + self.edu_handlers = {} + + self._order = 0 + + self._clock = hs.get_clock() + + def set_handler(self, handler): + """Sets the handler that the replication layer will use to communicate + receipt of new PDUs from other home servers. The required methods are + documented on :py:class:`.ReplicationHandler`. + """ + self.handler = handler + + def register_edu_handler(self, edu_type, handler): + if edu_type in self.edu_handlers: + raise KeyError("Already have an EDU handler for %s" % (edu_type)) + + self.edu_handlers[edu_type] = handler + + @defer.inlineCallbacks + @log_function + def send_pdu(self, pdu): + """Informs the replication layer about a new PDU generated within the + home server that should be transmitted to others. + + This will fill out various attributes on the PDU object, e.g. the + `prev_pdus` key. + + *Note:* The home server should always call `send_pdu` even if it knows + that it does not need to be replicated to other home servers. This is + in case e.g. someone else joins via a remote home server and then + paginates. + + TODO: Figure out when we should actually resolve the deferred. + + Args: + pdu (Pdu): The new Pdu. + + Returns: + Deferred: Completes when we have successfully processed the PDU + and replicated it to any interested remote home servers. + """ + order = self._order + self._order += 1 + + logger.debug("[%s] Persisting PDU", pdu.pdu_id) + + #yield self.pdu_actions.populate_previous_pdus(pdu) + + # Save *before* trying to send + yield self.pdu_actions.persist_outgoing(pdu) + + logger.debug("[%s] Persisted PDU", pdu.pdu_id) + logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id) + + # TODO, add errback, etc. + self._transaction_queue.enqueue_pdu(pdu, order) + + logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id) + + @log_function + def send_edu(self, destination, edu_type, content): + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) + + # TODO, add errback, etc. + self._transaction_queue.enqueue_edu(edu) + + @defer.inlineCallbacks + @log_function + def paginate(self, dest, context, limit): + """Requests some more historic PDUs for the given context from the + given destination server. + + Args: + dest (str): The remote home server to ask. + context (str): The context to paginate back on. + limit (int): The maximum number of PDUs to return. + + Returns: + Deferred: Results in the received PDUs. + """ + extremities = yield self.store.get_oldest_pdus_in_context(context) + + logger.debug("paginate extrem=%s", extremities) + + # If there are no extremeties then we've (probably) reached the start. + if not extremities: + return + + transaction_data = yield self.transport_layer.paginate( + dest, context, extremities, limit) + + logger.debug("paginate transaction_data=%s", repr(transaction_data)) + + transaction = Transaction(**transaction_data) + + pdus = [Pdu(outlier=False, **p) for p in transaction.pdus] + for pdu in pdus: + yield self._handle_new_pdu(pdu) + + defer.returnValue(pdus) + + @defer.inlineCallbacks + @log_function + def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False): + """Requests the PDU with given origin and ID from the remote home + server. + + This will persist the PDU locally upon receipt. + + Args: + destination (str): Which home server to query + pdu_origin (str): The home server that originally sent the pdu. + pdu_id (str) + outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if + it's from an arbitary point in the context as opposed to part + of the current block of PDUs. Defaults to `False` + + Returns: + Deferred: Results in the requested PDU. + """ + + transaction_data = yield self.transport_layer.get_pdu( + destination, pdu_origin, pdu_id) + + transaction = Transaction(**transaction_data) + + pdu_list = [Pdu(outlier=outlier, **p) for p in transaction.pdus] + + pdu = None + if pdu_list: + pdu = pdu_list[0] + yield self._handle_new_pdu(pdu) + + defer.returnValue(pdu) + + @defer.inlineCallbacks + @log_function + def get_state_for_context(self, destination, context): + """Requests all of the `current` state PDUs for a given context from + a remote home server. + + Args: + destination (str): The remote homeserver to query for the state. + context (str): The context we're interested in. + + Returns: + Deferred: Results in a list of PDUs. + """ + + transaction_data = yield self.transport_layer.get_context_state( + destination, context) + + transaction = Transaction(**transaction_data) + + pdus = [Pdu(outlier=True, **p) for p in transaction.pdus] + for pdu in pdus: + yield self._handle_new_pdu(pdu) + + defer.returnValue(pdus) + + @defer.inlineCallbacks + @log_function + def on_context_pdus_request(self, context): + pdus = yield self.pdu_actions.get_all_pdus_from_context( + context + ) + defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) + + @defer.inlineCallbacks + @log_function + def on_paginate_request(self, context, versions, limit): + + pdus = yield self.pdu_actions.paginate(context, versions, limit) + + defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) + + @defer.inlineCallbacks + @log_function + def on_incoming_transaction(self, transaction_data): + transaction = Transaction(**transaction_data) + + logger.debug("[%s] Got transaction", transaction.transaction_id) + + response = yield self.transaction_actions.have_responded(transaction) + + if response: + logger.debug("[%s] We've already responed to this request", + transaction.transaction_id) + defer.returnValue(response) + return + + logger.debug("[%s] Transacition is new", transaction.transaction_id) + + pdu_list = [Pdu(**p) for p in transaction.pdus] + + dl = [] + for pdu in pdu_list: + dl.append(self._handle_new_pdu(pdu)) + + if hasattr(transaction, "edus"): + for edu in [Edu(**x) for x in transaction.edus]: + self.received_edu(edu.origin, edu.edu_type, edu.content) + + results = yield defer.DeferredList(dl) + + ret = [] + for r in results: + if r[0]: + ret.append({}) + else: + logger.exception(r[1]) + ret.append({"error": str(r[1])}) + + logger.debug("Returning: %s", str(ret)) + + yield self.transaction_actions.set_response( + transaction, + 200, response + ) + defer.returnValue((200, response)) + + def received_edu(self, origin, edu_type, content): + if edu_type in self.edu_handlers: + self.edu_handlers[edu_type](origin, content) + else: + logger.warn("Received EDU of type %s with no handler", edu_type) + + @defer.inlineCallbacks + @log_function + def on_context_state_request(self, context): + results = yield self.store.get_current_state_for_context( + context + ) + + logger.debug("Context returning %d results", len(results)) + + pdus = [Pdu.from_pdu_tuple(p) for p in results] + defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) + + @defer.inlineCallbacks + @log_function + def on_pdu_request(self, pdu_origin, pdu_id): + pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin) + + if pdu: + defer.returnValue( + (200, self._transaction_from_pdus([pdu]).get_dict()) + ) + else: + defer.returnValue((404, "")) + + @defer.inlineCallbacks + @log_function + def on_pull_request(self, origin, versions): + transaction_id = max([int(v) for v in versions]) + + response = yield self.pdu_actions.after_transaction( + transaction_id, + origin, + self.server_name + ) + + if not response: + response = [] + + defer.returnValue( + (200, self._transaction_from_pdus(response).get_dict()) + ) + + @defer.inlineCallbacks + @log_function + def _get_persisted_pdu(self, pdu_id, pdu_origin): + """ Get a PDU from the database with given origin and id. + + Returns: + Deferred: Results in a `Pdu`. + """ + pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin) + + defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple)) + + def _transaction_from_pdus(self, pdu_list): + """Returns a new Transaction containing the given PDUs suitable for + transmission. + """ + return Transaction( + pdus=[p.get_dict() for p in pdu_list], + origin=self.server_name, + ts=int(self._clock.time_msec()), + destination=None, + ) + + @defer.inlineCallbacks + @log_function + def _handle_new_pdu(self, pdu): + # We reprocess pdus when we have seen them only as outliers + existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin) + + if existing and (not existing.outlier or pdu.outlier): + logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin) + defer.returnValue({}) + return + + # Get missing pdus if necessary. + is_new = yield self.pdu_actions.is_new(pdu) + if is_new and not pdu.outlier: + # We only paginate backwards to the min depth. + min_depth = yield self.store.get_min_depth_for_context(pdu.context) + + if min_depth and pdu.depth > min_depth: + for pdu_id, origin in pdu.prev_pdus: + exists = yield self._get_persisted_pdu(pdu_id, origin) + + if not exists: + logger.debug("Requesting pdu %s %s", pdu_id, origin) + + try: + yield self.get_pdu( + pdu.origin, + pdu_id=pdu_id, + pdu_origin=origin + ) + logger.debug("Processed pdu %s %s", pdu_id, origin) + except: + # TODO(erikj): Do some more intelligent retries. + logger.exception("Failed to get PDU") + + # Persist the Pdu, but don't mark it as processed yet. + yield self.pdu_actions.persist_received(pdu) + + ret = yield self.handler.on_receive_pdu(pdu) + + yield self.pdu_actions.mark_as_processed(pdu) + + defer.returnValue(ret) + + def __str__(self): + return "<ReplicationLayer(%s)>" % self.server_name + + +class ReplicationHandler(object): + """This defines the methods that the :py:class:`.ReplicationLayer` will + use to communicate with the rest of the home server. + """ + def on_receive_pdu(self, pdu): + raise NotImplementedError("on_receive_pdu") + + +class _TransactionQueue(object): + """This class makes sure we only have one transaction in flight at + a time for a given destination. + + It batches pending PDUs into single transactions. + """ + + def __init__(self, hs, transaction_actions, transport_layer): + + self.server_name = hs.hostname + self.transaction_actions = transaction_actions + self.transport_layer = transport_layer + + self._clock = hs.get_clock() + + # Is a mapping from destinations -> deferreds. Used to keep track + # of which destinations have transactions in flight and when they are + # done + self.pending_transactions = {} + + # Is a mapping from destination -> list of + # tuple(pending pdus, deferred, order) + self.pending_pdus_by_dest = {} + # destination -> list of tuple(edu, deferred) + self.pending_edus_by_dest = {} + + # HACK to get unique tx id + self._next_txn_id = int(self._clock.time_msec()) + + @defer.inlineCallbacks + @log_function + def enqueue_pdu(self, pdu, order): + # We loop through all destinations to see whether we already have + # a transaction in progress. If we do, stick it in the pending_pdus + # table and we'll get back to it later. + + destinations = [ + d for d in pdu.destinations + if d != self.server_name + ] + + logger.debug("Sending to: %s", str(destinations)) + + if not destinations: + return + + deferreds = [] + + for destination in destinations: + deferred = defer.Deferred() + self.pending_pdus_by_dest.setdefault(destination, []).append( + (pdu, deferred, order) + ) + + self._attempt_new_transaction(destination) + + deferreds.append(deferred) + + yield defer.DeferredList(deferreds) + + # NO inlineCallbacks + def enqueue_edu(self, edu): + destination = edu.destination + + deferred = defer.Deferred() + self.pending_edus_by_dest.setdefault(destination, []).append( + (edu, deferred) + ) + + def eb(failure): + deferred.errback(failure) + self._attempt_new_transaction(destination).addErrback(eb) + + return deferred + + @defer.inlineCallbacks + @log_function + def _attempt_new_transaction(self, destination): + if destination in self.pending_transactions: + return + + # list of (pending_pdu, deferred, order) + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + + if not pending_pdus and not pending_edus: + return + + logger.debug("TX [%s] Attempting new transaction", destination) + + # Sort based on the order field + pending_pdus.sort(key=lambda t: t[2]) + + pdus = [x[0] for x in pending_pdus] + edus = [x[0] for x in pending_edus] + deferreds = [x[1] for x in pending_pdus + pending_edus] + + try: + self.pending_transactions[destination] = 1 + + logger.debug("TX [%s] Persisting transaction...", destination) + + transaction = Transaction.create_new( + ts=self._clock.time_msec(), + transaction_id=self._next_txn_id, + origin=self.server_name, + destination=destination, + pdus=pdus, + edus=edus, + ) + + self._next_txn_id += 1 + + yield self.transaction_actions.prepare_to_send(transaction) + + logger.debug("TX [%s] Persisted transaction", destination) + logger.debug("TX [%s] Sending transaction...", destination) + + # Actually send the transaction + code, response = yield self.transport_layer.send_transaction( + transaction + ) + + logger.debug("TX [%s] Sent transaction", destination) + logger.debug("TX [%s] Marking as delivered...", destination) + + yield self.transaction_actions.delivered( + transaction, code, response + ) + + logger.debug("TX [%s] Marked as delivered", destination) + logger.debug("TX [%s] Yielding to callbacks...", destination) + + for deferred in deferreds: + if code == 200: + deferred.callback(None) + else: + deferred.errback(RuntimeError("Got status %d" % code)) + + # Ensures we don't continue until all callbacks on that + # deferred have fired + yield deferred + + logger.debug("TX [%s] Yielded to callbacks", destination) + + except Exception as e: + logger.error("TX Problem in _attempt_transaction") + + # We capture this here as there as nothing actually listens + # for this finishing functions deferred. + logger.exception(e) + + for deferred in deferreds: + deferred.errback(e) + yield deferred + + finally: + # We want to be *very* sure we delete this after we stop processing + self.pending_transactions.pop(destination, None) + + # Check to see if there is anything else to send. + self._attempt_new_transaction(destination) diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py new file mode 100644 index 0000000000..2136adf8d7 --- /dev/null +++ b/synapse/federation/transport.py @@ -0,0 +1,454 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The transport layer is responsible for both sending transactions to remote +home servers and receiving a variety of requests from other home servers. + +Typically, this is done over HTTP (and all home servers are required to +support HTTP), however individual pairings of servers may decide to communicate +over a different (albeit still reliable) protocol. +""" + +from twisted.internet import defer + +from synapse.util.logutils import log_function + +import logging +import json +import re + + +logger = logging.getLogger(__name__) + + +class TransportLayer(object): + """This is a basic implementation of the transport layer that translates + transactions and other requests to/from HTTP. + + Attributes: + server_name (str): Local home server host + + server (synapse.http.server.HttpServer): the http server to + register listeners on + + client (synapse.http.client.HttpClient): the http client used to + send requests + + request_handler (TransportRequestHandler): The handler to fire when we + receive requests for data. + + received_handler (TransportReceivedHandler): The handler to fire when + we receive data. + """ + + def __init__(self, server_name, server, client): + """ + Args: + server_name (str): Local home server host + server (synapse.protocol.http.HttpServer): the http server to + register listeners on + client (synapse.protocol.http.HttpClient): the http client used to + send requests + """ + self.server_name = server_name + self.server = server + self.client = client + self.request_handler = None + self.received_handler = None + + @log_function + def get_context_state(self, destination, context): + """ Requests all state for a given context (i.e. room) from the + given server. + + Args: + destination (str): The host name of the remote home server we want + to get the state from. + context (str): The name of the context we want the state of + + Returns: + Deferred: Results in a dict received from the remote homeserver. + """ + logger.debug("get_context_state dest=%s, context=%s", + destination, context) + + path = "/state/%s/" % context + + return self._do_request_for_transaction(destination, path) + + @log_function + def get_pdu(self, destination, pdu_origin, pdu_id): + """ Requests the pdu with give id and origin from the given server. + + Args: + destination (str): The host name of the remote home server we want + to get the state from. + pdu_origin (str): The home server which created the PDU. + pdu_id (str): The id of the PDU being requested. + + Returns: + Deferred: Results in a dict received from the remote homeserver. + """ + logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s", + destination, pdu_origin, pdu_id) + + path = "/pdu/%s/%s/" % (pdu_origin, pdu_id) + + return self._do_request_for_transaction(destination, path) + + @log_function + def paginate(self, dest, context, pdu_tuples, limit): + """ Requests `limit` previous PDUs in a given context before list of + PDUs. + + Args: + dest (str) + context (str) + pdu_tuples (list) + limt (int) + + Returns: + Deferred: Results in a dict received from the remote homeserver. + """ + logger.debug( + "paginate dest=%s, context=%s, pdu_tuples=%s, limit=%s", + dest, context, repr(pdu_tuples), str(limit) + ) + + if not pdu_tuples: + return + + path = "/paginate/%s/" % context + + args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]} + args["limit"] = limit + + return self._do_request_for_transaction( + dest, + path, + args=args, + ) + + @defer.inlineCallbacks + @log_function + def send_transaction(self, transaction): + """ Sends the given Transaction to it's destination + + Args: + transaction (Transaction) + + Returns: + Deferred: Results of the deferred is a tuple in the form of + (response_code, response_body) where the response_body is a + python dict decoded from json + """ + logger.debug( + "send_data dest=%s, txid=%s", + transaction.destination, transaction.transaction_id + ) + + if transaction.destination == self.server_name: + raise RuntimeError("Transport layer cannot send to itself!") + + data = transaction.get_dict() + + code, response = yield self.client.put_json( + transaction.destination, + path="/send/%s/" % transaction.transaction_id, + data=data + ) + + logger.debug( + "send_data dest=%s, txid=%s, got response: %d", + transaction.destination, transaction.transaction_id, code + ) + + defer.returnValue((code, response)) + + @log_function + def register_received_handler(self, handler): + """ Register a handler that will be fired when we receive data. + + Args: + handler (TransportReceivedHandler) + """ + self.received_handler = handler + + # This is when someone is trying to send us a bunch of data. + self.server.register_path( + "PUT", + re.compile("^/send/([^/]*)/$"), + self._on_send_request + ) + + @log_function + def register_request_handler(self, handler): + """ Register a handler that will be fired when we get asked for data. + + Args: + handler (TransportRequestHandler) + """ + self.request_handler = handler + + # TODO(markjh): Namespace the federation URI paths + + # This is for when someone asks us for everything since version X + self.server.register_path( + "GET", + re.compile("^/pull/$"), + lambda request: handler.on_pull_request( + request.args["origin"][0], + request.args["v"] + ) + ) + + # This is when someone asks for a data item for a given server + # data_id pair. + self.server.register_path( + "GET", + re.compile("^/pdu/([^/]*)/([^/]*)/$"), + lambda request, pdu_origin, pdu_id: handler.on_pdu_request( + pdu_origin, pdu_id + ) + ) + + # This is when someone asks for all data for a given context. + self.server.register_path( + "GET", + re.compile("^/state/([^/]*)/$"), + lambda request, context: handler.on_context_state_request( + context + ) + ) + + self.server.register_path( + "GET", + re.compile("^/paginate/([^/]*)/$"), + lambda request, context: self._on_paginate_request( + context, request.args["v"], + request.args["limit"] + ) + ) + + self.server.register_path( + "GET", + re.compile("^/context/([^/]*)/$"), + lambda request, context: handler.on_context_pdus_request(context) + ) + + @defer.inlineCallbacks + @log_function + def _on_send_request(self, request, transaction_id): + """ Called on PUT /send/<transaction_id>/ + + Args: + request (twisted.web.http.Request): The HTTP request. + transaction_id (str): The transaction_id associated with this + request. This is *not* None. + + Returns: + Deferred: Results in a tuple of `(code, response)`, where + `response` is a python dict to be converted into JSON that is + used as the response body. + """ + # Parse the request + try: + data = request.content.read() + + l = data[:20].encode("string_escape") + logger.debug("Got data: \"%s\"", l) + + transaction_data = json.loads(data) + + logger.debug( + "Decoded %s: %s", + transaction_id, str(transaction_data) + ) + + # We should ideally be getting this from the security layer. + # origin = body["origin"] + + # Add some extra data to the transaction dict that isn't included + # in the request body. + transaction_data.update( + transaction_id=transaction_id, + destination=self.server_name + ) + + except Exception as e: + logger.exception(e) + defer.returnValue((400, {"error": "Invalid transaction"})) + return + + code, response = yield self.received_handler.on_incoming_transaction( + transaction_data + ) + + defer.returnValue((code, response)) + + @defer.inlineCallbacks + @log_function + def _do_request_for_transaction(self, destination, path, args={}): + """ + Args: + destination (str) + path (str) + args (dict): This is parsed directly to the HttpClient. + + Returns: + Deferred: Results in a dict. + """ + + data = yield self.client.get_json( + destination, + path=path, + args=args, + ) + + # Add certain keys to the JSON, ready for decoding as a Transaction + data.update( + origin=destination, + destination=self.server_name, + transaction_id=None + ) + + defer.returnValue(data) + + @log_function + def _on_paginate_request(self, context, v_list, limits): + if not limits: + return defer.succeed( + (400, {"error": "Did not include limit param"}) + ) + + limit = int(limits[-1]) + + versions = [v.split(",", 1) for v in v_list] + + return self.request_handler.on_paginate_request( + context, versions, limit) + + +class TransportReceivedHandler(object): + """ Callbacks used when we receive a transaction + """ + def on_incoming_transaction(self, transaction): + """ Called on PUT /send/<transaction_id>, or on response to a request + that we sent (e.g. a pagination request) + + Args: + transaction (synapse.transaction.Transaction): The transaction that + was sent to us. + + Returns: + twisted.internet.defer.Deferred: A deferred that get's fired when + the transaction has finished being processed. + + The result should be a tuple in the form of + `(response_code, respond_body)`, where `response_body` is a python + dict that will get serialized to JSON. + + On errors, the dict should have an `error` key with a brief message + of what went wrong. + """ + pass + + +class TransportRequestHandler(object): + """ Handlers used when someone want's data from us + """ + def on_pull_request(self, versions): + """ Called on GET /pull/?v=... + + This is hit when a remote home server wants to get all data + after a given transaction. Mainly used when a home server comes back + online and wants to get everything it has missed. + + Args: + versions (list): A list of transaction_ids that should be used to + determine what PDUs the remote side have not yet seen. + + Returns: + Deferred: Resultsin a tuple in the form of + `(response_code, respond_body)`, where `response_body` is a python + dict that will get serialized to JSON. + + On errors, the dict should have an `error` key with a brief message + of what went wrong. + """ + pass + + def on_pdu_request(self, pdu_origin, pdu_id): + """ Called on GET /pdu/<pdu_origin>/<pdu_id>/ + + Someone wants a particular PDU. This PDU may or may not have originated + from us. + + Args: + pdu_origin (str) + pdu_id (str) + + Returns: + Deferred: Resultsin a tuple in the form of + `(response_code, respond_body)`, where `response_body` is a python + dict that will get serialized to JSON. + + On errors, the dict should have an `error` key with a brief message + of what went wrong. + """ + pass + + def on_context_state_request(self, context): + """ Called on GET /state/<context>/ + + Get's hit when someone wants all the *current* state for a given + contexts. + + Args: + context (str): The name of the context that we're interested in. + + Returns: + twisted.internet.defer.Deferred: A deferred that get's fired when + the transaction has finished being processed. + + The result should be a tuple in the form of + `(response_code, respond_body)`, where `response_body` is a python + dict that will get serialized to JSON. + + On errors, the dict should have an `error` key with a brief message + of what went wrong. + """ + pass + + def on_paginate_request(self, context, versions, limit): + """ Called on GET /paginate/<context>/?v=...&limit=... + + Get's hit when we want to paginate backwards on a given context from + the given point. + + Args: + context (str): The context to paginate on + versions (list): A list of 2-tuple's representing where to paginate + from, in the form `(pdu_id, origin)` + limit (int): How many pdus to return. + + Returns: + Deferred: Resultsin a tuple in the form of + `(response_code, respond_body)`, where `response_body` is a python + dict that will get serialized to JSON. + + On errors, the dict should have an `error` key with a brief message + of what went wrong. + """ + pass diff --git a/synapse/federation/units.py b/synapse/federation/units.py new file mode 100644 index 0000000000..0efea7b768 --- /dev/null +++ b/synapse/federation/units.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Defines the JSON structure of the protocol units used by the server to +server protocol. +""" + +from synapse.util.jsonobject import JsonEncodedObject + +import logging +import json +import copy + + +logger = logging.getLogger(__name__) + + +class Pdu(JsonEncodedObject): + """ A Pdu represents a piece of data sent from a server and is associated + with a context. + + A Pdu can be classified as "state". For a given context, we can efficiently + retrieve all state pdu's that haven't been clobbered. Clobbering is done + via a unique constraint on the tuple (context, pdu_type, state_key). A pdu + is a state pdu if `is_state` is True. + + Example pdu:: + + { + "pdu_id": "78c", + "ts": 1404835423000, + "origin": "bar", + "prev_ids": [ + ["23b", "foo"], + ["56a", "bar"], + ], + "content": { ... }, + } + + """ + + valid_keys = [ + "pdu_id", + "context", + "origin", + "ts", + "pdu_type", + "destinations", + "transaction_id", + "prev_pdus", + "depth", + "content", + "outlier", + "is_state", # Below this are keys valid only for State Pdus. + "state_key", + "power_level", + "prev_state_id", + "prev_state_origin", + ] + + internal_keys = [ + "destinations", + "transaction_id", + "outlier", + ] + + required_keys = [ + "pdu_id", + "context", + "origin", + "ts", + "pdu_type", + "content", + ] + + # TODO: We need to make this properly load content rather than + # just leaving it as a dict. (OR DO WE?!) + + def __init__(self, destinations=[], is_state=False, prev_pdus=[], + outlier=False, **kwargs): + if is_state: + for required_key in ["state_key"]: + if required_key not in kwargs: + raise RuntimeError("Key %s is required" % required_key) + + super(Pdu, self).__init__( + destinations=destinations, + is_state=is_state, + prev_pdus=prev_pdus, + outlier=outlier, + **kwargs + ) + + @classmethod + def from_pdu_tuple(cls, pdu_tuple): + """ Converts a PduTuple to a Pdu + + Args: + pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to + convert + + Returns: + Pdu + """ + if pdu_tuple: + d = copy.copy(pdu_tuple.pdu_entry._asdict()) + + d["content"] = json.loads(d["content_json"]) + del d["content_json"] + + args = {f: d[f] for f in cls.valid_keys if f in d} + if "unrecognized_keys" in d and d["unrecognized_keys"]: + args.update(json.loads(d["unrecognized_keys"])) + + return Pdu( + prev_pdus=pdu_tuple.prev_pdu_list, + **args + ) + else: + return None + + def __str__(self): + return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) + + def __repr__(self): + return "<%s, %s>" % (self.__class__.__name__, repr(self.__dict__)) + + +class Edu(JsonEncodedObject): + """ An Edu represents a piece of data sent from one homeserver to another. + + In comparison to Pdus, Edus are not persisted for a long time on disk, are + not meaningful beyond a given pair of homeservers, and don't have an + internal ID or previous references graph. + """ + + valid_keys = [ + "origin", + "destination", + "edu_type", + "content", + ] + + required_keys = [ + "origin", + "destination", + "edu_type", + ] + + +class Transaction(JsonEncodedObject): + """ A transaction is a list of Pdus and Edus to be sent to a remote home + server with some extra metadata. + + Example transaction:: + + { + "origin": "foo", + "prev_ids": ["abc", "def"], + "pdus": [ + ... + ], + } + + """ + + valid_keys = [ + "transaction_id", + "origin", + "destination", + "ts", + "previous_ids", + "pdus", + "edus", + ] + + internal_keys = [ + "transaction_id", + "destination", + ] + + required_keys = [ + "transaction_id", + "origin", + "destination", + "ts", + "pdus", + ] + + def __init__(self, transaction_id=None, pdus=[], **kwargs): + """ If we include a list of pdus then we decode then as PDU's + automatically. + """ + + # If there's no EDUs then remove the arg + if "edus" in kwargs and not kwargs["edus"]: + del kwargs["edus"] + + super(Transaction, self).__init__( + transaction_id=transaction_id, + pdus=pdus, + **kwargs + ) + + @staticmethod + def create_new(pdus, **kwargs): + """ Used to create a new transaction. Will auto fill out + transaction_id and ts keys. + """ + if "ts" not in kwargs: + raise KeyError("Require 'ts' to construct a Transaction") + if "transaction_id" not in kwargs: + raise KeyError( + "Require 'transaction_id' to construct a Transaction" + ) + + for p in pdus: + p.transaction_id = kwargs["transaction_id"] + + kwargs["pdus"] = [p.get_dict() for p in pdus] + + return Transaction(**kwargs) + + + diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py new file mode 100644 index 0000000000..5688b68e49 --- /dev/null +++ b/synapse/handlers/__init__.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .register import RegistrationHandler +from .room import ( + MessageHandler, RoomCreationHandler, RoomMemberHandler, RoomListHandler +) +from .events import EventStreamHandler +from .federation import FederationHandler +from .login import LoginHandler +from .profile import ProfileHandler +from .presence import PresenceHandler +from .directory import DirectoryHandler + + +class Handlers(object): + + """ A collection of all the event handlers. + + There's no need to lazily create these; we'll just make them all eagerly + at construction time. + """ + + def __init__(self, hs): + self.registration_handler = RegistrationHandler(hs) + self.message_handler = MessageHandler(hs) + self.room_creation_handler = RoomCreationHandler(hs) + self.room_member_handler = RoomMemberHandler(hs) + self.event_stream_handler = EventStreamHandler(hs) + self.federation_handler = FederationHandler(hs) + self.profile_handler = ProfileHandler(hs) + self.presence_handler = PresenceHandler(hs) + self.room_list_handler = RoomListHandler(hs) + self.login_handler = LoginHandler(hs) + self.directory_handler = DirectoryHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py new file mode 100644 index 0000000000..87a392dd77 --- /dev/null +++ b/synapse/handlers/_base.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class BaseHandler(object): + + def __init__(self, hs): + self.store = hs.get_datastore() + self.event_factory = hs.get_event_factory() + self.auth = hs.get_auth() + self.notifier = hs.get_notifier() + self.room_lock = hs.get_room_lock_manager() + self.state_handler = hs.get_state_handler() + self.hs = hs diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py new file mode 100644 index 0000000000..456007c71d --- /dev/null +++ b/synapse/handlers/directory.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from ._base import BaseHandler + +from synapse.api.errors import SynapseError + +import logging +import json +import urllib + + +logger = logging.getLogger(__name__) + + +# TODO(erikj): This needs to be factored out somewere +PREFIX = "/matrix/client/api/v1" + + +class DirectoryHandler(BaseHandler): + + def __init__(self, hs): + super(DirectoryHandler, self).__init__(hs) + self.hs = hs + self.http_client = hs.get_http_client() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def create_association(self, room_alias, room_id, servers): + # TODO(erikj): Do auth. + + if not room_alias.is_mine: + raise SynapseError(400, "Room alias must be local") + # TODO(erikj): Change this. + + # TODO(erikj): Add transactions. + + # TODO(erikj): Check if there is a current association. + + yield self.store.create_room_alias_association( + room_alias, + room_id, + servers + ) + + @defer.inlineCallbacks + def get_association(self, room_alias, local_only=False): + # TODO(erikj): Do auth + + room_id = None + if room_alias.is_mine: + result = yield self.store.get_association_from_room_alias( + room_alias + ) + + if result: + room_id = result.room_id + servers = result.servers + elif not local_only: + path = "%s/ds/room/%s?local_only=1" % ( + PREFIX, + urllib.quote(room_alias.to_string()) + ) + + result = None + try: + result = yield self.http_client.get_json( + destination=room_alias.domain, + path=path, + ) + except: + # TODO(erikj): Handle this better? + logger.exception("Failed to get remote room alias") + + if result and "room_id" in result and "servers" in result: + room_id = result["room_id"] + servers = result["servers"] + + if not room_id: + defer.returnValue({}) + return + + defer.returnValue({ + "room_id": room_id, + "servers": servers, + }) + return diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py new file mode 100644 index 0000000000..79742a4e1c --- /dev/null +++ b/synapse/handlers/events.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from ._base import BaseHandler +from synapse.api.streams.event import ( + EventStream, MessagesStreamData, RoomMemberStreamData, FeedbackStreamData, + RoomDataStreamData +) +from synapse.handlers.presence import PresenceStreamData + + +class EventStreamHandler(BaseHandler): + + stream_data_classes = [ + MessagesStreamData, + RoomMemberStreamData, + FeedbackStreamData, + RoomDataStreamData, + PresenceStreamData, + ] + + def __init__(self, hs): + super(EventStreamHandler, self).__init__(hs) + + # Count of active streams per user + self._streams_per_user = {} + # Grace timers per user to delay the "stopped" signal + self._stop_timer_per_user = {} + + self.distributor = hs.get_distributor() + self.distributor.declare("started_user_eventstream") + self.distributor.declare("stopped_user_eventstream") + + self.clock = hs.get_clock() + + def get_event_stream_token(self, stream_type, store_id, start_token): + """Return the next token after this event. + + Args: + stream_type (str): The StreamData.EVENT_TYPE + store_id (int): The new storage ID assigned from the data store. + start_token (str): The token the user started with. + Returns: + str: The end token. + """ + for i, stream_cls in enumerate(EventStreamHandler.stream_data_classes): + if stream_cls.EVENT_TYPE == stream_type: + # this is the stream for this event, so replace this part of + # the token + store_ids = start_token.split(EventStream.SEPARATOR) + store_ids[i] = str(store_id) + return EventStream.SEPARATOR.join(store_ids) + raise RuntimeError("Didn't find a stream type %s" % stream_type) + + @defer.inlineCallbacks + def get_stream(self, auth_user_id, pagin_config, timeout=0): + """Gets events as an event stream for this user. + + This function looks for interesting *events* for this user. This is + different from the notifier, which looks for interested *users* who may + want to know about a single event. + + Args: + auth_user_id (str): The user requesting their event stream. + pagin_config (synapse.api.streams.PaginationConfig): The config to + use when obtaining the stream. + timeout (int): The max time to wait for an incoming event in ms. + Returns: + A pagination stream API dict + """ + auth_user = self.hs.parse_userid(auth_user_id) + + stream_id = object() + + try: + if auth_user not in self._streams_per_user: + self._streams_per_user[auth_user] = 0 + if auth_user in self._stop_timer_per_user: + self.clock.cancel_call_later( + self._stop_timer_per_user.pop(auth_user)) + else: + self.distributor.fire( + "started_user_eventstream", auth_user + ) + self._streams_per_user[auth_user] += 1 + + # construct an event stream with the correct data ordering + stream_data_list = [] + for stream_class in EventStreamHandler.stream_data_classes: + stream_data_list.append(stream_class(self.hs)) + event_stream = EventStream(auth_user_id, stream_data_list) + + # fix unknown tokens to known tokens + pagin_config = yield event_stream.fix_tokens(pagin_config) + + # register interest in receiving new events + self.notifier.store_events_for(user_id=auth_user_id, + stream_id=stream_id, + from_tok=pagin_config.from_tok) + + # see if we can grab a chunk now + data_chunk = yield event_stream.get_chunk(config=pagin_config) + + # if there are previous events, return those. If not, wait on the + # new events for 'timeout' seconds. + if len(data_chunk["chunk"]) == 0 and timeout != 0: + results = yield defer.maybeDeferred( + self.notifier.get_events_for, + user_id=auth_user_id, + stream_id=stream_id, + timeout=timeout + ) + if results: + defer.returnValue(results) + + defer.returnValue(data_chunk) + finally: + # cleanup + self.notifier.purge_events_for(user_id=auth_user_id, + stream_id=stream_id) + + self._streams_per_user[auth_user] -= 1 + if not self._streams_per_user[auth_user]: + del self._streams_per_user[auth_user] + + # 10 seconds of grace to allow the client to reconnect again + # before we think they're gone + def _later(): + self.distributor.fire( + "stopped_user_eventstream", auth_user + ) + del self._stop_timer_per_user[auth_user] + + self._stop_timer_per_user[auth_user] = ( + self.clock.call_later(5, _later) + ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py new file mode 100644 index 0000000000..12e7afca4c --- /dev/null +++ b/synapse/handlers/federation.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains handlers for federation events.""" + +from ._base import BaseHandler + +from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent +from synapse.api.constants import Membership +from synapse.util.logutils import log_function + +from twisted.internet import defer + +import logging + + +logger = logging.getLogger(__name__) + + +class FederationHandler(BaseHandler): + + """Handles events that originated from federation.""" + + @log_function + @defer.inlineCallbacks + def on_receive(self, event, is_new_state): + if hasattr(event, "state_key") and not is_new_state: + logger.debug("Ignoring old state.") + return + + target_is_mine = False + if hasattr(event, "target_host"): + target_is_mine = event.target_host == self.hs.hostname + + if event.type == InviteJoinEvent.TYPE: + if not target_is_mine: + logger.debug("Ignoring invite/join event %s", event) + return + + # If we receive an invite/join event then we need to join the + # sender to the given room. + # TODO: We should probably auth this or some such + content = event.content + content.update({"membership": Membership.JOIN}) + new_event = self.event_factory.create_event( + etype=RoomMemberEvent.TYPE, + target_user_id=event.user_id, + room_id=event.room_id, + user_id=event.user_id, + membership=Membership.JOIN, + content=content + ) + + yield self.hs.get_handlers().room_member_handler.change_membership( + new_event, + True + ) + + else: + with (yield self.room_lock.lock(event.room_id)): + store_id = yield self.store.persist_event(event) + + yield self.notifier.on_new_room_event(event, store_id) diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py new file mode 100644 index 0000000000..5a1acd7102 --- /dev/null +++ b/synapse/handlers/login.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from ._base import BaseHandler +from synapse.api.errors import LoginError + +import bcrypt +import logging + +logger = logging.getLogger(__name__) + + +class LoginHandler(BaseHandler): + + def __init__(self, hs): + super(LoginHandler, self).__init__(hs) + self.hs = hs + + @defer.inlineCallbacks + def login(self, user, password): + """Login as the specified user with the specified password. + + Args: + user (str): The user ID. + password (str): The password. + Returns: + The newly allocated access token. + Raises: + StoreError if there was a problem storing the token. + LoginError if there was an authentication problem. + """ + # TODO do this better, it can't go in __init__ else it cyclic loops + if not hasattr(self, "reg_handler"): + self.reg_handler = self.hs.get_handlers().registration_handler + + # pull out the hash for this user if they exist + user_info = yield self.store.get_user_by_id(user_id=user) + if not user_info: + logger.warn("Attempted to login as %s but they do not exist.", user) + raise LoginError(403, "") + + stored_hash = user_info[0]["password_hash"] + if bcrypt.checkpw(password, stored_hash): + # generate an access token and store it. + token = self.reg_handler._generate_token(user) + logger.info("Adding token %s for user %s", token, user) + yield self.store.add_access_token_to_user(user, token) + defer.returnValue(token) + else: + logger.warn("Failed password login for user %s", user) + raise LoginError(403, "") \ No newline at end of file diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py new file mode 100644 index 0000000000..38db4b1d67 --- /dev/null +++ b/synapse/handlers/presence.py @@ -0,0 +1,697 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.api.errors import SynapseError, AuthError +from synapse.api.constants import PresenceState +from synapse.api.streams import StreamData + +from ._base import BaseHandler + +import logging + + +logger = logging.getLogger(__name__) + + +# TODO(paul): Maybe there's one of these I can steal from somewhere +def partition(l, func): + """Partition the list by the result of func applied to each element.""" + ret = {} + + for x in l: + key = func(x) + if key not in ret: + ret[key] = [] + ret[key].append(x) + + return ret + + +def partitionbool(l, func): + def boolfunc(x): + return bool(func(x)) + + ret = partition(l, boolfunc) + return ret.get(True, []), ret.get(False, []) + + +class PresenceHandler(BaseHandler): + + def __init__(self, hs): + super(PresenceHandler, self).__init__(hs) + + self.homeserver = hs + + distributor = hs.get_distributor() + distributor.observe("registered_user", self.registered_user) + + distributor.observe( + "started_user_eventstream", self.started_user_eventstream + ) + distributor.observe( + "stopped_user_eventstream", self.stopped_user_eventstream + ) + + distributor.observe("user_joined_room", + self.user_joined_room + ) + + distributor.declare("collect_presencelike_data") + + distributor.declare("changed_presencelike_data") + distributor.observe( + "changed_presencelike_data", self.changed_presencelike_data + ) + + self.distributor = distributor + + self.federation = hs.get_replication_layer() + + self.federation.register_edu_handler( + "m.presence", self.incoming_presence + ) + self.federation.register_edu_handler( + "m.presence_invite", + lambda origin, content: self.invite_presence( + observed_user=hs.parse_userid(content["observed_user"]), + observer_user=hs.parse_userid(content["observer_user"]), + ) + ) + self.federation.register_edu_handler( + "m.presence_accept", + lambda origin, content: self.accept_presence( + observed_user=hs.parse_userid(content["observed_user"]), + observer_user=hs.parse_userid(content["observer_user"]), + ) + ) + self.federation.register_edu_handler( + "m.presence_deny", + lambda origin, content: self.deny_presence( + observed_user=hs.parse_userid(content["observed_user"]), + observer_user=hs.parse_userid(content["observer_user"]), + ) + ) + + # IN-MEMORY store, mapping local userparts to sets of local users to + # be informed of state changes. + self._local_pushmap = {} + # map local users to sets of remote /domain names/ who are interested + # in them + self._remote_sendmap = {} + # map remote users to sets of local users who're interested in them + self._remote_recvmap = {} + + # map any user to a UserPresenceCache + self._user_cachemap = {} + self._user_cachemap_latest_serial = 0 + + def _get_or_make_usercache(self, user): + """If the cache entry doesn't exist, initialise a new one.""" + if user not in self._user_cachemap: + self._user_cachemap[user] = UserPresenceCache() + return self._user_cachemap[user] + + def _get_or_offline_usercache(self, user): + """If the cache entry doesn't exist, return an OFFLINE one but do not + store it into the cache.""" + if user in self._user_cachemap: + return self._user_cachemap[user] + else: + statuscache = UserPresenceCache() + statuscache.update({"state": PresenceState.OFFLINE}, user) + return statuscache + + def registered_user(self, user): + self.store.create_presence(user.localpart) + + @defer.inlineCallbacks + def is_presence_visible(self, observer_user, observed_user): + assert(observed_user.is_mine) + + if observer_user == observed_user: + defer.returnValue(True) + + allowed_by_subscription = yield self.store.is_presence_visible( + observed_localpart=observed_user.localpart, + observer_userid=observer_user.to_string(), + ) + + if allowed_by_subscription: + defer.returnValue(True) + + # TODO(paul): Check same channel + + defer.returnValue(False) + + @defer.inlineCallbacks + def get_state(self, target_user, auth_user): + if target_user.is_mine: + visible = yield self.is_presence_visible(observer_user=auth_user, + observed_user=target_user + ) + + if visible: + state = yield self.store.get_presence_state( + target_user.localpart + ) + defer.returnValue(state) + else: + raise SynapseError(404, "Presence information not visible") + else: + # TODO(paul): Have remote server send us permissions set + defer.returnValue( + self._get_or_offline_usercache(target_user).get_state() + ) + + @defer.inlineCallbacks + def set_state(self, target_user, auth_user, state): + if not target_user.is_mine: + raise SynapseError(400, "User is not hosted on this Home Server") + + if target_user != auth_user: + raise AuthError(400, "Cannot set another user's displayname") + + # TODO(paul): Sanity-check 'state' + if "status_msg" not in state: + state["status_msg"] = None + + for k in state.keys(): + if k not in ("state", "status_msg"): + raise SynapseError( + 400, "Unexpected presence state key '%s'" % (k,) + ) + + logger.debug("Updating presence state of %s to %s", + target_user.localpart, state["state"]) + + state_to_store = dict(state) + + yield defer.DeferredList([ + self.store.set_presence_state( + target_user.localpart, state_to_store + ), + self.distributor.fire( + "collect_presencelike_data", target_user, state + ), + ]) + + now_online = state["state"] != PresenceState.OFFLINE + was_polling = target_user in self._user_cachemap + + if now_online and not was_polling: + self.start_polling_presence(target_user, state=state) + elif not now_online and was_polling: + self.stop_polling_presence(target_user) + + # TODO(paul): perform a presence push as part of start/stop poll so + # we don't have to do this all the time + self.changed_presencelike_data(target_user, state) + + if not now_online: + del self._user_cachemap[target_user] + + def changed_presencelike_data(self, user, state): + statuscache = self._get_or_make_usercache(user) + + self._user_cachemap_latest_serial += 1 + statuscache.update(state, serial=self._user_cachemap_latest_serial) + + self.push_presence(user, statuscache=statuscache) + + def started_user_eventstream(self, user): + # TODO(paul): Use "last online" state + self.set_state(user, user, {"state": PresenceState.ONLINE}) + + def stopped_user_eventstream(self, user): + # TODO(paul): Save current state as "last online" state + self.set_state(user, user, {"state": PresenceState.OFFLINE}) + + @defer.inlineCallbacks + def user_joined_room(self, user, room_id): + localusers = set() + remotedomains = set() + + rm_handler = self.homeserver.get_handlers().room_member_handler + yield rm_handler.fetch_room_distributions_into(room_id, + localusers=localusers, remotedomains=remotedomains, + ignore_user=user) + + if user.is_mine: + yield self._send_presence_to_distribution(srcuser=user, + localusers=localusers, remotedomains=remotedomains, + statuscache=self._get_or_offline_usercache(user), + ) + + for srcuser in localusers: + yield self._send_presence(srcuser=srcuser, destuser=user, + statuscache=self._get_or_offline_usercache(srcuser), + ) + + @defer.inlineCallbacks + def send_invite(self, observer_user, observed_user): + if not observer_user.is_mine: + raise SynapseError(400, "User is not hosted on this Home Server") + + yield self.store.add_presence_list_pending( + observer_user.localpart, observed_user.to_string() + ) + + if observed_user.is_mine: + yield self.invite_presence(observed_user, observer_user) + else: + yield self.federation.send_edu( + destination=observed_user.domain, + edu_type="m.presence_invite", + content={ + "observed_user": observed_user.to_string(), + "observer_user": observer_user.to_string(), + } + ) + + @defer.inlineCallbacks + def _should_accept_invite(self, observed_user, observer_user): + if not observed_user.is_mine: + defer.returnValue(False) + + row = yield self.store.has_presence_state(observed_user.localpart) + if not row: + defer.returnValue(False) + + # TODO(paul): Eventually we'll ask the user's permission for this + # before accepting. For now just accept any invite request + defer.returnValue(True) + + @defer.inlineCallbacks + def invite_presence(self, observed_user, observer_user): + accept = yield self._should_accept_invite(observed_user, observer_user) + + if accept: + yield self.store.allow_presence_visible( + observed_user.localpart, observer_user.to_string() + ) + + if observer_user.is_mine: + if accept: + yield self.accept_presence(observed_user, observer_user) + else: + yield self.deny_presence(observed_user, observer_user) + else: + edu_type = "m.presence_accept" if accept else "m.presence_deny" + + yield self.federation.send_edu( + destination=observer_user.domain, + edu_type=edu_type, + content={ + "observed_user": observed_user.to_string(), + "observer_user": observer_user.to_string(), + } + ) + + @defer.inlineCallbacks + def accept_presence(self, observed_user, observer_user): + yield self.store.set_presence_list_accepted( + observer_user.localpart, observed_user.to_string() + ) + + self.start_polling_presence(observer_user, target_user=observed_user) + + @defer.inlineCallbacks + def deny_presence(self, observed_user, observer_user): + yield self.store.del_presence_list( + observer_user.localpart, observed_user.to_string() + ) + + # TODO(paul): Inform the user somehow? + + @defer.inlineCallbacks + def drop(self, observed_user, observer_user): + if not observer_user.is_mine: + raise SynapseError(400, "User is not hosted on this Home Server") + + yield self.store.del_presence_list( + observer_user.localpart, observed_user.to_string() + ) + + self.stop_polling_presence(observer_user, target_user=observed_user) + + @defer.inlineCallbacks + def get_presence_list(self, observer_user, accepted=None): + if not observer_user.is_mine: + raise SynapseError(400, "User is not hosted on this Home Server") + + presence = yield self.store.get_presence_list( + observer_user.localpart, accepted=accepted + ) + + for p in presence: + observed_user = self.hs.parse_userid(p.pop("observed_user_id")) + p["observed_user"] = observed_user + p.update(self._get_or_offline_usercache(observed_user).get_state()) + + defer.returnValue(presence) + + @defer.inlineCallbacks + def start_polling_presence(self, user, target_user=None, state=None): + logger.debug("Start polling for presence from %s", user) + + if target_user: + target_users = [target_user] + else: + presence = yield self.store.get_presence_list( + user.localpart, accepted=True + ) + target_users = [ + self.hs.parse_userid(x["observed_user_id"]) for x in presence + ] + + if state is None: + state = yield self.store.get_presence_state(user.localpart) + + localusers, remoteusers = partitionbool( + target_users, + lambda u: u.is_mine + ) + + for target_user in localusers: + self._start_polling_local(user, target_user) + + deferreds = [] + remoteusers_by_domain = partition(remoteusers, lambda u: u.domain) + for domain in remoteusers_by_domain: + remoteusers = remoteusers_by_domain[domain] + + deferreds.append(self._start_polling_remote( + user, domain, remoteusers + )) + + yield defer.DeferredList(deferreds) + + def _start_polling_local(self, user, target_user): + target_localpart = target_user.localpart + + if not self.is_presence_visible(observer_user=user, + observed_user=target_user): + return + + if target_localpart not in self._local_pushmap: + self._local_pushmap[target_localpart] = set() + + self._local_pushmap[target_localpart].add(user) + + self.push_update_to_clients( + observer_user=user, + observed_user=target_user, + statuscache=self._get_or_offline_usercache(target_user), + ) + + def _start_polling_remote(self, user, domain, remoteusers): + for u in remoteusers: + if u not in self._remote_recvmap: + self._remote_recvmap[u] = set() + + self._remote_recvmap[u].add(user) + + return self.federation.send_edu( + destination=domain, + edu_type="m.presence", + content={"poll": [u.to_string() for u in remoteusers]} + ) + + def stop_polling_presence(self, user, target_user=None): + logger.debug("Stop polling for presence from %s", user) + + if not target_user or target_user.is_mine: + self._stop_polling_local(user, target_user=target_user) + + deferreds = [] + + if target_user: + raise NotImplementedError("TODO: remove one user") + + remoteusers = [u for u in self._remote_recvmap + if user in self._remote_recvmap[u]] + remoteusers_by_domain = partition(remoteusers, lambda u: u.domain) + + for domain in remoteusers_by_domain: + remoteusers = remoteusers_by_domain[domain] + + deferreds.append( + self._stop_polling_remote(user, domain, remoteusers) + ) + + return defer.DeferredList(deferreds) + + def _stop_polling_local(self, user, target_user): + for localpart in self._local_pushmap.keys(): + if target_user and localpart != target_user.localpart: + continue + + if user in self._local_pushmap[localpart]: + self._local_pushmap[localpart].remove(user) + + if not self._local_pushmap[localpart]: + del self._local_pushmap[localpart] + + def _stop_polling_remote(self, user, domain, remoteusers): + for u in remoteusers: + self._remote_recvmap[u].remove(user) + + if not self._remote_recvmap[u]: + del self._remote_recvmap[u] + + return self.federation.send_edu( + destination=domain, + edu_type="m.presence", + content={"unpoll": [u.to_string() for u in remoteusers]} + ) + + @defer.inlineCallbacks + def push_presence(self, user, statuscache): + assert(user.is_mine) + + logger.debug("Pushing presence update from %s", user) + + localusers = set(self._local_pushmap.get(user.localpart, set())) + remotedomains = set(self._remote_sendmap.get(user.localpart, set())) + + # Reflect users' status changes back to themselves, so UIs look nice + # and also user is informed of server-forced pushes + localusers.add(user) + + rm_handler = self.homeserver.get_handlers().room_member_handler + room_ids = yield rm_handler.get_rooms_for_user(user) + + for room_id in room_ids: + yield rm_handler.fetch_room_distributions_into( + room_id, localusers=localusers, remotedomains=remotedomains, + ignore_user=user, + ) + + if not localusers and not remotedomains: + defer.returnValue(None) + + yield self._send_presence_to_distribution(user, + localusers=localusers, remotedomains=remotedomains, + statuscache=statuscache + ) + + def _send_presence(self, srcuser, destuser, statuscache): + if destuser.is_mine: + self.push_update_to_clients( + observer_user=destuser, + observed_user=srcuser, + statuscache=statuscache) + return defer.succeed(None) + else: + return self._push_presence_remote(srcuser, destuser.domain, + state=statuscache.get_state() + ) + + @defer.inlineCallbacks + def _send_presence_to_distribution(self, srcuser, localusers=set(), + remotedomains=set(), statuscache=None): + + for u in localusers: + logger.debug(" | push to local user %s", u) + self.push_update_to_clients( + observer_user=u, + observed_user=srcuser, + statuscache=statuscache, + ) + + deferreds = [] + for domain in remotedomains: + logger.debug(" | push to remote domain %s", domain) + deferreds.append(self._push_presence_remote(srcuser, domain, + state=statuscache.get_state()) + ) + + yield defer.DeferredList(deferreds) + + @defer.inlineCallbacks + def _push_presence_remote(self, user, destination, state=None): + if state is None: + state = yield self.store.get_presence_state(user.localpart) + yield self.distributor.fire( + "collect_presencelike_data", user, state + ) + + yield self.federation.send_edu( + destination=destination, + edu_type="m.presence", + content={ + "push": [ + dict(user_id=user.to_string(), **state), + ], + } + ) + + @defer.inlineCallbacks + def incoming_presence(self, origin, content): + deferreds = [] + + for push in content.get("push", []): + user = self.hs.parse_userid(push["user_id"]) + + logger.debug("Incoming presence update from %s", user) + + observers = set(self._remote_recvmap.get(user, set())) + + rm_handler = self.homeserver.get_handlers().room_member_handler + room_ids = yield rm_handler.get_rooms_for_user(user) + + for room_id in room_ids: + yield rm_handler.fetch_room_distributions_into( + room_id, localusers=observers, ignore_user=user + ) + + if not observers: + break + + state = dict(push) + del state["user_id"] + + statuscache = self._get_or_make_usercache(user) + + self._user_cachemap_latest_serial += 1 + statuscache.update(state, serial=self._user_cachemap_latest_serial) + + for observer_user in observers: + self.push_update_to_clients( + observer_user=observer_user, + observed_user=user, + statuscache=statuscache, + ) + + if state["state"] == PresenceState.OFFLINE: + del self._user_cachemap[user] + + for poll in content.get("poll", []): + user = self.hs.parse_userid(poll) + + if not user.is_mine: + continue + + # TODO(paul) permissions checks + + if not user in self._remote_sendmap: + self._remote_sendmap[user] = set() + + self._remote_sendmap[user].add(origin) + + deferreds.append(self._push_presence_remote(user, origin)) + + for unpoll in content.get("unpoll", []): + user = self.hs.parse_userid(unpoll) + + if not user.is_mine: + continue + + if user in self._remote_sendmap: + self._remote_sendmap[user].remove(origin) + + if not self._remote_sendmap[user]: + del self._remote_sendmap[user] + + yield defer.DeferredList(deferreds) + + def push_update_to_clients(self, observer_user, observed_user, + statuscache): + self.notifier.on_new_user_event( + observer_user.to_string(), + event_data=statuscache.make_event(user=observed_user), + stream_type=PresenceStreamData, + store_id=statuscache.serial + ) + + +class PresenceStreamData(StreamData): + def __init__(self, hs): + super(PresenceStreamData, self).__init__(hs) + self.presence = hs.get_handlers().presence_handler + + def get_rows(self, user_id, from_key, to_key, limit): + cachemap = self.presence._user_cachemap + + # TODO(paul): limit, and filter by visibility + updates = [(k, cachemap[k]) for k in cachemap + if from_key < cachemap[k].serial <= to_key] + + if updates: + latest_serial = max([x[1].serial for x in updates]) + data = [x[1].make_event(user=x[0]) for x in updates] + return ((data, latest_serial)) + else: + return (([], self.presence._user_cachemap_latest_serial)) + + def max_token(self): + return self.presence._user_cachemap_latest_serial + +PresenceStreamData.EVENT_TYPE = PresenceStreamData + + +class UserPresenceCache(object): + """Store an observed user's state and status message. + + Includes the update timestamp. + """ + def __init__(self): + self.state = {} + self.serial = None + + def update(self, state, serial): + self.state.update(state) + # Delete keys that are now 'None' + for k in self.state.keys(): + if self.state[k] is None: + del self.state[k] + + self.serial = serial + + if "status_msg" in state: + self.status_msg = state["status_msg"] + else: + self.status_msg = None + + def get_state(self): + # clone it so caller can't break our cache + return dict(self.state) + + def make_event(self, user): + content = self.get_state() + content["user_id"] = user.to_string() + + return {"type": "m.presence", "content": content} diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py new file mode 100644 index 0000000000..a27206b002 --- /dev/null +++ b/synapse/handlers/profile.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.api.errors import SynapseError, AuthError + +from synapse.api.errors import CodeMessageException + +from ._base import BaseHandler + +import logging + + +logger = logging.getLogger(__name__) + +PREFIX = "/matrix/client/api/v1" + + +class ProfileHandler(BaseHandler): + + def __init__(self, hs): + super(ProfileHandler, self).__init__(hs) + + self.client = hs.get_http_client() + + distributor = hs.get_distributor() + self.distributor = distributor + + distributor.observe("registered_user", self.registered_user) + + distributor.observe( + "collect_presencelike_data", self.collect_presencelike_data + ) + + def registered_user(self, user): + self.store.create_profile(user.localpart) + + @defer.inlineCallbacks + def get_displayname(self, target_user, local_only=False): + if target_user.is_mine: + displayname = yield self.store.get_profile_displayname( + target_user.localpart + ) + + defer.returnValue(displayname) + elif not local_only: + # TODO(paul): This should use the server-server API to ask another + # HS. For now we'll just have it use the http client to talk to the + # other HS's REST client API + path = PREFIX + "/profile/%s/displayname?local_only=1" % ( + target_user.to_string() + ) + + try: + result = yield self.client.get_json( + destination=target_user.domain, + path=path + ) + except CodeMessageException as e: + if e.code != 404: + logger.exception("Failed to get displayname") + + raise + except: + logger.exception("Failed to get displayname") + + defer.returnValue(result["displayname"]) + else: + raise SynapseError(400, "User is not hosted on this Home Server") + + @defer.inlineCallbacks + def set_displayname(self, target_user, auth_user, new_displayname): + """target_user is the user whose displayname is to be changed; + auth_user is the user attempting to make this change.""" + if not target_user.is_mine: + raise SynapseError(400, "User is not hosted on this Home Server") + + if target_user != auth_user: + raise AuthError(400, "Cannot set another user's displayname") + + yield self.store.set_profile_displayname( + target_user.localpart, new_displayname + ) + + yield self.distributor.fire( + "changed_presencelike_data", target_user, { + "displayname": new_displayname, + } + ) + + @defer.inlineCallbacks + def get_avatar_url(self, target_user, local_only=False): + if target_user.is_mine: + avatar_url = yield self.store.get_profile_avatar_url( + target_user.localpart + ) + + defer.returnValue(avatar_url) + elif not local_only: + # TODO(paul): This should use the server-server API to ask another + # HS. For now we'll just have it use the http client to talk to the + # other HS's REST client API + destination = target_user.domain + path = PREFIX + "/profile/%s/avatar_url?local_only=1" % ( + target_user.to_string(), + ) + + try: + result = yield self.client.get_json( + destination=destination, + path=path + ) + except CodeMessageException as e: + if e.code != 404: + logger.exception("Failed to get avatar_url") + raise + except: + logger.exception("Failed to get avatar_url") + + defer.returnValue(result["avatar_url"]) + else: + raise SynapseError(400, "User is not hosted on this Home Server") + + @defer.inlineCallbacks + def set_avatar_url(self, target_user, auth_user, new_avatar_url): + """target_user is the user whose avatar_url is to be changed; + auth_user is the user attempting to make this change.""" + if not target_user.is_mine: + raise SynapseError(400, "User is not hosted on this Home Server") + + if target_user != auth_user: + raise AuthError(400, "Cannot set another user's avatar_url") + + yield self.store.set_profile_avatar_url( + target_user.localpart, new_avatar_url + ) + + yield self.distributor.fire( + "changed_presencelike_data", target_user, { + "avatar_url": new_avatar_url, + } + ) + + @defer.inlineCallbacks + def collect_presencelike_data(self, user, state): + if not user.is_mine: + defer.returnValue(None) + + (displayname, avatar_url) = yield defer.gatherResults([ + self.store.get_profile_displayname(user.localpart), + self.store.get_profile_avatar_url(user.localpart), + ]) + + state["displayname"] = displayname + state["avatar_url"] = avatar_url + + defer.returnValue(None) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py new file mode 100644 index 0000000000..246c1f6530 --- /dev/null +++ b/synapse/handlers/register.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains functions for registering clients.""" +from twisted.internet import defer + +from synapse.types import UserID +from synapse.api.errors import SynapseError, RegistrationError +from ._base import BaseHandler +import synapse.util.stringutils as stringutils + +import base64 +import bcrypt + + +class RegistrationHandler(BaseHandler): + + def __init__(self, hs): + super(RegistrationHandler, self).__init__(hs) + + self.distributor = hs.get_distributor() + self.distributor.declare("registered_user") + + @defer.inlineCallbacks + def register(self, localpart=None, password=None): + """Registers a new client on the server. + + Args: + localpart : The local part of the user ID to register. If None, + one will be randomly generated. + password (str) : The password to assign to this user so they can + login again. + Returns: + A tuple of (user_id, access_token). + Raises: + RegistrationError if there was a problem registering. + """ + password_hash = None + if password: + password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) + + if localpart: + user = UserID(localpart, self.hs.hostname, True) + user_id = user.to_string() + + token = self._generate_token(user_id) + yield self.store.register(user_id=user_id, + token=token, + password_hash=password_hash) + + self.distributor.fire("registered_user", user) + defer.returnValue((user_id, token)) + else: + # autogen a random user ID + attempts = 0 + user_id = None + token = None + while not user_id and not token: + try: + localpart = self._generate_user_id() + user = UserID(localpart, self.hs.hostname, True) + user_id = user.to_string() + + token = self._generate_token(user_id) + yield self.store.register( + user_id=user_id, + token=token, + password_hash=password_hash) + + self.distributor.fire("registered_user", user) + defer.returnValue((user_id, token)) + except SynapseError: + # if user id is taken, just generate another + user_id = None + token = None + attempts += 1 + if attempts > 5: + raise RegistrationError( + 500, "Cannot generate user ID.") + + def _generate_token(self, user_id): + # urlsafe variant uses _ and - so use . as the separator and replace + # all =s with .s so http clients don't quote =s when it is used as + # query params. + return (base64.urlsafe_b64encode(user_id).replace('=', '.') + '.' + + stringutils.random_string(18)) + + def _generate_user_id(self): + return "-" + stringutils.random_string(18) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py new file mode 100644 index 0000000000..4d82b33993 --- /dev/null +++ b/synapse/handlers/room.py @@ -0,0 +1,808 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains functions for performing events on rooms.""" +from twisted.internet import defer + +from synapse.types import UserID, RoomAlias, RoomID +from synapse.api.constants import Membership +from synapse.api.errors import RoomError, StoreError, SynapseError +from synapse.api.events.room import ( + RoomTopicEvent, MessageEvent, InviteJoinEvent, RoomMemberEvent, + RoomConfigEvent +) +from synapse.api.streams.event import EventStream, MessagesStreamData +from synapse.util import stringutils +from ._base import BaseHandler + +import logging +import json + +logger = logging.getLogger(__name__) + + +class MessageHandler(BaseHandler): + + def __init__(self, hs): + super(MessageHandler, self).__init__(hs) + self.hs = hs + self.clock = hs.get_clock() + self.event_factory = hs.get_event_factory() + + @defer.inlineCallbacks + def get_message(self, msg_id=None, room_id=None, sender_id=None, + user_id=None): + """ Retrieve a message. + + Args: + msg_id (str): The message ID to obtain. + room_id (str): The room where the message resides. + sender_id (str): The user ID of the user who sent the message. + user_id (str): The user ID of the user making this request. + Returns: + The message, or None if no message exists. + Raises: + SynapseError if something went wrong. + """ + yield self.auth.check_joined_room(room_id, user_id) + + # Pull out the message from the db + msg = yield self.store.get_message(room_id=room_id, + msg_id=msg_id, + user_id=sender_id) + + if msg: + defer.returnValue(msg) + defer.returnValue(None) + + @defer.inlineCallbacks + def send_message(self, event=None, suppress_auth=False, stamp_event=True): + """ Send a message. + + Args: + event : The message event to store. + suppress_auth (bool) : True to suppress auth for this message. This + is primarily so the home server can inject messages into rooms at + will. + stamp_event (bool) : True to stamp event content with server keys. + Raises: + SynapseError if something went wrong. + """ + if stamp_event: + event.content["hsob_ts"] = int(self.clock.time_msec()) + + with (yield self.room_lock.lock(event.room_id)): + if not suppress_auth: + yield self.auth.check(event, raises=True) + + # store message in db + store_id = yield self.store.persist_event(event) + + event.destinations = yield self.store.get_joined_hosts_for_room( + event.room_id + ) + + yield self.hs.get_federation().handle_new_event(event) + + self.notifier.on_new_room_event(event, store_id) + + @defer.inlineCallbacks + def get_messages(self, user_id=None, room_id=None, pagin_config=None, + feedback=False): + """Get messages in a room. + + Args: + user_id (str): The user requesting messages. + room_id (str): The room they want messages from. + pagin_config (synapse.api.streams.PaginationConfig): The pagination + config rules to apply, if any. + feedback (bool): True to get compressed feedback with the messages + Returns: + dict: Pagination API results + """ + yield self.auth.check_joined_room(room_id, user_id) + + data_source = [MessagesStreamData(self.hs, room_id=room_id, + feedback=feedback)] + event_stream = EventStream(user_id, data_source) + pagin_config = yield event_stream.fix_tokens(pagin_config) + data_chunk = yield event_stream.get_chunk(config=pagin_config) + defer.returnValue(data_chunk) + + @defer.inlineCallbacks + def store_room_data(self, event=None, stamp_event=True): + """ Stores data for a room. + + Args: + event : The room path event + stamp_event (bool) : True to stamp event content with server keys. + Raises: + SynapseError if something went wrong. + """ + + with (yield self.room_lock.lock(event.room_id)): + yield self.auth.check(event, raises=True) + + if stamp_event: + event.content["hsob_ts"] = int(self.clock.time_msec()) + + yield self.state_handler.handle_new_event(event) + + # store in db + store_id = yield self.store.store_room_data( + room_id=event.room_id, + etype=event.type, + state_key=event.state_key, + content=json.dumps(event.content) + ) + + event.destinations = yield self.store.get_joined_hosts_for_room( + event.room_id + ) + self.notifier.on_new_room_event(event, store_id) + + yield self.hs.get_federation().handle_new_event(event) + + @defer.inlineCallbacks + def get_room_data(self, user_id=None, room_id=None, + event_type=None, state_key="", + public_room_rules=[], + private_room_rules=["join"]): + """ Get data from a room. + + Args: + event : The room path event + public_room_rules : A list of membership states the user can be in, + in order to read this data IN A PUBLIC ROOM. An empty list means + 'any state'. + private_room_rules : A list of membership states the user can be + in, in order to read this data IN A PRIVATE ROOM. An empty list + means 'any state'. + Returns: + The path data content. + Raises: + SynapseError if something went wrong. + """ + if event_type == RoomTopicEvent.TYPE: + # anyone invited/joined can read the topic + private_room_rules = ["invite", "join"] + + # does this room exist + room = yield self.store.get_room(room_id) + if not room: + raise RoomError(403, "Room does not exist.") + + # does this user exist in this room + member = yield self.store.get_room_member( + room_id=room_id, + user_id="" if not user_id else user_id) + + member_state = member.membership if member else None + + if room.is_public and public_room_rules: + # make sure the user meets public room rules + if member_state not in public_room_rules: + raise RoomError(403, "Member does not meet public room rules.") + elif not room.is_public and private_room_rules: + # make sure the user meets private room rules + if member_state not in private_room_rules: + raise RoomError( + 403, "Member does not meet private room rules.") + + data = yield self.store.get_room_data(room_id, event_type, state_key) + defer.returnValue(data) + + @defer.inlineCallbacks + def get_feedback(self, room_id=None, msg_sender_id=None, msg_id=None, + user_id=None, fb_sender_id=None, fb_type=None): + yield self.auth.check_joined_room(room_id, user_id) + + # Pull out the feedback from the db + fb = yield self.store.get_feedback( + room_id=room_id, msg_id=msg_id, msg_sender_id=msg_sender_id, + fb_sender_id=fb_sender_id, fb_type=fb_type + ) + + if fb: + defer.returnValue(fb) + defer.returnValue(None) + + @defer.inlineCallbacks + def send_feedback(self, event, stamp_event=True): + if stamp_event: + event.content["hsob_ts"] = int(self.clock.time_msec()) + + with (yield self.room_lock.lock(event.room_id)): + yield self.auth.check(event, raises=True) + + # store message in db + store_id = yield self.store.persist_event(event) + + event.destinations = yield self.store.get_joined_hosts_for_room( + event.room_id + ) + yield self.hs.get_federation().handle_new_event(event) + + self.notifier.on_new_room_event(event, store_id) + + @defer.inlineCallbacks + def snapshot_all_rooms(self, user_id=None, pagin_config=None, + feedback=False): + """Retrieve a snapshot of all rooms the user is invited or has joined. + + This snapshot may include messages for all rooms where the user is + joined, depending on the pagination config. + + Args: + user_id (str): The ID of the user making the request. + pagin_config (synapse.api.streams.PaginationConfig): The pagination + config used to determine how many messages *PER ROOM* to return. + feedback (bool): True to get feedback along with these messages. + Returns: + A list of dicts with "room_id" and "membership" keys for all rooms + the user is currently invited or joined in on. Rooms where the user + is joined on, may return a "messages" key with messages, depending + on the specified PaginationConfig. + """ + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=user_id, + membership_list=[Membership.INVITE, Membership.JOIN] + ) + for room_info in room_list: + if room_info["membership"] != Membership.JOIN: + continue + try: + event_chunk = yield self.get_messages( + user_id=user_id, + pagin_config=pagin_config, + feedback=feedback, + room_id=room_info["room_id"] + ) + room_info["messages"] = event_chunk + except: + pass + defer.returnValue(room_list) + + +class RoomCreationHandler(BaseHandler): + + @defer.inlineCallbacks + def create_room(self, user_id, room_id, config): + """ Creates a new room. + + Args: + user_id (str): The ID of the user creating the new room. + room_id (str): The proposed ID for the new room. Can be None, in + which case one will be created for you. + config (dict) : A dict of configuration options. + Returns: + The new room ID. + Raises: + SynapseError if the room ID was taken, couldn't be stored, or + something went horribly wrong. + """ + + if "room_alias_name" in config: + room_alias = RoomAlias.create_local( + config["room_alias_name"], + self.hs + ) + mapping = yield self.store.get_association_from_room_alias( + room_alias + ) + + if mapping: + raise SynapseError(400, "Room alias already taken") + else: + room_alias = None + + if room_id: + # Ensure room_id is the correct type + room_id_obj = RoomID.from_string(room_id, self.hs) + if not room_id_obj.is_mine: + raise SynapseError(400, "Room id must be local") + + yield self.store.store_room( + room_id=room_id, + room_creator_user_id=user_id, + is_public=config["visibility"] == "public" + ) + else: + # autogen room IDs and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + room_id = None + while attempts < 5: + try: + random_string = stringutils.random_string(18) + gen_room_id = RoomID.create_local(random_string, self.hs) + yield self.store.store_room( + room_id=gen_room_id.to_string(), + room_creator_user_id=user_id, + is_public=config["visibility"] == "public" + ) + room_id = gen_room_id.to_string() + break + except StoreError: + attempts += 1 + if not room_id: + raise StoreError(500, "Couldn't generate a room ID.") + + config_event = self.event_factory.create_event( + etype=RoomConfigEvent.TYPE, + room_id=room_id, + user_id=user_id, + content=config, + ) + + if room_alias: + yield self.store.create_room_alias_association( + room_id=room_id, + room_alias=room_alias, + servers=[self.hs.hostname], + ) + + yield self.state_handler.handle_new_event(config_event) + # store_id = persist... + + yield self.hs.get_federation().handle_new_event(config_event) + # self.notifier.on_new_room_event(event, store_id) + + content = {"membership": Membership.JOIN} + join_event = self.event_factory.create_event( + etype=RoomMemberEvent.TYPE, + target_user_id=user_id, + room_id=room_id, + user_id=user_id, + membership=Membership.JOIN, + content=content + ) + + yield self.hs.get_handlers().room_member_handler.change_membership( + join_event, + broadcast_msg=True, + do_auth=False + ) + + result = {"room_id": room_id} + if room_alias: + result["room_alias"] = room_alias.to_string() + + defer.returnValue(result) + + +class RoomMemberHandler(BaseHandler): + # TODO(paul): This handler currently contains a messy conflation of + # low-level API that works on UserID objects and so on, and REST-level + # API that takes ID strings and returns pagination chunks. These concerns + # ought to be separated out a lot better. + + def __init__(self, hs): + super(RoomMemberHandler, self).__init__(hs) + + self.clock = hs.get_clock() + + self.distributor = hs.get_distributor() + self.distributor.declare("user_joined_room") + + @defer.inlineCallbacks + def get_room_members(self, room_id, membership=Membership.JOIN): + hs = self.hs + + memberships = yield self.store.get_room_members( + room_id=room_id, membership=membership + ) + + defer.returnValue([hs.parse_userid(m.user_id) for m in memberships]) + + @defer.inlineCallbacks + def fetch_room_distributions_into(self, room_id, localusers=None, + remotedomains=None, ignore_user=None): + """Fetch the distribution of a room, adding elements to either + 'localusers' or 'remotedomains', which should be a set() if supplied. + If ignore_user is set, ignore that user. + + This function returns nothing; its result is performed by the + side-effect on the two passed sets. This allows easy accumulation of + member lists of multiple rooms at once if required. + """ + members = yield self.get_room_members(room_id) + for member in members: + if ignore_user is not None and member == ignore_user: + continue + + if member.is_mine: + if localusers is not None: + localusers.add(member) + else: + if remotedomains is not None: + remotedomains.add(member.domain) + + @defer.inlineCallbacks + def get_room_members_as_pagination_chunk(self, room_id=None, user_id=None, + limit=0, start_tok=None, + end_tok=None): + """Retrieve a list of room members in the room. + + Args: + room_id (str): The room to get the member list for. + user_id (str): The ID of the user making the request. + limit (int): The max number of members to return. + start_tok (str): Optional. The start token if known. + end_tok (str): Optional. The end token if known. + Returns: + dict: A Pagination streamable dict. + Raises: + SynapseError if something goes wrong. + """ + yield self.auth.check_joined_room(room_id, user_id) + + member_list = yield self.store.get_room_members(room_id=room_id) + event_list = [ + entry.as_event(self.event_factory).get_dict() + for entry in member_list + ] + chunk_data = { + "start": "START", + "end": "END", + "chunk": event_list + } + # TODO honor Pagination stream params + # TODO snapshot this list to return on subsequent requests when + # paginating + defer.returnValue(chunk_data) + + @defer.inlineCallbacks + def get_room_member(self, room_id, member_user_id, auth_user_id): + """Retrieve a room member from a room. + + Args: + room_id : The room the member is in. + member_user_id : The member's user ID + auth_user_id : The user ID of the user making this request. + Returns: + The room member, or None if this member does not exist. + Raises: + SynapseError if something goes wrong. + """ + yield self.auth.check_joined_room(room_id, auth_user_id) + + member = yield self.store.get_room_member(user_id=member_user_id, + room_id=room_id) + defer.returnValue(member) + + @defer.inlineCallbacks + def change_membership(self, event=None, broadcast_msg=False, do_auth=True): + """ Change the membership status of a user in a room. + + Args: + event (SynapseEvent): The membership event + broadcast_msg (bool): True to inject a membership message into this + room on success. + Raises: + SynapseError if there was a problem changing the membership. + """ + + #broadcast_msg = False + + prev_state = yield self.store.get_room_member( + event.target_user_id, event.room_id + ) + + if prev_state and prev_state.membership == event.membership: + # treat this event as a NOOP. + if do_auth: # This is mainly to fix a unit test. + yield self.auth.check(event, raises=True) + defer.returnValue({}) + return + + room_id = event.room_id + + # If we're trying to join a room then we have to do this differently + # if this HS is not currently in the room, i.e. we have to do the + # invite/join dance. + if event.membership == Membership.JOIN: + yield self._do_join( + event, do_auth=do_auth, broadcast_msg=broadcast_msg + ) + else: + # This is not a JOIN, so we can handle it normally. + if do_auth: + yield self.auth.check(event, raises=True) + + prev_state = yield self.store.get_room_member( + event.target_user_id, event.room_id + ) + if prev_state and prev_state.membership == event.membership: + # double same action, treat this event as a NOOP. + defer.returnValue({}) + return + + yield self.state_handler.handle_new_event(event) + yield self._do_local_membership_update( + event, + membership=event.content["membership"], + broadcast_msg=broadcast_msg, + ) + + defer.returnValue({"room_id": room_id}) + + @defer.inlineCallbacks + def join_room_alias(self, joinee, room_alias, do_auth=True, content={}): + directory_handler = self.hs.get_handlers().directory_handler + mapping = yield directory_handler.get_association(room_alias) + + if not mapping: + raise SynapseError(404, "No such room alias") + + room_id = mapping["room_id"] + hosts = mapping["servers"] + if not hosts: + raise SynapseError(404, "No known servers") + + host = hosts[0] + + content.update({"membership": Membership.JOIN}) + new_event = self.event_factory.create_event( + etype=RoomMemberEvent.TYPE, + target_user_id=joinee.to_string(), + room_id=room_id, + user_id=joinee.to_string(), + membership=Membership.JOIN, + content=content, + ) + + yield self._do_join(new_event, room_host=host, do_auth=True) + + defer.returnValue({"room_id": room_id}) + + @defer.inlineCallbacks + def _do_join(self, event, room_host=None, do_auth=True, broadcast_msg=True): + joinee = self.hs.parse_userid(event.target_user_id) + # room_id = RoomID.from_string(event.room_id, self.hs) + room_id = event.room_id + + # If event doesn't include a display name, add one. + yield self._fill_out_join_content( + joinee, event.content + ) + + # XXX: We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + + room = yield self.store.get_room(room_id) + + if room: + should_do_dance = False + elif room_host: + should_do_dance = True + else: + prev_state = yield self.store.get_room_member( + joinee.to_string(), room_id + ) + + if prev_state and prev_state.membership == Membership.INVITE: + room = yield self.store.get_room(room_id) + inviter = UserID.from_string( + prev_state.sender, self.hs + ) + + should_do_dance = not inviter.is_mine and not room + room_host = inviter.domain + else: + should_do_dance = False + + # We want to do the _do_update inside the room lock. + if not should_do_dance: + logger.debug("Doing normal join") + + if do_auth: + yield self.auth.check(event, raises=True) + + yield self.state_handler.handle_new_event(event) + yield self._do_local_membership_update( + event, + membership=event.content["membership"], + broadcast_msg=broadcast_msg, + ) + + + if should_do_dance: + yield self._do_invite_join_dance( + room_id=room_id, + joinee=event.user_id, + target_host=room_host, + content=event.content, + ) + + user = self.hs.parse_userid(event.user_id) + self.distributor.fire( + "user_joined_room", user=user, room_id=room_id + ) + + @defer.inlineCallbacks + def _fill_out_join_content(self, user_id, content): + # If event doesn't include a display name, add one. + profile_handler = self.hs.get_handlers().profile_handler + if "displayname" not in content: + try: + display_name = yield profile_handler.get_displayname( + user_id + ) + + if display_name: + content["displayname"] = display_name + except: + logger.exception("Failed to set display_name") + + if "avatar_url" not in content: + try: + avatar_url = yield profile_handler.get_avatar_url( + user_id + ) + + if avatar_url: + content["avatar_url"] = avatar_url + except: + logger.exception("Failed to set display_name") + + @defer.inlineCallbacks + def _should_invite_join(self, room_id, prev_state, do_auth): + logger.debug("_should_invite_join: room_id: %s", room_id) + + # XXX: We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + + # Only do an invite join dance if a) we were invited, + # b) the person inviting was from a differnt HS and c) we are + # not currently in the room + room_host = None + if prev_state and prev_state.membership == Membership.INVITE: + room = yield self.store.get_room(room_id) + inviter = UserID.from_string( + prev_state.sender, self.hs + ) + + is_remote_invite_join = not inviter.is_mine and not room + room_host = inviter.domain + else: + is_remote_invite_join = False + + defer.returnValue((is_remote_invite_join, room_host)) + + @defer.inlineCallbacks + def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]): + """Returns a list of roomids that the user has any of the given + membership states in.""" + rooms = yield self.store.get_rooms_for_user_where_membership_is( + user_id=user.to_string(), membership_list=membership_list + ) + + defer.returnValue([r["room_id"] for r in rooms]) + + @defer.inlineCallbacks + def _do_local_membership_update(self, event, membership, broadcast_msg): + # store membership + store_id = yield self.store.store_room_member( + user_id=event.target_user_id, + sender=event.user_id, + room_id=event.room_id, + content=event.content, + membership=membership + ) + + # Send a PDU to all hosts who have joined the room. + destinations = yield self.store.get_joined_hosts_for_room( + event.room_id + ) + + # If we're inviting someone, then we should also send it to that + # HS. + if membership == Membership.INVITE: + host = UserID.from_string( + event.target_user_id, self.hs + ).domain + destinations.append(host) + + # If we are joining a remote HS, include that. + if membership == Membership.JOIN: + host = UserID.from_string( + event.target_user_id, self.hs + ).domain + destinations.append(host) + + event.destinations = list(set(destinations)) + + yield self.hs.get_federation().handle_new_event(event) + self.notifier.on_new_room_event(event, store_id) + + if broadcast_msg: + yield self._inject_membership_msg( + source=event.user_id, + target=event.target_user_id, + room_id=event.room_id, + membership=event.content["membership"] + ) + + @defer.inlineCallbacks + def _do_invite_join_dance(self, room_id, joinee, target_host, content): + logger.debug("Doing remote join dance") + + # do invite join dance + federation = self.hs.get_federation() + new_event = self.event_factory.create_event( + etype=InviteJoinEvent.TYPE, + target_host=target_host, + room_id=room_id, + user_id=joinee, + content=content + ) + + new_event.destinations = [target_host] + + yield self.store.store_room( + room_id, "", is_public=False + ) + + #yield self.state_handler.handle_new_event(event) + yield federation.handle_new_event(new_event) + yield federation.get_state_for_room( + target_host, room_id + ) + + @defer.inlineCallbacks + def _inject_membership_msg(self, room_id=None, source=None, target=None, + membership=None): + # TODO this should be a different type of message, not m.text + if membership == Membership.INVITE: + body = "%s invited %s to the room." % (source, target) + elif membership == Membership.JOIN: + body = "%s joined the room." % (target) + elif membership == Membership.LEAVE: + body = "%s left the room." % (target) + else: + raise RoomError(500, "Unknown membership value %s" % membership) + + membership_json = { + "msgtype": u"m.text", + "body": body, + "membership_source": source, + "membership_target": target, + "membership": membership, + } + + msg_id = "m%s" % int(self.clock.time_msec()) + + event = self.event_factory.create_event( + etype=MessageEvent.TYPE, + room_id=room_id, + user_id="_homeserver_", + msg_id=msg_id, + content=membership_json + ) + + handler = self.hs.get_handlers().message_handler + yield handler.send_message(event, suppress_auth=True) + + +class RoomListHandler(BaseHandler): + + @defer.inlineCallbacks + def get_public_room_list(self): + chunk = yield self.store.get_rooms(is_public=True, with_topics=True) + defer.returnValue({"start": "START", "end": "END", "chunk": chunk}) diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py new file mode 100644 index 0000000000..fe8a073cd3 --- /dev/null +++ b/synapse/http/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/http/client.py b/synapse/http/client.py new file mode 100644 index 0000000000..bb22b0ee9a --- /dev/null +++ b/synapse/http/client.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer, reactor +from twisted.web.client import _AgentBase, _URI, readBody +from twisted.web.http_headers import Headers + +from synapse.http.endpoint import matrix_endpoint +from synapse.util.async import sleep + +from syutil.jsonutil import encode_canonical_json + +from synapse.api.errors import CodeMessageException + +import json +import logging +import urllib + + +logger = logging.getLogger(__name__) + + +_destination_mappings = { + "red": "localhost:8080", + "blue": "localhost:8081", + "green": "localhost:8082", +} + + +class HttpClient(object): + """ Interface for talking json over http + """ + + def put_json(self, destination, path, data): + """ Sends the specifed json data using PUT + + Args: + destination (str): The remote server to send the HTTP request + to. + path (str): The HTTP path. + data (dict): A dict containing the data that will be used as + the request body. This will be encoded as JSON. + + Returns: + Deferred: Succeeds when we get *any* HTTP response. + + The result of the deferred is a tuple of `(code, response)`, + where `response` is a dict representing the decoded JSON body. + """ + pass + + def get_json(self, destination, path, args=None): + """ Get's some json from the given host homeserver and path + + Args: + destination (str): The remote server to send the HTTP request + to. + path (str): The HTTP path. + args (dict): A dictionary used to create query strings, defaults to + None. + **Note**: The value of each key is assumed to be an iterable + and *not* a string. + + Returns: + Deferred: Succeeds when we get *any* HTTP response. + + The result of the deferred is a tuple of `(code, response)`, + where `response` is a dict representing the decoded JSON body. + """ + pass + + +class MatrixHttpAgent(_AgentBase): + + def __init__(self, reactor, pool=None): + _AgentBase.__init__(self, reactor, pool) + + def request(self, destination, endpoint, method, path, params, query, + headers, body_producer): + + host = b"" + port = 0 + fragment = b"" + + parsed_URI = _URI(b"http", destination, host, port, path, params, + query, fragment) + + # Set the connection pool key to be the destination. + key = destination + + return self._requestWithEndpoint(key, endpoint, method, parsed_URI, + headers, body_producer, + parsed_URI.originForm) + + +class TwistedHttpClient(HttpClient): + """ Wrapper around the twisted HTTP client api. + + Attributes: + agent (twisted.web.client.Agent): The twisted Agent used to send the + requests. + """ + + def __init__(self): + self.agent = MatrixHttpAgent(reactor) + + @defer.inlineCallbacks + def put_json(self, destination, path, data): + if destination in _destination_mappings: + destination = _destination_mappings[destination] + + response = yield self._create_request( + destination.encode("ascii"), + "PUT", + path.encode("ascii"), + producer=_JsonProducer(data), + headers_dict={"Content-Type": ["application/json"]} + ) + + logger.debug("Getting resp body") + body = yield readBody(response) + logger.debug("Got resp body") + + defer.returnValue((response.code, body)) + + @defer.inlineCallbacks + def get_json(self, destination, path, args={}): + if destination in _destination_mappings: + destination = _destination_mappings[destination] + + logger.debug("get_json args: %s", args) + query_bytes = urllib.urlencode(args, True) + + response = yield self._create_request( + destination.encode("ascii"), + "GET", + path.encode("ascii"), + query_bytes + ) + + body = yield readBody(response) + + defer.returnValue(json.loads(body)) + + @defer.inlineCallbacks + def _create_request(self, destination, method, path_bytes, param_bytes=b"", + query_bytes=b"", producer=None, headers_dict={}): + """ Creates and sends a request to the given url + """ + headers_dict[b"User-Agent"] = [b"Synapse"] + headers_dict[b"Host"] = [destination] + + logger.debug("Sending request to %s: %s %s;%s?%s", + destination, method, path_bytes, param_bytes, query_bytes) + + logger.debug( + "Types: %s", + [ + type(destination), type(method), type(path_bytes), + type(param_bytes), + type(query_bytes) + ] + ) + + retries_left = 5 + + # TODO: setup and pass in an ssl_context to enable TLS + endpoint = matrix_endpoint(reactor, destination, timeout=10) + + while True: + try: + response = yield self.agent.request( + destination, + endpoint, + method, + path_bytes, + param_bytes, + query_bytes, + Headers(headers_dict), + producer + ) + + logger.debug("Got response to %s", method) + break + except Exception as e: + logger.exception("Got error in _create_request") + _print_ex(e) + + if retries_left: + yield sleep(2 ** (5 - retries_left)) + retries_left -= 1 + else: + raise + + if 200 <= response.code < 300: + # We need to update the transactions table to say it was sent? + pass + else: + # :'( + # Update transactions table? + logger.error( + "Got response %d %s", response.code, response.phrase + ) + raise CodeMessageException( + response.code, response.phrase + ) + + defer.returnValue(response) + + +def _print_ex(e): + if hasattr(e, "reasons") and e.reasons: + for ex in e.reasons: + _print_ex(ex) + else: + logger.exception(e) + + +class _JsonProducer(object): + """ Used by the twisted http client to create the HTTP body from json + """ + def __init__(self, jsn): + self.body = encode_canonical_json(jsn) + self.length = len(self.body) + + def startProducing(self, consumer): + consumer.write(self.body) + return defer.succeed(None) + + def pauseProducing(self): + pass + + def stopProducing(self): + pass diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py new file mode 100644 index 0000000000..c4e6e63a80 --- /dev/null +++ b/synapse/http/endpoint.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint +from twisted.internet import defer +from twisted.internet.error import ConnectError +from twisted.names import client, dns +from twisted.names.error import DNSNameError + +import collections +import logging +import random + + +logger = logging.getLogger(__name__) + + +def matrix_endpoint(reactor, destination, ssl_context_factory=None, + timeout=None): + """Construct an endpoint for the given matrix destination. + + Args: + reactor: Twisted reactor. + destination (bytes): The name of the server to connect to. + ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory + which generates SSL contexts to use for TLS. + timeout (int): connection timeout in seconds + """ + + domain_port = destination.split(":") + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + + endpoint_kw_args = {} + + if timeout is not None: + endpoint_kw_args.update(timeout=timeout) + + if ssl_context_factory is None: + transport_endpoint = TCP4ClientEndpoint + default_port = 8080 + else: + transport_endpoint = SSL4ClientEndpoint + endpoint_kw_args.update(ssl_context_factory=ssl_context_factory) + default_port = 443 + + if port is None: + return SRVClientEndpoint( + reactor, "matrix", domain, protocol="tcp", + default_port=default_port, endpoint=transport_endpoint, + endpoint_kw_args=endpoint_kw_args + ) + else: + return transport_endpoint(reactor, domain, port, **endpoint_kw_args) + + +class SRVClientEndpoint(object): + """An endpoint which looks up SRV records for a service. + Cycles through the list of servers starting with each call to connect + picking the next server. + Implements twisted.internet.interfaces.IStreamClientEndpoint. + """ + + _Server = collections.namedtuple( + "_Server", "priority weight host port" + ) + + def __init__(self, reactor, service, domain, protocol="tcp", + default_port=None, endpoint=TCP4ClientEndpoint, + endpoint_kw_args={}): + self.reactor = reactor + self.service_name = "_%s._%s.%s" % (service, protocol, domain) + + if default_port is not None: + self.default_server = self._Server( + host=domain, + port=default_port, + priority=0, + weight=0 + ) + else: + self.default_server = None + + self.endpoint = endpoint + self.endpoint_kw_args = endpoint_kw_args + + self.servers = None + self.used_servers = None + + @defer.inlineCallbacks + def fetch_servers(self): + try: + answers, auth, add = yield client.lookupService(self.service_name) + except DNSNameError: + answers = [] + + if (len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name('.')): + raise ConnectError("Service %s unavailable", self.service_name) + + self.servers = [] + self.used_servers = [] + + for answer in answers: + if answer.type != dns.SRV or not answer.payload: + continue + payload = answer.payload + self.servers.append(self._Server( + host=str(payload.target), + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight) + )) + + self.servers.sort() + + def pick_server(self): + if not self.servers: + if self.used_servers: + self.servers = self.used_servers + self.used_servers = [] + self.servers.sort() + elif self.default_server: + return self.default_server + else: + raise ConnectError( + "Not server available for %s", self.service_name + ) + + min_priority = self.servers[0].priority + weight_indexes = list( + (index, server.weight + 1) + for index, server in enumerate(self.servers) + if server.priority == min_priority + ) + + total_weight = sum(weight for index, weight in weight_indexes) + target_weight = random.randint(0, total_weight) + + for index, weight in weight_indexes: + target_weight -= weight + if target_weight <= 0: + server = self.servers[index] + del self.servers[index] + self.used_servers.append(server) + return server + + @defer.inlineCallbacks + def connect(self, protocolFactory): + if self.servers is None: + yield self.fetch_servers() + server = self.pick_server() + logger.info("Connecting to %s:%s", server.host, server.port) + endpoint = self.endpoint( + self.reactor, server.host, server.port, **self.endpoint_kw_args + ) + connection = yield endpoint.connect(protocolFactory) + defer.returnValue(connection) diff --git a/synapse/http/server.py b/synapse/http/server.py new file mode 100644 index 0000000000..8823aade78 --- /dev/null +++ b/synapse/http/server.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from syutil.jsonutil import ( + encode_canonical_json, encode_pretty_printed_json +) +from synapse.api.errors import cs_exception, CodeMessageException + +from twisted.internet import defer, reactor +from twisted.web import server, resource +from twisted.web.server import NOT_DONE_YET + +import collections +import logging + + +logger = logging.getLogger(__name__) + + +class HttpServer(object): + """ Interface for registering callbacks on a HTTP server + """ + + def register_path(self, method, path_pattern, callback): + """ Register a callback that get's fired if we receive a http request + with the given method for a path that matches the given regex. + + If the regex contains groups these get's passed to the calback via + an unpacked tuple. + + Args: + method (str): The method to listen to. + path_pattern (str): The regex used to match requests. + callback (function): The function to fire if we receive a matched + request. The first argument will be the request object and + subsequent arguments will be any matched groups from the regex. + This should return a tuple of (code, response). + """ + pass + + +# The actual HTTP server impl, using twisted http server +class TwistedHttpServer(HttpServer, resource.Resource): + """ This wraps the twisted HTTP server, and triggers the correct callbacks + on the transport_layer. + + Register callbacks via register_path() + """ + + isLeaf = True + + _PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"]) + + def __init__(self): + resource.Resource.__init__(self) + + self.path_regexs = {} + + def register_path(self, method, path_pattern, callback): + self.path_regexs.setdefault(method, []).append( + self._PathEntry(path_pattern, callback) + ) + + def start_listening(self, port): + """ Registers the http server with the twisted reactor. + + Args: + port (int): The port to listen on. + + """ + reactor.listenTCP(port, server.Site(self)) + + # Gets called by twisted + def render(self, request): + """ This get's called by twisted every time someone sends us a request. + """ + self._async_render(request) + return server.NOT_DONE_YET + + @defer.inlineCallbacks + def _async_render(self, request): + """ This get's called by twisted every time someone sends us a request. + This checks if anyone has registered a callback for that method and + path. + """ + try: + # Loop through all the registered callbacks to check if the method + # and path regex match + for path_entry in self.path_regexs.get(request.method, []): + m = path_entry.pattern.match(request.path) + if m: + # We found a match! Trigger callback and then return the + # returned response. We pass both the request and any + # matched groups from the regex to the callback. + code, response = yield path_entry.callback( + request, + *m.groups() + ) + + self._send_response(request, code, response) + return + + # Huh. No one wanted to handle that? Fiiiiiine. Send 400. + self._send_response( + request, + 400, + {"error": "Unrecognized request"} + ) + except CodeMessageException as e: + logger.exception(e) + self._send_response( + request, + e.code, + cs_exception(e) + ) + except Exception as e: + logger.exception(e) + self._send_response( + request, + 500, + {"error": "Internal server error"} + ) + + def _send_response(self, request, code, response_json_object): + + if not self._request_user_agent_is_curl(request): + json_bytes = encode_canonical_json(response_json_object) + else: + json_bytes = encode_pretty_printed_json(response_json_object) + + # TODO: Only enable CORS for the requests that need it. + respond_with_json_bytes(request, code, json_bytes, send_cors=True) + + @staticmethod + def _request_user_agent_is_curl(request): + user_agents = request.requestHeaders.getRawHeaders( + "User-Agent", default=[] + ) + for user_agent in user_agents: + if "curl" in user_agent: + return True + return False + + +def respond_with_json_bytes(request, code, json_bytes, send_cors=False): + """Sends encoded JSON in response to the given request. + + Args: + request (twisted.web.http.Request): The http request to respond to. + code (int): The HTTP response code. + json_bytes (bytes): The json bytes to use as the response body. + send_cors (bool): Whether to send Cross-Origin Resource Sharing headers + http://www.w3.org/TR/cors/ + Returns: + twisted.web.server.NOT_DONE_YET""" + + request.setResponseCode(code) + request.setHeader(b"Content-Type", b"application/json") + + if send_cors: + request.setHeader("Access-Control-Allow-Origin", "*") + request.setHeader("Access-Control-Allow-Methods", + "GET, POST, PUT, DELETE, OPTIONS") + request.setHeader("Access-Control-Allow-Headers", + "Origin, X-Requested-With, Content-Type, Accept") + + request.write(json_bytes) + request.finish() + return NOT_DONE_YET diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py new file mode 100644 index 0000000000..5598295793 --- /dev/null +++ b/synapse/rest/__init__.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import ( + room, events, register, profile, public, presence, im, directory +) + +class RestServletFactory(object): + + """ A factory for creating REST servlets. + + These REST servlets represent the entire client-server REST API. Generally + speaking, they serve as wrappers around events and the handlers that + process them. + + See synapse.api.events for information on synapse events. + """ + + def __init__(self, hs): + http_server = hs.get_http_server() + + # TODO(erikj): There *must* be a better way of doing this. + room.register_servlets(hs, http_server) + events.register_servlets(hs, http_server) + register.register_servlets(hs, http_server) + profile.register_servlets(hs, http_server) + public.register_servlets(hs, http_server) + presence.register_servlets(hs, http_server) + im.register_servlets(hs, http_server) + directory.register_servlets(hs, http_server) + + diff --git a/synapse/rest/base.py b/synapse/rest/base.py new file mode 100644 index 0000000000..d90ac611fe --- /dev/null +++ b/synapse/rest/base.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" This module contains base REST classes for constructing REST servlets. """ +import re + + +def client_path_pattern(path_regex): + """Creates a regex compiled client path with the correct client path + prefix. + + Args: + path_regex (str): The regex string to match. This should NOT have a ^ + as this will be prefixed. + Returns: + SRE_Pattern + """ + return re.compile("^/matrix/client/api/v1" + path_regex) + + +class RestServletFactory(object): + + """ A factory for creating REST servlets. + + These REST servlets represent the entire client-server REST API. Generally + speaking, they serve as wrappers around events and the handlers that + process them. + + See synapse.api.events for information on synapse events. + """ + + def __init__(self, hs): + http_server = hs.get_http_server() + + # You get import errors if you try to import before the classes in this + # file are defined, hence importing here instead. + + import room + room.register_servlets(hs, http_server) + + import events + events.register_servlets(hs, http_server) + + import register + register.register_servlets(hs, http_server) + + import profile + profile.register_servlets(hs, http_server) + + import public + public.register_servlets(hs, http_server) + + import presence + presence.register_servlets(hs, http_server) + + import im + im.register_servlets(hs, http_server) + + import login + login.register_servlets(hs, http_server) + + +class RestServlet(object): + + """ A Synapse REST Servlet. + + An implementing class can either provide its own custom 'register' method, + or use the automatic pattern handling provided by the base class. + + To use this latter, the implementing class instead provides a `PATTERN` + class attribute containing a pre-compiled regular expression. The automatic + register method will then use this method to register any of the following + instance methods associated with the corresponding HTTP method: + + on_GET + on_PUT + on_POST + on_DELETE + on_OPTIONS + + Automatically handles turning CodeMessageExceptions thrown by these methods + into the appropriate HTTP response. + """ + + def __init__(self, hs): + self.hs = hs + + self.handlers = hs.get_handlers() + self.event_factory = hs.get_event_factory() + self.auth = hs.get_auth() + + def register(self, http_server): + """ Register this servlet with the given HTTP server. """ + if hasattr(self, "PATTERN"): + pattern = self.PATTERN + + for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): + if hasattr(self, "on_%s" % (method)): + method_handler = getattr(self, "on_%s" % (method)) + http_server.register_path(method, pattern, method_handler) + else: + raise NotImplementedError("RestServlet must register something.") diff --git a/synapse/rest/directory.py b/synapse/rest/directory.py new file mode 100644 index 0000000000..a426003a38 --- /dev/null +++ b/synapse/rest/directory.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.types import RoomAlias, RoomID +from base import RestServlet, client_path_pattern + +import json +import logging +import urllib + + +logger = logging.getLogger(__name__) + + +def register_servlets(hs, http_server): + ClientDirectoryServer(hs).register(http_server) + + +class ClientDirectoryServer(RestServlet): + PATTERN = client_path_pattern("/ds/room/(?P<room_alias>[^/]*)$") + + @defer.inlineCallbacks + def on_GET(self, request, room_alias): + # TODO(erikj): Handle request + local_only = "local_only" in request.args + + room_alias = urllib.unquote(room_alias) + room_alias_obj = RoomAlias.from_string(room_alias, self.hs) + + dir_handler = self.handlers.directory_handler + res = yield dir_handler.get_association( + room_alias_obj, + local_only=local_only + ) + + defer.returnValue((200, res)) + + @defer.inlineCallbacks + def on_PUT(self, request, room_alias): + # TODO(erikj): Exceptions + content = json.loads(request.content.read()) + + logger.debug("Got content: %s", content) + + room_alias = urllib.unquote(room_alias) + room_alias_obj = RoomAlias.from_string(room_alias, self.hs) + + logger.debug("Got room name: %s", room_alias_obj.to_string()) + + room_id = content["room_id"] + servers = content["servers"] + + logger.debug("Got room_id: %s", room_id) + logger.debug("Got servers: %s", servers) + + # TODO(erikj): Check types. + # TODO(erikj): Check that room exists + + dir_handler = self.handlers.directory_handler + + try: + yield dir_handler.create_association( + room_alias_obj, room_id, servers + ) + except: + logger.exception("Failed to create association") + + defer.returnValue((200, {})) diff --git a/synapse/rest/events.py b/synapse/rest/events.py new file mode 100644 index 0000000000..147257a940 --- /dev/null +++ b/synapse/rest/events.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains REST servlets to do with event streaming, /events.""" +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.api.streams import PaginationConfig +from synapse.rest.base import RestServlet, client_path_pattern + + +class EventStreamRestServlet(RestServlet): + PATTERN = client_path_pattern("/events$") + + DEFAULT_LONGPOLL_TIME_MS = 5000 + + @defer.inlineCallbacks + def on_GET(self, request): + auth_user = yield self.auth.get_user_by_req(request) + + handler = self.handlers.event_stream_handler + pagin_config = PaginationConfig.from_request(request) + timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS + if "timeout" in request.args: + try: + timeout = int(request.args["timeout"][0]) + except ValueError: + raise SynapseError(400, "timeout must be in milliseconds.") + + chunk = yield handler.get_stream(auth_user.to_string(), pagin_config, + timeout=timeout) + defer.returnValue((200, chunk)) + + def on_OPTIONS(self, request): + return (200, {}) + + +def register_servlets(hs, http_server): + EventStreamRestServlet(hs).register(http_server) diff --git a/synapse/rest/im.py b/synapse/rest/im.py new file mode 100644 index 0000000000..39f2dbd749 --- /dev/null +++ b/synapse/rest/im.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.api.streams import PaginationConfig +from base import RestServlet, client_path_pattern + + +class ImSyncRestServlet(RestServlet): + PATTERN = client_path_pattern("/im/sync$") + + @defer.inlineCallbacks + def on_GET(self, request): + user = yield self.auth.get_user_by_req(request) + with_feedback = "feedback" in request.args + pagination_config = PaginationConfig.from_request(request) + handler = self.handlers.message_handler + content = yield handler.snapshot_all_rooms( + user_id=user.to_string(), + pagin_config=pagination_config, + feedback=with_feedback) + + defer.returnValue((200, content)) + + +def register_servlets(hs, http_server): + ImSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/login.py b/synapse/rest/login.py new file mode 100644 index 0000000000..0284e125b4 --- /dev/null +++ b/synapse/rest/login.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from base import RestServlet, client_path_pattern + +import json + + +class LoginRestServlet(RestServlet): + PATTERN = client_path_pattern("/login$") + PASS_TYPE = "m.login.password" + + def on_GET(self, request): + return (200, {"type": LoginRestServlet.PASS_TYPE}) + + def on_OPTIONS(self, request): + return (200, {}) + + @defer.inlineCallbacks + def on_POST(self, request): + login_submission = _parse_json(request) + try: + if login_submission["type"] == LoginRestServlet.PASS_TYPE: + result = yield self.do_password_login(login_submission) + defer.returnValue(result) + else: + raise SynapseError(400, "Bad login type.") + except KeyError: + raise SynapseError(400, "Missing JSON keys.") + + @defer.inlineCallbacks + def do_password_login(self, login_submission): + handler = self.handlers.login_handler + token = yield handler.login( + user=login_submission["user"], + password=login_submission["password"]) + + result = { + "access_token": token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + +class LoginFallbackRestServlet(RestServlet): + PATTERN = client_path_pattern("/login/fallback$") + + def on_GET(self, request): + # TODO(kegan): This should be returning some HTML which is capable of + # hitting LoginRestServlet + return (200, "") + + +def _parse_json(request): + try: + content = json.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.") + return content + except ValueError: + raise SynapseError(400, "Content not JSON.") + + +def register_servlets(hs, http_server): + LoginRestServlet(hs).register(http_server) diff --git a/synapse/rest/presence.py b/synapse/rest/presence.py new file mode 100644 index 0000000000..e4925c20a5 --- /dev/null +++ b/synapse/rest/presence.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" This module contains REST servlets to do with presence: /presence/<paths> +""" +from twisted.internet import defer + +from base import RestServlet, client_path_pattern + +import json +import logging + + +logger = logging.getLogger(__name__) + + +class PresenceStatusRestServlet(RestServlet): + PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status") + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + auth_user = yield self.auth.get_user_by_req(request) + user = self.hs.parse_userid(user_id) + + state = yield self.handlers.presence_handler.get_state( + target_user=user, auth_user=auth_user) + + defer.returnValue((200, state)) + + @defer.inlineCallbacks + def on_PUT(self, request, user_id): + auth_user = yield self.auth.get_user_by_req(request) + user = self.hs.parse_userid(user_id) + + state = {} + try: + content = json.loads(request.content.read()) + + state["state"] = content.pop("state") + + if "status_msg" in content: + state["status_msg"] = content.pop("status_msg") + + if content: + raise KeyError() + except: + defer.returnValue((400, "Unable to parse state")) + + yield self.handlers.presence_handler.set_state( + target_user=user, auth_user=auth_user, state=state) + + defer.returnValue((200, "")) + + def on_OPTIONS(self, request): + return (200, {}) + + +class PresenceListRestServlet(RestServlet): + PATTERN = client_path_pattern("/presence_list/(?P<user_id>[^/]*)") + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + auth_user = yield self.auth.get_user_by_req(request) + user = self.hs.parse_userid(user_id) + + if not user.is_mine: + defer.returnValue((400, "User not hosted on this Home Server")) + + if auth_user != user: + defer.returnValue((400, "Cannot get another user's presence list")) + + presence = yield self.handlers.presence_handler.get_presence_list( + observer_user=user, accepted=True) + + for p in presence: + observed_user = p.pop("observed_user") + p["user_id"] = observed_user.to_string() + + defer.returnValue((200, presence)) + + @defer.inlineCallbacks + def on_POST(self, request, user_id): + auth_user = yield self.auth.get_user_by_req(request) + user = self.hs.parse_userid(user_id) + + if not user.is_mine: + defer.returnValue((400, "User not hosted on this Home Server")) + + if auth_user != user: + defer.returnValue(( + 400, "Cannot modify another user's presence list")) + + try: + content = json.loads(request.content.read()) + except: + logger.exception("JSON parse error") + defer.returnValue((400, "Unable to parse content")) + + deferreds = [] + + if "invite" in content: + for u in content["invite"]: + invited_user = self.hs.parse_userid(u) + deferreds.append(self.handlers.presence_handler.send_invite( + observer_user=user, observed_user=invited_user)) + + if "drop" in content: + for u in content["drop"]: + dropped_user = self.hs.parse_userid(u) + deferreds.append(self.handlers.presence_handler.drop( + observer_user=user, observed_user=dropped_user)) + + yield defer.DeferredList(deferreds) + + defer.returnValue((200, "")) + + def on_OPTIONS(self, request): + return (200, {}) + + +def register_servlets(hs, http_server): + PresenceStatusRestServlet(hs).register(http_server) + PresenceListRestServlet(hs).register(http_server) diff --git a/synapse/rest/profile.py b/synapse/rest/profile.py new file mode 100644 index 0000000000..f384227c29 --- /dev/null +++ b/synapse/rest/profile.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" This module contains REST servlets to do with profile: /profile/<paths> """ +from twisted.internet import defer + +from base import RestServlet, client_path_pattern + +import json + + +class ProfileDisplaynameRestServlet(RestServlet): + PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname") + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + user = self.hs.parse_userid(user_id) + + displayname = yield self.handlers.profile_handler.get_displayname( + user, + local_only="local_only" in request.args + ) + + defer.returnValue((200, {"displayname": displayname})) + + @defer.inlineCallbacks + def on_PUT(self, request, user_id): + auth_user = yield self.auth.get_user_by_req(request) + user = self.hs.parse_userid(user_id) + + try: + content = json.loads(request.content.read()) + new_name = content["displayname"] + except: + defer.returnValue((400, "Unable to parse name")) + + yield self.handlers.profile_handler.set_displayname( + user, auth_user, new_name) + + defer.returnValue((200, "")) + + def on_OPTIONS(self, request, user_id): + return (200, {}) + + +class ProfileAvatarURLRestServlet(RestServlet): + PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url") + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + user = self.hs.parse_userid(user_id) + + avatar_url = yield self.handlers.profile_handler.get_avatar_url( + user, + local_only="local_only" in request.args + ) + + defer.returnValue((200, {"avatar_url": avatar_url})) + + @defer.inlineCallbacks + def on_PUT(self, request, user_id): + auth_user = yield self.auth.get_user_by_req(request) + user = self.hs.parse_userid(user_id) + + try: + content = json.loads(request.content.read()) + new_name = content["avatar_url"] + except: + defer.returnValue((400, "Unable to parse name")) + + yield self.handlers.profile_handler.set_avatar_url( + user, auth_user, new_name) + + defer.returnValue((200, "")) + + def on_OPTIONS(self, request, user_id): + return (200, {}) + + +def register_servlets(hs, http_server): + ProfileDisplaynameRestServlet(hs).register(http_server) + ProfileAvatarURLRestServlet(hs).register(http_server) diff --git a/synapse/rest/public.py b/synapse/rest/public.py new file mode 100644 index 0000000000..6fd1731a61 --- /dev/null +++ b/synapse/rest/public.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains REST servlets to do with public paths: /public""" +from twisted.internet import defer + +from base import RestServlet, client_path_pattern + + +class PublicRoomListRestServlet(RestServlet): + PATTERN = client_path_pattern("/public/rooms$") + + @defer.inlineCallbacks + def on_GET(self, request): + handler = self.handlers.room_list_handler + data = yield handler.get_public_room_list() + defer.returnValue((200, data)) + + +def register_servlets(hs, http_server): + PublicRoomListRestServlet(hs).register(http_server) diff --git a/synapse/rest/register.py b/synapse/rest/register.py new file mode 100644 index 0000000000..f1cbce5c67 --- /dev/null +++ b/synapse/rest/register.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains REST servlets to do with registration: /register""" +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from base import RestServlet, client_path_pattern + +import json +import urllib + + +class RegisterRestServlet(RestServlet): + PATTERN = client_path_pattern("/register$") + + @defer.inlineCallbacks + def on_POST(self, request): + desired_user_id = None + password = None + try: + register_json = json.loads(request.content.read()) + if "password" in register_json: + password = register_json["password"] + + if type(register_json["user_id"]) == unicode: + desired_user_id = register_json["user_id"] + if urllib.quote(desired_user_id) != desired_user_id: + raise SynapseError( + 400, + "User ID must only contain characters which do not " + + "require URL encoding.") + except ValueError: + defer.returnValue((400, "No JSON object.")) + except KeyError: + pass # user_id is optional + + handler = self.handlers.registration_handler + (user_id, token) = yield handler.register( + localpart=desired_user_id, + password=password) + + result = { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + } + defer.returnValue( + (200, result) + ) + + def on_OPTIONS(self, request): + return (200, {}) + + +def register_servlets(hs, http_server): + RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/room.py b/synapse/rest/room.py new file mode 100644 index 0000000000..c96de5e65d --- /dev/null +++ b/synapse/rest/room.py @@ -0,0 +1,394 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" This module contains REST servlets to do with rooms: /rooms/<paths> """ +from twisted.internet import defer + +from base import RestServlet, client_path_pattern +from synapse.api.errors import SynapseError, Codes +from synapse.api.events.room import (RoomTopicEvent, MessageEvent, + RoomMemberEvent, FeedbackEvent) +from synapse.api.constants import Feedback, Membership +from synapse.api.streams import PaginationConfig +from synapse.types import RoomAlias + +import json +import logging +import urllib + + +logger = logging.getLogger(__name__) + + +class RoomCreateRestServlet(RestServlet): + # No PATTERN; we have custom dispatch rules here + + def register(self, http_server): + # /rooms OR /rooms/<roomid> + http_server.register_path("POST", + client_path_pattern("/rooms$"), + self.on_POST) + http_server.register_path("PUT", + client_path_pattern( + "/rooms/(?P<room_id>[^/]*)$"), + self.on_PUT) + # define CORS for all of /rooms in RoomCreateRestServlet for simplicity + http_server.register_path("OPTIONS", + client_path_pattern("/rooms(?:/.*)?$"), + self.on_OPTIONS) + + @defer.inlineCallbacks + def on_PUT(self, request, room_id): + room_id = urllib.unquote(room_id) + auth_user = yield self.auth.get_user_by_req(request) + + if not room_id: + raise SynapseError(400, "PUT must specify a room ID") + + room_config = self.get_room_config(request) + info = yield self.make_room(room_config, auth_user, room_id) + room_config.update(info) + defer.returnValue((200, info)) + + @defer.inlineCallbacks + def on_POST(self, request): + auth_user = yield self.auth.get_user_by_req(request) + + room_config = self.get_room_config(request) + info = yield self.make_room(room_config, auth_user, None) + room_config.update(info) + defer.returnValue((200, info)) + + @defer.inlineCallbacks + def make_room(self, room_config, auth_user, room_id): + handler = self.handlers.room_creation_handler + info = yield handler.create_room( + user_id=auth_user.to_string(), + room_id=room_id, + config=room_config + ) + defer.returnValue(info) + + def get_room_config(self, request): + try: + user_supplied_config = json.loads(request.content.read()) + if "visibility" not in user_supplied_config: + # default visibility + user_supplied_config["visibility"] = "public" + return user_supplied_config + except (ValueError, TypeError): + raise SynapseError(400, "Body must be JSON.", + errcode=Codes.BAD_JSON) + + def on_OPTIONS(self, request): + return (200, {}) + + +class RoomTopicRestServlet(RestServlet): + PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/topic$") + + def get_event_type(self): + return RoomTopicEvent.TYPE + + @defer.inlineCallbacks + def on_GET(self, request, room_id): + user = yield self.auth.get_user_by_req(request) + + msg_handler = self.handlers.message_handler + data = yield msg_handler.get_room_data( + user_id=user.to_string(), + room_id=urllib.unquote(room_id), + event_type=RoomTopicEvent.TYPE, + state_key="", + ) + + if not data: + raise SynapseError(404, "Topic not found.", errcode=Codes.NOT_FOUND) + defer.returnValue((200, json.loads(data.content))) + + @defer.inlineCallbacks + def on_PUT(self, request, room_id): + user = yield self.auth.get_user_by_req(request) + + content = _parse_json(request) + + event = self.event_factory.create_event( + etype=self.get_event_type(), + content=content, + room_id=urllib.unquote(room_id), + user_id=user.to_string(), + ) + + msg_handler = self.handlers.message_handler + yield msg_handler.store_room_data( + event=event + ) + defer.returnValue((200, "")) + + +class JoinRoomAliasServlet(RestServlet): + PATTERN = client_path_pattern("/join/(?P<room_alias>[^/]+)$") + + @defer.inlineCallbacks + def on_PUT(self, request, room_alias): + user = yield self.auth.get_user_by_req(request) + + if not user: + defer.returnValue((403, "Unrecognized user")) + + logger.debug("room_alias: %s", room_alias) + + room_alias = RoomAlias.from_string( + urllib.unquote(room_alias), + self.hs + ) + + handler = self.handlers.room_member_handler + ret_dict = yield handler.join_room_alias(user, room_alias) + + defer.returnValue((200, ret_dict)) + + +class RoomMemberRestServlet(RestServlet): + PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members/" + + "(?P<target_user_id>[^/]*)/state$") + + def get_event_type(self): + return RoomMemberEvent.TYPE + + @defer.inlineCallbacks + def on_GET(self, request, room_id, target_user_id): + room_id = urllib.unquote(room_id) + user = yield self.auth.get_user_by_req(request) + + handler = self.handlers.room_member_handler + member = yield handler.get_room_member(room_id, target_user_id, + user.to_string()) + if not member: + raise SynapseError(404, "Member not found.", + errcode=Codes.NOT_FOUND) + defer.returnValue((200, json.loads(member.content))) + + @defer.inlineCallbacks + def on_DELETE(self, request, roomid, target_user_id): + user = yield self.auth.get_user_by_req(request) + + event = self.event_factory.create_event( + etype=self.get_event_type(), + target_user_id=target_user_id, + room_id=urllib.unquote(roomid), + user_id=user.to_string(), + membership=Membership.LEAVE, + content={"membership": Membership.LEAVE} + ) + + handler = self.handlers.room_member_handler + yield handler.change_membership(event, broadcast_msg=True) + defer.returnValue((200, "")) + + @defer.inlineCallbacks + def on_PUT(self, request, roomid, target_user_id): + user = yield self.auth.get_user_by_req(request) + + content = _parse_json(request) + if "membership" not in content: + raise SynapseError(400, "No membership key.", + errcode=Codes.BAD_JSON) + + valid_membership_values = [Membership.JOIN, Membership.INVITE] + if (content["membership"] not in valid_membership_values): + raise SynapseError(400, "Membership value must be %s." % ( + valid_membership_values,), errcode=Codes.BAD_JSON) + + event = self.event_factory.create_event( + etype=self.get_event_type(), + target_user_id=target_user_id, + room_id=urllib.unquote(roomid), + user_id=user.to_string(), + membership=content["membership"], + content=content + ) + + handler = self.handlers.room_member_handler + result = yield handler.change_membership(event, broadcast_msg=True) + defer.returnValue((200, result)) + + +class MessageRestServlet(RestServlet): + PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages/" + + "(?P<sender_id>[^/]*)/(?P<msg_id>[^/]*)$") + + def get_event_type(self): + return MessageEvent.TYPE + + @defer.inlineCallbacks + def on_GET(self, request, room_id, sender_id, msg_id): + user = yield self.auth.get_user_by_req(request) + + msg_handler = self.handlers.message_handler + msg = yield msg_handler.get_message(room_id=urllib.unquote(room_id), + sender_id=sender_id, + msg_id=msg_id, + user_id=user.to_string(), + ) + + if not msg: + raise SynapseError(404, "Message not found.", + errcode=Codes.NOT_FOUND) + + defer.returnValue((200, json.loads(msg.content))) + + @defer.inlineCallbacks + def on_PUT(self, request, room_id, sender_id, msg_id): + user = yield self.auth.get_user_by_req(request) + + if user.to_string() != sender_id: + raise SynapseError(403, "Must send messages as yourself.", + errcode=Codes.FORBIDDEN) + + content = _parse_json(request) + + event = self.event_factory.create_event( + etype=self.get_event_type(), + room_id=urllib.unquote(room_id), + user_id=user.to_string(), + msg_id=msg_id, + content=content + ) + + msg_handler = self.handlers.message_handler + yield msg_handler.send_message(event) + + defer.returnValue((200, "")) + + +class FeedbackRestServlet(RestServlet): + PATTERN = client_path_pattern( + "/rooms/(?P<room_id>[^/]*)/messages/" + + "(?P<msg_sender_id>[^/]*)/(?P<msg_id>[^/]*)/feedback/" + + "(?P<sender_id>[^/]*)/(?P<feedback_type>[^/]*)$" + ) + + def get_event_type(self): + return FeedbackEvent.TYPE + + @defer.inlineCallbacks + def on_GET(self, request, room_id, msg_sender_id, msg_id, fb_sender_id, + feedback_type): + user = yield (self.auth.get_user_by_req(request)) + + if feedback_type not in Feedback.LIST: + raise SynapseError(400, "Bad feedback type.", + errcode=Codes.BAD_JSON) + + msg_handler = self.handlers.message_handler + feedback = yield msg_handler.get_feedback( + room_id=urllib.unquote(room_id), + msg_sender_id=msg_sender_id, + msg_id=msg_id, + user_id=user.to_string(), + fb_sender_id=fb_sender_id, + fb_type=feedback_type + ) + + if not feedback: + raise SynapseError(404, "Feedback not found.", + errcode=Codes.NOT_FOUND) + + defer.returnValue((200, json.loads(feedback.content))) + + @defer.inlineCallbacks + def on_PUT(self, request, room_id, sender_id, msg_id, fb_sender_id, + feedback_type): + user = yield (self.auth.get_user_by_req(request)) + + if user.to_string() != fb_sender_id: + raise SynapseError(403, "Must send feedback as yourself.", + errcode=Codes.FORBIDDEN) + + if feedback_type not in Feedback.LIST: + raise SynapseError(400, "Bad feedback type.", + errcode=Codes.BAD_JSON) + + content = _parse_json(request) + + event = self.event_factory.create_event( + etype=self.get_event_type(), + room_id=urllib.unquote(room_id), + msg_sender_id=sender_id, + msg_id=msg_id, + user_id=user.to_string(), # user sending the feedback + feedback_type=feedback_type, + content=content + ) + + msg_handler = self.handlers.message_handler + yield msg_handler.send_feedback(event) + + defer.returnValue((200, "")) + + +class RoomMemberListRestServlet(RestServlet): + PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members/list$") + + @defer.inlineCallbacks + def on_GET(self, request, room_id): + # TODO support Pagination stream API (limit/tokens) + user = yield self.auth.get_user_by_req(request) + handler = self.handlers.room_member_handler + members = yield handler.get_room_members_as_pagination_chunk( + room_id=urllib.unquote(room_id), + user_id=user.to_string()) + + defer.returnValue((200, members)) + + +class RoomMessageListRestServlet(RestServlet): + PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages/list$") + + @defer.inlineCallbacks + def on_GET(self, request, room_id): + user = yield self.auth.get_user_by_req(request) + pagination_config = PaginationConfig.from_request(request) + with_feedback = "feedback" in request.args + handler = self.handlers.message_handler + msgs = yield handler.get_messages( + room_id=urllib.unquote(room_id), + user_id=user.to_string(), + pagin_config=pagination_config, + feedback=with_feedback) + + defer.returnValue((200, msgs)) + + +def _parse_json(request): + try: + content = json.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.", + errcode=Codes.NOT_JSON) + return content + except ValueError: + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + +def register_servlets(hs, http_server): + RoomTopicRestServlet(hs).register(http_server) + RoomMemberRestServlet(hs).register(http_server) + MessageRestServlet(hs).register(http_server) + FeedbackRestServlet(hs).register(http_server) + RoomCreateRestServlet(hs).register(http_server) + RoomMemberListRestServlet(hs).register(http_server) + RoomMessageListRestServlet(hs).register(http_server) + JoinRoomAliasServlet(hs).register(http_server) diff --git a/synapse/server.py b/synapse/server.py new file mode 100644 index 0000000000..0aff75f399 --- /dev/null +++ b/synapse/server.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file provides some classes for setting up (partially-populated) +# homeservers; either as a full homeserver as a real application, or a small +# partial one for unit test mocking. + +# Imports required for the default HomeServer() implementation +from synapse.federation import initialize_http_replication +from synapse.federation.handler import FederationEventHandler +from synapse.api.events.factory import EventFactory +from synapse.api.notifier import Notifier +from synapse.api.auth import Auth +from synapse.handlers import Handlers +from synapse.rest import RestServletFactory +from synapse.state import StateHandler +from synapse.storage import DataStore +from synapse.types import UserID +from synapse.util import Clock +from synapse.util.distributor import Distributor +from synapse.util.lockutils import LockManager + + +class BaseHomeServer(object): + """A basic homeserver object without lazy component builders. + + This will need all of the components it requires to either be passed as + constructor arguments, or the relevant methods overriding to create them. + Typically this would only be used for unit tests. + + For every dependency in the DEPENDENCIES list below, this class creates one + method, + def get_DEPENDENCY(self) + which returns the value of that dependency. If no value has yet been set + nor was provided to the constructor, it will attempt to call a lazy builder + method called + def build_DEPENDENCY(self) + which must be implemented by the subclass. This code may call any of the + required "get" methods on the instance to obtain the sub-dependencies that + one requires. + """ + + DEPENDENCIES = [ + 'clock', + 'http_server', + 'http_client', + 'db_pool', + 'persistence_service', + 'federation', + 'replication_layer', + 'datastore', + 'event_factory', + 'handlers', + 'auth', + 'rest_servlet_factory', + 'state_handler', + 'room_lock_manager', + 'notifier', + 'distributor', + ] + + def __init__(self, hostname, **kwargs): + """ + Args: + hostname : The hostname for the server. + """ + self.hostname = hostname + self._building = {} + + # Other kwargs are explicit dependencies + for depname in kwargs: + setattr(self, depname, kwargs[depname]) + + @classmethod + def _make_dependency_method(cls, depname): + def _get(self): + if hasattr(self, depname): + return getattr(self, depname) + + if hasattr(self, "build_%s" % (depname)): + # Prevent cyclic dependencies from deadlocking + if depname in self._building: + raise ValueError("Cyclic dependency while building %s" % ( + depname, + )) + self._building[depname] = 1 + + builder = getattr(self, "build_%s" % (depname)) + dep = builder() + setattr(self, depname, dep) + + del self._building[depname] + + return dep + + raise NotImplementedError( + "%s has no %s nor a builder for it" % ( + type(self).__name__, depname, + ) + ) + + setattr(BaseHomeServer, "get_%s" % (depname), _get) + + # Other utility methods + def parse_userid(self, s): + """Parse the string given by 's' as a User ID and return a UserID + object.""" + return UserID.from_string(s, hs=self) + +# Build magic accessors for every dependency +for depname in BaseHomeServer.DEPENDENCIES: + BaseHomeServer._make_dependency_method(depname) + + +class HomeServer(BaseHomeServer): + """A homeserver object that will construct most of its dependencies as + required. + + It still requires the following to be specified by the caller: + http_server + http_client + db_pool + """ + + def build_clock(self): + return Clock() + + def build_replication_layer(self): + return initialize_http_replication(self) + + def build_federation(self): + return FederationEventHandler(self) + + def build_datastore(self): + return DataStore(self) + + def build_event_factory(self): + return EventFactory() + + def build_handlers(self): + return Handlers(self) + + def build_notifier(self): + return Notifier(self) + + def build_auth(self): + return Auth(self) + + def build_rest_servlet_factory(self): + return RestServletFactory(self) + + def build_state_handler(self): + return StateHandler(self) + + def build_room_lock_manager(self): + return LockManager() + + def build_distributor(self): + return Distributor() + + def register_servlets(self): + """Simply building the ServletFactory is sufficient to have it + register.""" + self.get_rest_servlet_factory() diff --git a/synapse/state.py b/synapse/state.py new file mode 100644 index 0000000000..439c0b519a --- /dev/null +++ b/synapse/state.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.federation.pdu_codec import encode_event_id +from synapse.util.logutils import log_function + +from collections import namedtuple + +import logging +import hashlib + +logger = logging.getLogger(__name__) + + +def _get_state_key_from_event(event): + return event.state_key + + +KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) + + +class StateHandler(object): + """ Repsonsible for doing state conflict resolution. + """ + + def __init__(self, hs): + self.store = hs.get_datastore() + self._replication = hs.get_replication_layer() + self.server_name = hs.hostname + + @defer.inlineCallbacks + @log_function + def handle_new_event(self, event): + """ Given an event this works out if a) we have sufficient power level + to update the state and b) works out what the prev_state should be. + + Returns: + Deferred: Resolved with a boolean indicating if we succesfully + updated the state. + + Raised: + AuthError + """ + # This needs to be done in a transaction. + + if not hasattr(event, "state_key"): + return + + key = KeyStateTuple( + event.room_id, + event.type, + _get_state_key_from_event(event) + ) + + # Now I need to fill out the prev state and work out if it has auth + # (w.r.t. to power levels) + + results = yield self.store.get_latest_pdus_in_context( + event.room_id + ) + + event.prev_events = [ + encode_event_id(p_id, origin) for p_id, origin, _ in results + ] + event.prev_events = [ + e for e in event.prev_events if e != event.event_id + ] + + if results: + event.depth = max([int(v) for _, _, v in results]) + 1 + else: + event.depth = 0 + + current_state = yield self.store.get_current_state( + key.context, key.type, key.state_key + ) + + if current_state: + event.prev_state = encode_event_id( + current_state.pdu_id, current_state.origin + ) + + # TODO check current_state to see if the min power level is less + # than the power level of the user + # power_level = self._get_power_level_for_event(event) + + yield self.store.update_current_state( + pdu_id=event.event_id, + origin=self.server_name, + context=key.context, + pdu_type=key.type, + state_key=key.state_key + ) + + defer.returnValue(True) + + @defer.inlineCallbacks + @log_function + def handle_new_state(self, new_pdu): + """ Apply conflict resolution to `new_pdu`. + + This should be called on every new state pdu, regardless of whether or + not there is a conflict. + + This function is safe against the race of it getting called with two + `PDU`s trying to update the same state. + """ + + # This needs to be done in a transaction. + + is_new = yield self._handle_new_state(new_pdu) + + if is_new: + yield self.store.update_current_state( + pdu_id=new_pdu.pdu_id, + origin=new_pdu.origin, + context=new_pdu.context, + pdu_type=new_pdu.pdu_type, + state_key=new_pdu.state_key + ) + + defer.returnValue(is_new) + + def _get_power_level_for_event(self, event): + # return self._persistence.get_power_level_for_user(event.room_id, + # event.sender) + return event.power_level + + @defer.inlineCallbacks + @log_function + def _handle_new_state(self, new_pdu): + tree = yield self.store.get_unresolved_state_tree(new_pdu) + new_branch, current_branch = tree + + logger.debug( + "_handle_new_state new=%s, current=%s", + new_branch, current_branch + ) + + if not current_branch: + # There is no current state + defer.returnValue(True) + return + + if new_branch[-1] == current_branch[-1]: + # We have all the PDUs we need, so we can just do the conflict + # resolution. + + if len(current_branch) == 1: + # This is a direct clobber so we can just... + defer.returnValue(True) + + conflict_res = [ + self._do_power_level_conflict_res, + self._do_chain_length_conflict_res, + self._do_hash_conflict_res, + ] + + for algo in conflict_res: + new_res, curr_res = algo(new_branch, current_branch) + + if new_res < curr_res: + defer.returnValue(False) + elif new_res > curr_res: + defer.returnValue(True) + + raise Exception("Conflict resolution failed.") + + else: + # We need to ask for PDUs. + missing_prev = max( + new_branch[-1], current_branch[-1], + key=lambda x: x.depth + ) + + yield self._replication.get_pdu( + destination=missing_prev.origin, + pdu_origin=missing_prev.prev_state_origin, + pdu_id=missing_prev.prev_state_id, + outlier=True + ) + + updated_current = yield self._handle_new_state(new_pdu) + defer.returnValue(updated_current) + + def _do_power_level_conflict_res(self, new_branch, current_branch): + max_power_new = max( + new_branch[:-1], + key=lambda t: t.power_level + ).power_level + + max_power_current = max( + current_branch[:-1], + key=lambda t: t.power_level + ).power_level + + return (max_power_new, max_power_current) + + def _do_chain_length_conflict_res(self, new_branch, current_branch): + return (len(new_branch), len(current_branch)) + + def _do_hash_conflict_res(self, new_branch, current_branch): + new_str = "".join([p.pdu_id + p.origin for p in new_branch]) + c_str = "".join([p.pdu_id + p.origin for p in current_branch]) + + return ( + hashlib.sha1(new_str).hexdigest(), + hashlib.sha1(c_str).hexdigest() + ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py new file mode 100644 index 0000000000..ec93f9f8a7 --- /dev/null +++ b/synapse/storage/__init__.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.events.room import ( + RoomMemberEvent, MessageEvent, RoomTopicEvent, FeedbackEvent, + RoomConfigEvent +) + +from .directory import DirectoryStore +from .feedback import FeedbackStore +from .message import MessageStore +from .presence import PresenceStore +from .profile import ProfileStore +from .registration import RegistrationStore +from .room import RoomStore +from .roommember import RoomMemberStore +from .roomdata import RoomDataStore +from .stream import StreamStore +from .pdu import StatePduStore, PduStore +from .transactions import TransactionStore + +import json +import os + + +class DataStore(RoomDataStore, RoomMemberStore, MessageStore, RoomStore, + RegistrationStore, StreamStore, ProfileStore, FeedbackStore, + PresenceStore, PduStore, StatePduStore, TransactionStore, + DirectoryStore): + + def __init__(self, hs): + super(DataStore, self).__init__(hs) + self.event_factory = hs.get_event_factory() + self.hs = hs + + def persist_event(self, event): + if event.type == MessageEvent.TYPE: + return self.store_message( + user_id=event.user_id, + room_id=event.room_id, + msg_id=event.msg_id, + content=json.dumps(event.content) + ) + elif event.type == RoomMemberEvent.TYPE: + return self.store_room_member( + user_id=event.target_user_id, + sender=event.user_id, + room_id=event.room_id, + content=event.content, + membership=event.content["membership"] + ) + elif event.type == FeedbackEvent.TYPE: + return self.store_feedback( + room_id=event.room_id, + msg_id=event.msg_id, + msg_sender_id=event.msg_sender_id, + fb_sender_id=event.user_id, + fb_type=event.feedback_type, + content=json.dumps(event.content) + ) + elif event.type == RoomTopicEvent.TYPE: + return self.store_room_data( + room_id=event.room_id, + etype=event.type, + state_key=event.state_key, + content=json.dumps(event.content) + ) + elif event.type == RoomConfigEvent.TYPE: + if "visibility" in event.content: + visibility = event.content["visibility"] + return self.store_room_config( + room_id=event.room_id, + visibility=visibility + ) + + else: + raise NotImplementedError( + "Don't know how to persist type=%s" % event.type + ) + + +def schema_path(schema): + """ Get a filesystem path for the named database schema + + Args: + schema: Name of the database schema. + Returns: + A filesystem path pointing at a ".sql" file. + + """ + dir_path = os.path.dirname(__file__) + schemaPath = os.path.join(dir_path, "schema", schema + ".sql") + return schemaPath + + +def read_schema(schema): + """ Read the named database schema. + + Args: + schema: Name of the datbase schema. + Returns: + A string containing the database schema. + """ + with open(schema_path(schema)) as schema_file: + return schema_file.read() diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py new file mode 100644 index 0000000000..4d98a6fd0d --- /dev/null +++ b/synapse/storage/_base.py @@ -0,0 +1,405 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import StoreError + +import collections + +logger = logging.getLogger(__name__) + + +class SQLBaseStore(object): + + def __init__(self, hs): + self._db_pool = hs.get_db_pool() + + def cursor_to_dict(self, cursor): + """Converts a SQL cursor into an list of dicts. + + Args: + cursor : The DBAPI cursor which has executed a query. + Returns: + A list of dicts where the key is the column header. + """ + col_headers = list(column[0] for column in cursor.description) + results = list( + dict(zip(col_headers, row)) for row in cursor.fetchall() + ) + return results + + def _execute(self, decoder, query, *args): + """Runs a single query for a result set. + + Args: + decoder - The function which can resolve the cursor results to + something meaningful. + query - The query string to execute + *args - Query args. + Returns: + The result of decoder(results) + """ + logger.debug( + "[SQL] %s Args=%s Func=%s", query, args, decoder.__name__ + ) + + def interaction(txn): + cursor = txn.execute(query, args) + return decoder(cursor) + return self._db_pool.runInteraction(interaction) + + # "Simple" SQL API methods that operate on a single table with no JOINs, + # no complex WHERE clauses, just a dict of values for columns. + + def _simple_insert(self, table, values): + """Executes an INSERT query on the named table. + + Args: + table : string giving the table name + values : dict of new column names and values for them + """ + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in values), + ", ".join("?" for k in values) + ) + + def func(txn): + txn.execute(sql, values.values()) + return txn.lastrowid + return self._db_pool.runInteraction(func) + + def _simple_select_one(self, table, keyvalues, retcols, + allow_none=False): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning a single column from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcols : list of strings giving the names of the columns to return + + allow_none : If true, return None instead of failing if the SELECT + statement returns no rows + """ + return self._simple_selectupdate_one( + table, keyvalues, retcols=retcols, allow_none=allow_none + ) + + @defer.inlineCallbacks + def _simple_select_one_onecol(self, table, keyvalues, retcol, + allow_none=False): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning a single column from it." + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcol : string giving the name of the column to return + """ + ret = yield self._simple_select_one( + table=table, + keyvalues=keyvalues, + retcols=[retcol], + allow_none=allow_none + ) + + if ret: + defer.returnValue(ret[retcol]) + else: + defer.returnValue(None) + + @defer.inlineCallbacks + def _simple_select_onecol(self, table, keyvalues, retcol): + """Executes a SELECT query on the named table, which returns a list + comprising of the values of the named column from the selected rows. + + Args: + table (str): table name + keyvalues (dict): column names and values to select the rows with + retcol (str): column whos value we wish to retrieve. + + Returns: + Deferred: Results in a list + """ + sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { + "retcol": retcol, + "table": table, + "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), + } + + def func(txn): + txn.execute(sql, keyvalues.values()) + return txn.fetchall() + + res = yield self._db_pool.runInteraction(func) + + defer.returnValue([r[0] for r in res]) + + def _simple_select_list(self, table, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k) for k in keyvalues) + ) + + def func(txn): + txn.execute(sql, keyvalues.values()) + return self.cursor_to_dict(txn) + + return self._db_pool.runInteraction(func) + + def _simple_update_one(self, table, keyvalues, updatevalues, + retcols=None): + """Executes an UPDATE query on the named table, setting new values for + columns in a row matching the key values. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + updatevalues : dict giving column names and values to update + retcols : optional list of column names to return + + If present, retcols gives a list of column names on which to perform + a SELECT statement *before* performing the UPDATE statement. The values + of these will be returned in a dict. + + These are performed within the same transaction, allowing an atomic + get-and-set. This can be used to implement compare-and-set by putting + the update column in the 'keyvalues' dict as well. + """ + return self._simple_selectupdate_one(table, keyvalues, updatevalues, + retcols=retcols) + + def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None, + retcols=None, allow_none=False): + """ Combined SELECT then UPDATE.""" + if retcols: + select_sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k) for k in keyvalues) + ) + + if updatevalues: + update_sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k) for k in updatevalues), + " AND ".join("%s = ?" % (k) for k in keyvalues) + ) + + def func(txn): + ret = None + if retcols: + txn.execute(select_sql, keyvalues.values()) + + row = txn.fetchone() + if not row: + if allow_none: + return None + raise StoreError(404, "No row found") + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched") + + ret = dict(zip(retcols, row)) + + if updatevalues: + txn.execute( + update_sql, + updatevalues.values() + keyvalues.values() + ) + + if txn.rowcount == 0: + raise StoreError(404, "No row found") + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched") + + return ret + return self._db_pool.runInteraction(func) + + def _simple_delete_one(self, table, keyvalues): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k) for k in keyvalues) + ) + + def func(txn): + txn.execute(sql, keyvalues.values()) + if txn.rowcount == 0: + raise StoreError(404, "No row found") + if txn.rowcount > 1: + raise StoreError(500, "more than one row matched") + return self._db_pool.runInteraction(func) + + def _simple_max_id(self, table): + """Executes a SELECT query on the named table, expecting to return the + max value for the column "id". + + Args: + table : string giving the table name + """ + sql = "SELECT MAX(id) AS id FROM %s" % table + + def func(txn): + txn.execute(sql) + max_id = self.cursor_to_dict(txn)[0]["id"] + if max_id is None: + return 0 + return max_id + + return self._db_pool.runInteraction(func) + + +class Table(object): + """ A base class used to store information about a particular table. + """ + + table_name = None + """ str: The name of the table """ + + fields = None + """ list: The field names """ + + EntryType = None + """ Type: A tuple type used to decode the results """ + + _select_where_clause = "SELECT %s FROM %s WHERE %s" + _select_clause = "SELECT %s FROM %s" + _insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)" + + @classmethod + def select_statement(cls, where_clause=None): + """ + Args: + where_clause (str): The WHERE clause to use. + + Returns: + str: An SQL statement to select rows from the table with the given + WHERE clause. + """ + if where_clause: + return cls._select_where_clause % ( + ", ".join(cls.fields), + cls.table_name, + where_clause + ) + else: + return cls._select_clause % ( + ", ".join(cls.fields), + cls.table_name, + ) + + @classmethod + def insert_statement(cls): + return cls._insert_clause % ( + cls.table_name, + ", ".join(cls.fields), + ", ".join(["?"] * len(cls.fields)), + ) + + @classmethod + def decode_single_result(cls, results): + """ Given an iterable of tuples, return a single instance of + `EntryType` or None if the iterable is empty + Args: + results (list): The results list to convert to `EntryType` + Returns: + EntryType: An instance of `EntryType` + """ + results = list(results) + if results: + return cls.EntryType(*results[0]) + else: + return None + + @classmethod + def decode_results(cls, results): + """ Given an iterable of tuples, return a list of `EntryType` + Args: + results (list): The results list to convert to `EntryType` + + Returns: + list: A list of `EntryType` + """ + return [cls.EntryType(*row) for row in results] + + @classmethod + def get_fields_string(cls, prefix=None): + if prefix: + to_join = ("%s.%s" % (prefix, f) for f in cls.fields) + else: + to_join = cls.fields + + return ", ".join(to_join) + + +class JoinHelper(object): + """ Used to help do joins on tables by looking at the tables' fields and + creating a list of unique fields to use with SELECTs and a namedtuple + to dump the results into. + + Attributes: + taples (list): List of `Table` classes + EntryType (type) + """ + + def __init__(self, *tables): + self.tables = tables + + res = [] + for table in self.tables: + res += [f for f in table.fields if f not in res] + + self.EntryType = collections.namedtuple("JoinHelperEntry", res) + + def get_fields(self, **prefixes): + """Get a string representing a list of fields for use in SELECT + statements with the given prefixes applied to each. + + For example:: + + JoinHelper(PdusTable, StateTable).get_fields( + PdusTable="pdus", + StateTable="state" + ) + """ + res = [] + for field in self.EntryType._fields: + for table in self.tables: + if field in table.fields: + res.append("%s.%s" % (prefixes[table.__name__], field)) + break + + return ", ".join(res) + + def decode_results(self, rows): + return [self.EntryType(*row) for row in rows] diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py new file mode 100644 index 0000000000..71fa9d9c9c --- /dev/null +++ b/synapse/storage/directory.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore +from twisted.internet import defer + +from collections import namedtuple + + +RoomAliasMapping = namedtuple( + "RoomAliasMapping", + ("room_id", "room_alias", "servers",) +) + + +class DirectoryStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_association_from_room_alias(self, room_alias): + """ Get's the room_id and server list for a given room_alias + + Args: + room_alias (RoomAlias) + + Returns: + Deferred: results in namedtuple with keys "room_id" and + "servers" or None if no association can be found + """ + room_id = yield self._simple_select_one_onecol( + "room_aliases", + {"room_alias": room_alias.to_string()}, + "room_id", + allow_none=True, + ) + + if not room_id: + defer.returnValue(None) + return + + servers = yield self._simple_select_onecol( + "room_alias_servers", + {"room_alias": room_alias.to_string()}, + "server", + ) + + if not servers: + defer.returnValue(None) + return + + defer.returnValue( + RoomAliasMapping(room_id, room_alias.to_string(), servers) + ) + + @defer.inlineCallbacks + def create_room_alias_association(self, room_alias, room_id, servers): + """ Creates an associatin between a room alias and room_id/servers + + Args: + room_alias (RoomAlias) + room_id (str) + servers (list) + + Returns: + Deferred + """ + yield self._simple_insert( + "room_aliases", + { + "room_alias": room_alias.to_string(), + "room_id": room_id, + }, + ) + + for server in servers: + # TODO(erikj): Fix this to bulk insert + yield self._simple_insert( + "room_alias_servers", + { + "room_alias": room_alias.to_string(), + "server": server, + } + ) diff --git a/synapse/storage/feedback.py b/synapse/storage/feedback.py new file mode 100644 index 0000000000..2b421e3342 --- /dev/null +++ b/synapse/storage/feedback.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore, Table +from synapse.api.events.room import FeedbackEvent + +import collections +import json + + +class FeedbackStore(SQLBaseStore): + + def store_feedback(self, room_id, msg_id, msg_sender_id, + fb_sender_id, fb_type, content): + return self._simple_insert(FeedbackTable.table_name, dict( + room_id=room_id, + msg_id=msg_id, + msg_sender_id=msg_sender_id, + fb_sender_id=fb_sender_id, + fb_type=fb_type, + content=content, + )) + + def get_feedback(self, room_id=None, msg_id=None, msg_sender_id=None, + fb_sender_id=None, fb_type=None): + query = FeedbackTable.select_statement( + "msg_sender_id = ? AND room_id = ? AND msg_id = ? " + + "AND fb_sender_id = ? AND feedback_type = ? " + + "ORDER BY id DESC LIMIT 1") + return self._execute( + FeedbackTable.decode_single_result, + query, msg_sender_id, room_id, msg_id, fb_sender_id, fb_type, + ) + + def get_max_feedback_id(self): + return self._simple_max_id(FeedbackTable.table_name) + + +class FeedbackTable(Table): + table_name = "feedback" + + fields = [ + "id", + "content", + "feedback_type", + "fb_sender_id", + "msg_id", + "room_id", + "msg_sender_id" + ] + + class EntryType(collections.namedtuple("FeedbackEntry", fields)): + + def as_event(self, event_factory): + return event_factory.create_event( + etype=FeedbackEvent.TYPE, + room_id=self.room_id, + msg_id=self.msg_id, + msg_sender_id=self.msg_sender_id, + user_id=self.fb_sender_id, + feedback_type=self.feedback_type, + content=json.loads(self.content), + ) diff --git a/synapse/storage/message.py b/synapse/storage/message.py new file mode 100644 index 0000000000..4822fa709d --- /dev/null +++ b/synapse/storage/message.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore, Table +from synapse.api.events.room import MessageEvent + +import collections +import json + + +class MessageStore(SQLBaseStore): + + def get_message(self, user_id, room_id, msg_id): + """Get a message from the store. + + Args: + user_id (str): The ID of the user who sent the message. + room_id (str): The room the message was sent in. + msg_id (str): The unique ID for this user/room combo. + """ + query = MessagesTable.select_statement( + "user_id = ? AND room_id = ? AND msg_id = ? " + + "ORDER BY id DESC LIMIT 1") + return self._execute( + MessagesTable.decode_single_result, + query, user_id, room_id, msg_id, + ) + + def store_message(self, user_id, room_id, msg_id, content): + """Store a message in the store. + + Args: + user_id (str): The ID of the user who sent the message. + room_id (str): The room the message was sent in. + msg_id (str): The unique ID for this user/room combo. + content (str): The content of the message (JSON) + """ + return self._simple_insert(MessagesTable.table_name, dict( + user_id=user_id, + room_id=room_id, + msg_id=msg_id, + content=content, + )) + + def get_max_message_id(self): + return self._simple_max_id(MessagesTable.table_name) + + +class MessagesTable(Table): + table_name = "messages" + + fields = [ + "id", + "user_id", + "room_id", + "msg_id", + "content" + ] + + class EntryType(collections.namedtuple("MessageEntry", fields)): + + def as_event(self, event_factory): + return event_factory.create_event( + etype=MessageEvent.TYPE, + room_id=self.room_id, + user_id=self.user_id, + msg_id=self.msg_id, + content=json.loads(self.content), + ) diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py new file mode 100644 index 0000000000..a1cdde0a3b --- /dev/null +++ b/synapse/storage/pdu.py @@ -0,0 +1,993 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore, Table, JoinHelper + +from synapse.util.logutils import log_function + +from collections import namedtuple + +import logging + +logger = logging.getLogger(__name__) + + +class PduStore(SQLBaseStore): + """A collection of queries for handling PDUs. + """ + + def get_pdu(self, pdu_id, origin): + """Given a pdu_id and origin, get a PDU. + + Args: + txn + pdu_id (str) + origin (str) + + Returns: + PduTuple: If the pdu does not exist in the database, returns None + """ + + return self._db_pool.runInteraction( + self._get_pdu_tuple, pdu_id, origin + ) + + def _get_pdu_tuple(self, txn, pdu_id, origin): + res = self._get_pdu_tuples(txn, [(pdu_id, origin)]) + return res[0] if res else None + + def _get_pdu_tuples(self, txn, pdu_id_tuples): + results = [] + for pdu_id, origin in pdu_id_tuples: + txn.execute( + PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"), + (pdu_id, origin) + ) + + edges = [ + (r.prev_pdu_id, r.prev_origin) + for r in PduEdgesTable.decode_results(txn.fetchall()) + ] + + query = ( + "SELECT %(fields)s FROM %(pdus)s as p " + "LEFT JOIN %(state)s as s " + "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " + "WHERE p.pdu_id = ? AND p.origin = ? " + ) % { + "fields": _pdu_state_joiner.get_fields( + PdusTable="p", StatePdusTable="s"), + "pdus": PdusTable.table_name, + "state": StatePdusTable.table_name, + } + + txn.execute(query, (pdu_id, origin)) + + row = txn.fetchone() + if row: + results.append(PduTuple(PduEntry(*row), edges)) + + return results + + def get_current_state_for_context(self, context): + """Get a list of PDUs that represent the current state for a given + context + + Args: + context (str) + + Returns: + list: A list of PduTuples + """ + + return self._db_pool.runInteraction( + self._get_current_state_for_context, + context + ) + + def _get_current_state_for_context(self, txn, context): + query = ( + "SELECT pdu_id, origin FROM %s WHERE context = ?" + % CurrentStateTable.table_name + ) + + logger.debug("get_current_state %s, Args=%s", query, context) + txn.execute(query, (context,)) + + res = txn.fetchall() + + logger.debug("get_current_state %d results", len(res)) + + return self._get_pdu_tuples(txn, res) + + def persist_pdu(self, prev_pdus, **cols): + """Inserts a (non-state) PDU into the database. + + Args: + txn, + prev_pdus (list) + **cols: The columns to insert into the PdusTable. + """ + return self._db_pool.runInteraction( + self._persist_pdu, prev_pdus, cols + ) + + def _persist_pdu(self, txn, prev_pdus, cols): + entry = PdusTable.EntryType( + **{k: cols.get(k, None) for k in PdusTable.fields} + ) + + txn.execute(PdusTable.insert_statement(), entry) + + self._handle_prev_pdus( + txn, entry.outlier, entry.pdu_id, entry.origin, + prev_pdus, entry.context + ) + + def mark_pdu_as_processed(self, pdu_id, pdu_origin): + """Mark a received PDU as processed. + + Args: + txn + pdu_id (str) + pdu_origin (str) + """ + + return self._db_pool.runInteraction( + self._mark_as_processed, pdu_id, pdu_origin + ) + + def _mark_as_processed(self, txn, pdu_id, pdu_origin): + txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name) + + def get_all_pdus_from_context(self, context): + """Get a list of all PDUs for a given context.""" + return self._db_pool.runInteraction( + self._get_all_pdus_from_context, context, + ) + + def _get_all_pdus_from_context(self, txn, context): + query = ( + "SELECT pdu_id, origin FROM %s " + "WHERE context = ?" + ) % PdusTable.table_name + + txn.execute(query, (context,)) + + return self._get_pdu_tuples(txn, txn.fetchall()) + + def get_pagination(self, context, pdu_list, limit): + """Get a list of Pdus for a given topic that occured before (and + including) the pdus in pdu_list. Return a list of max size `limit`. + + Args: + txn + context (str) + pdu_list (list) + limit (int) + + Return: + list: A list of PduTuples + """ + return self._db_pool.runInteraction( + self._get_paginate, context, pdu_list, limit + ) + + def _get_paginate(self, txn, context, pdu_list, limit): + logger.debug( + "paginate: %s, %s, %s", + context, repr(pdu_list), limit + ) + + # We seed the pdu_results with the things from the pdu_list. + pdu_results = pdu_list + + front = pdu_list + + query = ( + "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s " + "WHERE context = ? AND pdu_id = ? AND origin = ? " + "LIMIT ?" + ) % { + "edges_table": PduEdgesTable.table_name, + } + + # We iterate through all pdu_ids in `front` to select their previous + # pdus. These are dumped in `new_front`. We continue until we reach the + # limit *or* new_front is empty (i.e., we've run out of things to + # select + while front and len(pdu_results) < limit: + + new_front = [] + for pdu_id, origin in front: + logger.debug( + "_paginate_interaction: i=%s, o=%s", + pdu_id, origin + ) + + txn.execute( + query, + (context, pdu_id, origin, limit - len(pdu_results)) + ) + + for row in txn.fetchall(): + logger.debug( + "_paginate_interaction: got i=%s, o=%s", + *row + ) + new_front.append(row) + + front = new_front + pdu_results += new_front + + # We also want to update the `prev_pdus` attributes before returning. + return self._get_pdu_tuples(txn, pdu_results) + + def get_min_depth_for_context(self, context): + """Get the current minimum depth for a context + + Args: + txn + context (str) + """ + return self._db_pool.runInteraction( + self._get_min_depth_for_context, context + ) + + def _get_min_depth_for_context(self, txn, context): + return self._get_min_depth_interaction(txn, context) + + def _get_min_depth_interaction(self, txn, context): + txn.execute( + "SELECT min_depth FROM %s WHERE context = ?" + % ContextDepthTable.table_name, + (context,) + ) + + row = txn.fetchone() + + return row[0] if row else None + + def update_min_depth_for_context(self, context, depth): + """Update the minimum `depth` of the given context, which is the line + where we stop paginating backwards on. + + Args: + context (str) + depth (int) + """ + return self._db_pool.runInteraction( + self._update_min_depth_for_context, context, depth + ) + + def _update_min_depth_for_context(self, txn, context, depth): + min_depth = self._get_min_depth_interaction(txn, context) + + do_insert = depth < min_depth if min_depth else True + + if do_insert: + txn.execute( + "INSERT OR REPLACE INTO %s (context, min_depth) " + "VALUES (?,?)" % ContextDepthTable.table_name, + (context, depth) + ) + + def get_latest_pdus_in_context(self, context): + """Get's a list of the most current pdus for a given context. This is + used when we are sending a Pdu and need to fill out the `prev_pdus` + key + + Args: + txn + context + """ + return self._db_pool.runInteraction( + self._get_latest_pdus_in_context, context + ) + + def _get_latest_pdus_in_context(self, txn, context): + query = ( + "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p " + "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id " + "AND f.origin = p.origin " + "WHERE f.context = ?" + ) % { + "pdus": PdusTable.table_name, + "forward": PduForwardExtremitiesTable.table_name, + } + + logger.debug("get_prev query: %s", query) + + txn.execute( + query, + (context, ) + ) + + results = txn.fetchall() + + return [(row[0], row[1], row[2]) for row in results] + + def get_oldest_pdus_in_context(self, context): + """Get a list of Pdus that we paginated beyond yet (and haven't seen). + This list is used when we want to paginate backwards and is the list we + send to the remote server. + + Args: + txn + context (str) + + Returns: + list: A list of PduIdTuple. + """ + return self._db_pool.runInteraction( + self._get_oldest_pdus_in_context, context + ) + + def _get_oldest_pdus_in_context(self, txn, context): + txn.execute( + "SELECT pdu_id, origin FROM %(back)s WHERE context = ?" + % {"back": PduBackwardExtremitiesTable.table_name, }, + (context,) + ) + return [PduIdTuple(i, o) for i, o in txn.fetchall()] + + def is_pdu_new(self, pdu_id, origin, context, depth): + """For a given Pdu, try and figure out if it's 'new', i.e., if it's + not something we got randomly from the past, for example when we + request the current state of the room that will probably return a bunch + of pdus from before we joined. + + Args: + txn + pdu_id (str) + origin (str) + context (str) + depth (int) + + Returns: + bool + """ + + return self._db_pool.runInteraction( + self._is_pdu_new, + pdu_id=pdu_id, + origin=origin, + context=context, + depth=depth + ) + + def _is_pdu_new(self, txn, pdu_id, origin, context, depth): + # If depth > min depth in back table, then we classify it as new. + # OR if there is nothing in the back table, then it kinda needs to + # be a new thing. + query = ( + "SELECT min(p.depth) FROM %(edges)s as e " + "INNER JOIN %(back)s as b " + "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin " + "INNER JOIN %(pdus)s as p " + "ON e.pdu_id = p.pdu_id AND p.origin = e.origin " + "WHERE p.context = ?" + ) % { + "pdus": PdusTable.table_name, + "edges": PduEdgesTable.table_name, + "back": PduBackwardExtremitiesTable.table_name, + } + + txn.execute(query, (context,)) + + min_depth, = txn.fetchone() + + if not min_depth or depth > int(min_depth): + logger.debug( + "is_new true: id=%s, o=%s, d=%s min_depth=%s", + pdu_id, origin, depth, min_depth + ) + return True + + # If this pdu is in the forwards table, then it also is a new one + query = ( + "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?" + ) % { + "forward": PduForwardExtremitiesTable.table_name, + } + + txn.execute(query, (pdu_id, origin)) + + # Did we get anything? + if txn.fetchall(): + logger.debug( + "is_new true: id=%s, o=%s, d=%s was forward", + pdu_id, origin, depth + ) + return True + + logger.debug( + "is_new false: id=%s, o=%s, d=%s", + pdu_id, origin, depth + ) + + # FINE THEN. It's probably old. + return False + + @staticmethod + @log_function + def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus, + context): + txn.executemany( + PduEdgesTable.insert_statement(), + [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus] + ) + + # Update the extremities table if this is not an outlier. + if not outlier: + + # First, we delete the new one from the forwards extremities table. + query = ( + "DELETE FROM %s WHERE pdu_id = ? AND origin = ?" + % PduForwardExtremitiesTable.table_name + ) + txn.executemany(query, prev_pdus) + + # We only insert as a forward extremety the new pdu if there are no + # other pdus that reference it as a prev pdu + query = ( + "INSERT INTO %(table)s (pdu_id, origin, context) " + "SELECT ?, ?, ? WHERE NOT EXISTS (" + "SELECT 1 FROM %(pdu_edges)s WHERE " + "prev_pdu_id = ? AND prev_origin = ?" + ")" + ) % { + "table": PduForwardExtremitiesTable.table_name, + "pdu_edges": PduEdgesTable.table_name + } + + logger.debug("query: %s", query) + + txn.execute(query, (pdu_id, origin, context, pdu_id, origin)) + + # Insert all the prev_pdus as a backwards thing, they'll get + # deleted in a second if they're incorrect anyway. + txn.executemany( + PduBackwardExtremitiesTable.insert_statement(), + [(i, o, context) for i, o in prev_pdus] + ) + + # Also delete from the backwards extremities table all ones that + # reference pdus that we have already seen + query = ( + "DELETE FROM %(pdu_back)s WHERE EXISTS (" + "SELECT 1 FROM %(pdus)s AS pdus " + "WHERE " + "%(pdu_back)s.pdu_id = pdus.pdu_id " + "AND %(pdu_back)s.origin = pdus.origin " + "AND not pdus.outlier " + ")" + ) % { + "pdu_back": PduBackwardExtremitiesTable.table_name, + "pdus": PdusTable.table_name, + } + txn.execute(query) + + +class StatePduStore(SQLBaseStore): + """A collection of queries for handling state PDUs. + """ + + def persist_state(self, prev_pdus, **cols): + """Inserts a state PDU into the database + + Args: + txn, + prev_pdus (list) + **cols: The columns to insert into the PdusTable and StatePdusTable + """ + + return self._db_pool.runInteraction( + self._persist_state, prev_pdus, cols + ) + + def _persist_state(self, txn, prev_pdus, cols): + pdu_entry = PdusTable.EntryType( + **{k: cols.get(k, None) for k in PdusTable.fields} + ) + state_entry = StatePdusTable.EntryType( + **{k: cols.get(k, None) for k in StatePdusTable.fields} + ) + + logger.debug("Inserting pdu: %s", repr(pdu_entry)) + logger.debug("Inserting state: %s", repr(state_entry)) + + txn.execute(PdusTable.insert_statement(), pdu_entry) + txn.execute(StatePdusTable.insert_statement(), state_entry) + + self._handle_prev_pdus( + txn, + pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus, + pdu_entry.context + ) + + def get_unresolved_state_tree(self, new_state_pdu): + return self._db_pool.runInteraction( + self._get_unresolved_state_tree, new_state_pdu + ) + + @log_function + def _get_unresolved_state_tree(self, txn, new_pdu): + current = self._get_current_interaction( + txn, + new_pdu.context, new_pdu.pdu_type, new_pdu.state_key + ) + + ReturnType = namedtuple( + "StateReturnType", ["new_branch", "current_branch"] + ) + return_value = ReturnType([new_pdu], []) + + if not current: + logger.debug("get_unresolved_state_tree No current state.") + return return_value + + return_value.current_branch.append(current) + + enum_branches = self._enumerate_state_branches( + txn, new_pdu, current + ) + + for branch, prev_state, state in enum_branches: + if state: + return_value[branch].append(state) + else: + break + + return return_value + + def update_current_state(self, pdu_id, origin, context, pdu_type, + state_key): + return self._db_pool.runInteraction( + self._update_current_state, + pdu_id, origin, context, pdu_type, state_key + ) + + def _update_current_state(self, txn, pdu_id, origin, context, pdu_type, + state_key): + query = ( + "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" + ) % { + "curr": CurrentStateTable.table_name, + "fields": CurrentStateTable.get_fields_string(), + "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) + } + + query_args = CurrentStateTable.EntryType( + pdu_id=pdu_id, + origin=origin, + context=context, + pdu_type=pdu_type, + state_key=state_key + ) + + txn.execute(query, query_args) + + def get_current_state(self, context, pdu_type, state_key): + """For a given context, pdu_type, state_key 3-tuple, return what is + currently considered the current state. + + Args: + txn + context (str) + pdu_type (str) + state_key (str) + + Returns: + PduEntry + """ + + return self._db_pool.runInteraction( + self._get_current_state, context, pdu_type, state_key + ) + + def _get_current_state(self, txn, context, pdu_type, state_key): + return self._get_current_interaction(txn, context, pdu_type, state_key) + + def _get_current_interaction(self, txn, context, pdu_type, state_key): + logger.debug( + "_get_current_interaction %s %s %s", + context, pdu_type, state_key + ) + + fields = _pdu_state_joiner.get_fields( + PdusTable="p", StatePdusTable="s") + + current_query = ( + "SELECT %(fields)s FROM %(state)s as s " + "INNER JOIN %(pdus)s as p " + "ON s.pdu_id = p.pdu_id AND s.origin = p.origin " + "INNER JOIN %(curr)s as c " + "ON s.pdu_id = c.pdu_id AND s.origin = c.origin " + "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? " + ) % { + "fields": fields, + "curr": CurrentStateTable.table_name, + "state": StatePdusTable.table_name, + "pdus": PdusTable.table_name, + } + + txn.execute( + current_query, + (context, pdu_type, state_key) + ) + + row = txn.fetchone() + + result = PduEntry(*row) if row else None + + if not result: + logger.debug("_get_current_interaction not found") + else: + logger.debug( + "_get_current_interaction found %s %s", + result.pdu_id, result.origin + ) + + return result + + def get_next_missing_pdu(self, new_pdu): + """When we get a new state pdu we need to check whether we need to do + any conflict resolution, if we do then we need to check if we need + to go back and request some more state pdus that we haven't seen yet. + + Args: + txn + new_pdu + + Returns: + PduIdTuple: A pdu that we are missing, or None if we have all the + pdus required to do the conflict resolution. + """ + return self._db_pool.runInteraction( + self._get_next_missing_pdu, new_pdu + ) + + def _get_next_missing_pdu(self, txn, new_pdu): + logger.debug( + "get_next_missing_pdu %s %s", + new_pdu.pdu_id, new_pdu.origin + ) + + current = self._get_current_interaction( + txn, + new_pdu.context, new_pdu.pdu_type, new_pdu.state_key + ) + + if (not current or not current.prev_state_id + or not current.prev_state_origin): + return None + + # Oh look, it's a straight clobber, so wooooo almost no-op. + if (new_pdu.prev_state_id == current.pdu_id + and new_pdu.prev_state_origin == current.origin): + return None + + enum_branches = self._enumerate_state_branches(txn, new_pdu, current) + for branch, prev_state, state in enum_branches: + if not state: + return PduIdTuple( + prev_state.prev_state_id, + prev_state.prev_state_origin + ) + + return None + + def handle_new_state(self, new_pdu): + """Actually perform conflict resolution on the new_pdu on the + assumption we have all the pdus required to perform it. + + Args: + new_pdu + + Returns: + bool: True if the new_pdu clobbered the current state, False if not + """ + return self._db_pool.runInteraction( + self._handle_new_state, new_pdu + ) + + def _handle_new_state(self, txn, new_pdu): + logger.debug( + "handle_new_state %s %s", + new_pdu.pdu_id, new_pdu.origin + ) + + current = self._get_current_interaction( + txn, + new_pdu.context, new_pdu.pdu_type, new_pdu.state_key + ) + + is_current = False + + if (not current or not current.prev_state_id + or not current.prev_state_origin): + # Oh, we don't have any state for this yet. + is_current = True + elif (current.pdu_id == new_pdu.prev_state_id + and current.origin == new_pdu.prev_state_origin): + # Oh! A direct clobber. Just do it. + is_current = True + else: + ## + # Ok, now loop through until we get to a common ancestor. + max_new = int(new_pdu.power_level) + max_current = int(current.power_level) + + enum_branches = self._enumerate_state_branches( + txn, new_pdu, current + ) + for branch, prev_state, state in enum_branches: + if not state: + raise RuntimeError( + "Could not find state_pdu %s %s" % + ( + prev_state.prev_state_id, + prev_state.prev_state_origin + ) + ) + + if branch == 0: + max_new = max(int(state.depth), max_new) + else: + max_current = max(int(state.depth), max_current) + + is_current = max_new > max_current + + if is_current: + logger.debug("handle_new_state make current") + + # Right, this is a new thing, so woo, just insert it. + txn.execute( + "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" + % { + "curr": CurrentStateTable.table_name, + "fields": CurrentStateTable.get_fields_string(), + "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) + }, + CurrentStateTable.EntryType( + *(new_pdu.__dict__[k] for k in CurrentStateTable.fields) + ) + ) + else: + logger.debug("handle_new_state not current") + + logger.debug("handle_new_state done") + + return is_current + + @classmethod + @log_function + def _enumerate_state_branches(cls, txn, pdu_a, pdu_b): + branch_a = pdu_a + branch_b = pdu_b + + get_query = ( + "SELECT %(fields)s FROM %(pdus)s as p " + "LEFT JOIN %(state)s as s " + "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " + "WHERE p.pdu_id = ? AND p.origin = ? " + ) % { + "fields": _pdu_state_joiner.get_fields( + PdusTable="p", StatePdusTable="s"), + "pdus": PdusTable.table_name, + "state": StatePdusTable.table_name, + } + + while True: + if (branch_a.pdu_id == branch_b.pdu_id + and branch_a.origin == branch_b.origin): + # Woo! We found a common ancestor + logger.debug("_enumerate_state_branches Found common ancestor") + break + + do_branch_a = ( + hasattr(branch_a, "prev_state_id") and + branch_a.prev_state_id + ) + + do_branch_b = ( + hasattr(branch_b, "prev_state_id") and + branch_b.prev_state_id + ) + + logger.debug( + "do_branch_a=%s, do_branch_b=%s", + do_branch_a, do_branch_b + ) + + if do_branch_a and do_branch_b: + do_branch_a = int(branch_a.depth) > int(branch_b.depth) + + if do_branch_a: + pdu_tuple = PduIdTuple( + branch_a.prev_state_id, + branch_a.prev_state_origin + ) + + logger.debug("getting branch_a prev %s", pdu_tuple) + txn.execute(get_query, pdu_tuple) + + prev_branch = branch_a + + res = txn.fetchone() + branch_a = PduEntry(*res) if res else None + + logger.debug("branch_a=%s", branch_a) + + yield (0, prev_branch, branch_a) + + if not branch_a: + break + elif do_branch_b: + pdu_tuple = PduIdTuple( + branch_b.prev_state_id, + branch_b.prev_state_origin + ) + txn.execute(get_query, pdu_tuple) + + logger.debug("getting branch_b prev %s", pdu_tuple) + + prev_branch = branch_b + + res = txn.fetchone() + branch_b = PduEntry(*res) if res else None + + logger.debug("branch_b=%s", branch_b) + + yield (1, prev_branch, branch_b) + + if not branch_b: + break + else: + break + + +class PdusTable(Table): + table_name = "pdus" + + fields = [ + "pdu_id", + "origin", + "context", + "pdu_type", + "ts", + "depth", + "is_state", + "content_json", + "unrecognized_keys", + "outlier", + "have_processed", + ] + + EntryType = namedtuple("PdusEntry", fields) + + +class PduDestinationsTable(Table): + table_name = "pdu_destinations" + + fields = [ + "pdu_id", + "origin", + "destination", + "delivered_ts", + ] + + EntryType = namedtuple("PduDestinationsEntry", fields) + + +class PduEdgesTable(Table): + table_name = "pdu_edges" + + fields = [ + "pdu_id", + "origin", + "prev_pdu_id", + "prev_origin", + "context" + ] + + EntryType = namedtuple("PduEdgesEntry", fields) + + +class PduForwardExtremitiesTable(Table): + table_name = "pdu_forward_extremities" + + fields = [ + "pdu_id", + "origin", + "context", + ] + + EntryType = namedtuple("PduForwardExtremitiesEntry", fields) + + +class PduBackwardExtremitiesTable(Table): + table_name = "pdu_backward_extremities" + + fields = [ + "pdu_id", + "origin", + "context", + ] + + EntryType = namedtuple("PduBackwardExtremitiesEntry", fields) + + +class ContextDepthTable(Table): + table_name = "context_depth" + + fields = [ + "context", + "min_depth", + ] + + EntryType = namedtuple("ContextDepthEntry", fields) + + +class StatePdusTable(Table): + table_name = "state_pdus" + + fields = [ + "pdu_id", + "origin", + "context", + "pdu_type", + "state_key", + "power_level", + "prev_state_id", + "prev_state_origin", + ] + + EntryType = namedtuple("StatePdusEntry", fields) + + +class CurrentStateTable(Table): + table_name = "current_state" + + fields = [ + "pdu_id", + "origin", + "context", + "pdu_type", + "state_key", + ] + + EntryType = namedtuple("CurrentStateEntry", fields) + +_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable) + + +# TODO: These should probably be put somewhere more sensible +PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin")) + +PduEntry = _pdu_state_joiner.EntryType +""" We are always interested in the join of the PdusTable and StatePdusTable, +rather than just the PdusTable. + +This does not include a prev_pdus key. +""" + +PduTuple = namedtuple( + "PduTuple", + ("pdu_entry", "prev_pdu_list") +) +""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent +the `prev_pdus` key of a PDU. +""" diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py new file mode 100644 index 0000000000..e57ddaf149 --- /dev/null +++ b/synapse/storage/presence.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore + + +class PresenceStore(SQLBaseStore): + def create_presence(self, user_localpart): + return self._simple_insert( + table="presence", + values={"user_id": user_localpart}, + ) + + def has_presence_state(self, user_localpart): + return self._simple_select_one( + table="presence", + keyvalues={"user_id": user_localpart}, + retcols=["user_id"], + allow_none=True, + ) + + def get_presence_state(self, user_localpart): + return self._simple_select_one( + table="presence", + keyvalues={"user_id": user_localpart}, + retcols=["state", "status_msg"], + ) + + def set_presence_state(self, user_localpart, new_state): + return self._simple_update_one( + table="presence", + keyvalues={"user_id": user_localpart}, + updatevalues={"state": new_state["state"], + "status_msg": new_state["status_msg"]}, + retcols=["state"], + ) + + def allow_presence_visible(self, observed_localpart, observer_userid): + return self._simple_insert( + table="presence_allow_inbound", + values={"observed_user_id": observed_localpart, + "observer_user_id": observer_userid}, + ) + + def disallow_presence_visible(self, observed_localpart, observer_userid): + return self._simple_delete_one( + table="presence_allow_inbound", + keyvalues={"observed_user_id": observed_localpart, + "observer_user_id": observer_userid}, + ) + + def is_presence_visible(self, observed_localpart, observer_userid): + return self._simple_select_one( + table="presence_allow_inbound", + keyvalues={"observed_user_id": observed_localpart, + "observer_user_id": observer_userid}, + allow_none=True, + ) + + def add_presence_list_pending(self, observer_localpart, observed_userid): + return self._simple_insert( + table="presence_list", + values={"user_id": observer_localpart, + "observed_user_id": observed_userid, + "accepted": False}, + ) + + def set_presence_list_accepted(self, observer_localpart, observed_userid): + return self._simple_update_one( + table="presence_list", + keyvalues={"user_id": observer_localpart, + "observed_user_id": observed_userid}, + updatevalues={"accepted": True}, + ) + + def get_presence_list(self, observer_localpart, accepted=None): + keyvalues = {"user_id": observer_localpart} + if accepted is not None: + keyvalues["accepted"] = accepted + + return self._simple_select_list( + table="presence_list", + keyvalues=keyvalues, + retcols=["observed_user_id", "accepted"], + ) + + def del_presence_list(self, observer_localpart, observed_userid): + return self._simple_delete_one( + table="presence_list", + keyvalues={"user_id": observer_localpart, + "observed_user_id": observed_userid}, + ) diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py new file mode 100644 index 0000000000..d2f24930c1 --- /dev/null +++ b/synapse/storage/profile.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore + + +class ProfileStore(SQLBaseStore): + def create_profile(self, user_localpart): + return self._simple_insert( + table="profiles", + values={"user_id": user_localpart}, + ) + + def get_profile_displayname(self, user_localpart): + return self._simple_select_one_onecol( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcol="displayname", + ) + + def set_profile_displayname(self, user_localpart, new_displayname): + return self._simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"displayname": new_displayname}, + ) + + def get_profile_avatar_url(self, user_localpart): + return self._simple_select_one_onecol( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcol="avatar_url", + ) + + def set_profile_avatar_url(self, user_localpart, new_avatar_url): + return self._simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"avatar_url": new_avatar_url}, + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py new file mode 100644 index 0000000000..4a970dd546 --- /dev/null +++ b/synapse/storage/registration.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from sqlite3 import IntegrityError + +from synapse.api.errors import StoreError + +from ._base import SQLBaseStore + + +class RegistrationStore(SQLBaseStore): + + def __init__(self, hs): + super(RegistrationStore, self).__init__(hs) + + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def add_access_token_to_user(self, user_id, token): + """Adds an access token for the given user. + + Args: + user_id (str): The user ID. + token (str): The new access token to add. + Raises: + StoreError if there was a problem adding this. + """ + row = yield self._simple_select_one("users", {"name": user_id}, ["id"]) + if not row: + raise StoreError(400, "Bad user ID supplied.") + row_id = row["id"] + yield self._simple_insert( + "access_tokens", + { + "user_id": row_id, + "token": token + } + ) + + @defer.inlineCallbacks + def register(self, user_id, token, password_hash): + """Attempts to register an account. + + Args: + user_id (str): The desired user ID to register. + token (str): The desired access token to use for this user. + password_hash (str): Optional. The password hash for this user. + Raises: + StoreError if the user_id could not be registered. + """ + yield self._db_pool.runInteraction(self._register, user_id, token, + password_hash) + + def _register(self, txn, user_id, token, password_hash): + now = int(self.clock.time()) + + try: + txn.execute("INSERT INTO users(name, password_hash, creation_ts) " + "VALUES (?,?,?)", + [user_id, password_hash, now]) + except IntegrityError: + raise StoreError(400, "User ID already taken.") + + # it's possible for this to get a conflict, but only for a single user + # since tokens are namespaced based on their user ID + txn.execute("INSERT INTO access_tokens(user_id, token) " + + "VALUES (?,?)", [txn.lastrowid, token]) + + def get_user_by_id(self, user_id): + query = ("SELECT users.name, users.password_hash FROM users " + "WHERE users.name = ?") + return self._execute( + self.cursor_to_dict, + query, user_id + ) + + @defer.inlineCallbacks + def get_user_by_token(self, token): + """Get a user from the given access token. + + Args: + token (str): The access token of a user. + Returns: + str: The user ID of the user. + Raises: + StoreError if no user was found. + """ + user_id = yield self._db_pool.runInteraction(self._query_for_auth, + token) + defer.returnValue(user_id) + + def _query_for_auth(self, txn, token): + txn.execute("SELECT users.name FROM access_tokens LEFT JOIN users" + + " ON users.id = access_tokens.user_id WHERE token = ?", + [token]) + row = txn.fetchone() + if row: + return row[0] + + raise StoreError(404, "Token not found.") diff --git a/synapse/storage/room.py b/synapse/storage/room.py new file mode 100644 index 0000000000..174cbcf3d8 --- /dev/null +++ b/synapse/storage/room.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from sqlite3 import IntegrityError + +from synapse.api.errors import StoreError +from synapse.api.events.room import RoomTopicEvent + +from ._base import SQLBaseStore, Table + +import collections +import json +import logging + +logger = logging.getLogger(__name__) + + +class RoomStore(SQLBaseStore): + + @defer.inlineCallbacks + def store_room(self, room_id, room_creator_user_id, is_public): + """Stores a room. + + Args: + room_id (str): The desired room ID, can be None. + room_creator_user_id (str): The user ID of the room creator. + is_public (bool): True to indicate that this room should appear in + public room lists. + Raises: + StoreError if the room could not be stored. + """ + try: + yield self._simple_insert(RoomsTable.table_name, dict( + room_id=room_id, + creator=room_creator_user_id, + is_public=is_public + )) + except IntegrityError: + raise StoreError(409, "Room ID in use.") + except Exception as e: + logger.error("store_room with room_id=%s failed: %s", room_id, e) + raise StoreError(500, "Problem creating room.") + + def store_room_config(self, room_id, visibility): + return self._simple_update_one( + table=RoomsTable.table_name, + keyvalues={"room_id": room_id}, + updatevalues={"is_public": visibility} + ) + + def get_room(self, room_id): + """Retrieve a room. + + Args: + room_id (str): The ID of the room to retrieve. + Returns: + A namedtuple containing the room information, or an empty list. + """ + query = RoomsTable.select_statement("room_id=?") + return self._execute( + RoomsTable.decode_single_result, query, room_id, + ) + + @defer.inlineCallbacks + def get_rooms(self, is_public, with_topics): + """Retrieve a list of all public rooms. + + Args: + is_public (bool): True if the rooms returned should be public. + with_topics (bool): True to include the current topic for the room + in the response. + Returns: + A list of room dicts containing at least a "room_id" key, and a + "topic" key if one is set and with_topic=True. + """ + room_data_type = RoomTopicEvent.TYPE + public = 1 if is_public else 0 + + latest_topic = ("SELECT max(room_data.id) FROM room_data WHERE " + + "room_data.type = ? GROUP BY room_id") + + query = ("SELECT rooms.*, room_data.content FROM rooms LEFT JOIN " + + "room_data ON rooms.room_id = room_data.room_id WHERE " + + "(room_data.id IN (" + latest_topic + ") " + + "OR room_data.id IS NULL) AND rooms.is_public = ?") + + res = yield self._execute( + self.cursor_to_dict, query, room_data_type, public + ) + + # return only the keys the specification expects + ret_keys = ["room_id", "topic"] + + # extract topic from the json (icky) FIXME + for i, room_row in enumerate(res): + try: + content_json = json.loads(room_row["content"]) + room_row["topic"] = content_json["topic"] + except: + pass # no topic set + # filter the dict based on ret_keys + res[i] = {k: v for k, v in room_row.iteritems() if k in ret_keys} + + defer.returnValue(res) + + +class RoomsTable(Table): + table_name = "rooms" + + fields = [ + "room_id", + "is_public", + "creator" + ] + + EntryType = collections.namedtuple("RoomEntry", fields) diff --git a/synapse/storage/roomdata.py b/synapse/storage/roomdata.py new file mode 100644 index 0000000000..781d477931 --- /dev/null +++ b/synapse/storage/roomdata.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore, Table + +import collections +import json + + +class RoomDataStore(SQLBaseStore): + + """Provides various CRUD operations for Room Events. """ + + def get_room_data(self, room_id, etype, state_key=""): + """Retrieve the data stored under this type and state_key. + + Args: + room_id (str) + etype (str) + state_key (str) + Returns: + namedtuple: Or None if nothing exists at this path. + """ + query = RoomDataTable.select_statement( + "room_id = ? AND type = ? AND state_key = ? " + "ORDER BY id DESC LIMIT 1" + ) + return self._execute( + RoomDataTable.decode_single_result, + query, room_id, etype, state_key, + ) + + def store_room_data(self, room_id, etype, state_key="", content=None): + """Stores room specific data. + + Args: + room_id (str) + etype (str) + state_key (str) + data (str)- The data to store for this path in JSON. + Returns: + The store ID for this data. + """ + return self._simple_insert(RoomDataTable.table_name, dict( + etype=etype, + state_key=state_key, + room_id=room_id, + content=content, + )) + + def get_max_room_data_id(self): + return self._simple_max_id(RoomDataTable.table_name) + + +class RoomDataTable(Table): + table_name = "room_data" + + fields = [ + "id", + "room_id", + "type", + "state_key", + "content" + ] + + class EntryType(collections.namedtuple("RoomDataEntry", fields)): + + def as_event(self, event_factory): + return event_factory.create_event( + etype=self.type, + room_id=self.room_id, + content=json.loads(self.content), + ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py new file mode 100644 index 0000000000..e6e7617797 --- /dev/null +++ b/synapse/storage/roommember.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.types import UserID +from synapse.api.constants import Membership +from synapse.api.events.room import RoomMemberEvent + +from ._base import SQLBaseStore, Table + + +import collections +import json +import logging + +logger = logging.getLogger(__name__) + + +class RoomMemberStore(SQLBaseStore): + + def get_room_member(self, user_id, room_id): + """Retrieve the current state of a room member. + + Args: + user_id (str): The member's user ID. + room_id (str): The room the member is in. + Returns: + namedtuple: The room member from the database, or None if this + member does not exist. + """ + query = RoomMemberTable.select_statement( + "room_id = ? AND user_id = ? ORDER BY id DESC LIMIT 1") + return self._execute( + RoomMemberTable.decode_single_result, + query, room_id, user_id, + ) + + def store_room_member(self, user_id, sender, room_id, membership, content): + """Store a room member in the database. + + Args: + user_id (str): The member's user ID. + room_id (str): The room in relation to the member. + membership (synapse.api.constants.Membership): The new membership + state. + content (dict): The content of the membership (JSON). + """ + content_json = json.dumps(content) + return self._simple_insert(RoomMemberTable.table_name, dict( + user_id=user_id, + sender=sender, + room_id=room_id, + membership=membership, + content=content_json, + )) + + @defer.inlineCallbacks + def get_room_members(self, room_id, membership=None): + """Retrieve the current room member list for a room. + + Args: + room_id (str): The room to get the list of members. + membership (synapse.api.constants.Membership): The filter to apply + to this list, or None to return all members with some state + associated with this room. + Returns: + list of namedtuples representing the members in this room. + """ + query = RoomMemberTable.select_statement( + "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name + + " WHERE room_id = ? GROUP BY user_id)" + ) + res = yield self._execute( + RoomMemberTable.decode_results, query, room_id, + ) + # strip memberships which don't match + if membership: + res = [entry for entry in res if entry.membership == membership] + defer.returnValue(res) + + def get_rooms_for_user_where_membership_is(self, user_id, membership_list): + """ Get all the rooms for this user where the membership for this user + matches one in the membership list. + + Args: + user_id (str): The user ID. + membership_list (list): A list of synapse.api.constants.Membership + values which the user must be in. + Returns: + A list of dicts with "room_id" and "membership" keys. + """ + if not membership_list: + return defer.succeed(None) + + args = [user_id] + membership_placeholder = ["membership=?"] * len(membership_list) + where_membership = "(" + " OR ".join(membership_placeholder) + ")" + for membership in membership_list: + args.append(membership) + + query = ("SELECT room_id, membership FROM room_memberships" + + " WHERE user_id=? AND " + where_membership + + " GROUP BY room_id ORDER BY id DESC") + return self._execute( + self.cursor_to_dict, query, *args + ) + + @defer.inlineCallbacks + def get_joined_hosts_for_room(self, room_id): + query = RoomMemberTable.select_statement( + "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name + + " WHERE room_id = ? GROUP BY user_id)" + ) + + res = yield self._execute( + RoomMemberTable.decode_results, query, room_id, + ) + + def host_from_user_id_string(user_id): + domain = UserID.from_string(entry.user_id, self.hs).domain + return domain + + # strip memberships which don't match + hosts = [ + host_from_user_id_string(entry.user_id) + for entry in res + if entry.membership == Membership.JOIN + ] + + logger.debug("Returning hosts: %s from results: %s", hosts, res) + + defer.returnValue(hosts) + + def get_max_room_member_id(self): + return self._simple_max_id(RoomMemberTable.table_name) + + +class RoomMemberTable(Table): + table_name = "room_memberships" + + fields = [ + "id", + "user_id", + "sender", + "room_id", + "membership", + "content" + ] + + class EntryType(collections.namedtuple("RoomMemberEntry", fields)): + + def as_event(self, event_factory): + return event_factory.create_event( + etype=RoomMemberEvent.TYPE, + room_id=self.room_id, + target_user_id=self.user_id, + user_id=self.sender, + content=json.loads(self.content), + ) diff --git a/synapse/storage/schema/edge_pdus.sql b/synapse/storage/schema/edge_pdus.sql new file mode 100644 index 0000000000..17b3c52f0d --- /dev/null +++ b/synapse/storage/schema/edge_pdus.sql @@ -0,0 +1,31 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS context_edge_pdus( + id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this + pdu_id TEXT, + origin TEXT, + context TEXT, + CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin) +); + +CREATE TABLE IF NOT EXISTS origin_edge_pdus( + id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this + pdu_id TEXT, + origin TEXT, + CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin) +); + +CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin); +CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin); diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql new file mode 100644 index 0000000000..77096546b2 --- /dev/null +++ b/synapse/storage/schema/im.sql @@ -0,0 +1,54 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS rooms( + room_id TEXT PRIMARY KEY NOT NULL, + is_public INTEGER, + creator TEXT +); + +CREATE TABLE IF NOT EXISTS room_memberships( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, -- no foreign key to users table, it could be an id belonging to another home server + sender TEXT NOT NULL, + room_id TEXT NOT NULL, + membership TEXT NOT NULL, + content TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS messages( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT, + room_id TEXT, + msg_id TEXT, + content TEXT +); + +CREATE TABLE IF NOT EXISTS feedback( + id INTEGER PRIMARY KEY AUTOINCREMENT, + content TEXT, + feedback_type TEXT, + fb_sender_id TEXT, + msg_id TEXT, + room_id TEXT, + msg_sender_id TEXT +); + +CREATE TABLE IF NOT EXISTS room_data( + id INTEGER PRIMARY KEY AUTOINCREMENT, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + content TEXT +); diff --git a/synapse/storage/schema/pdu.sql b/synapse/storage/schema/pdu.sql new file mode 100644 index 0000000000..ca3de005e9 --- /dev/null +++ b/synapse/storage/schema/pdu.sql @@ -0,0 +1,106 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- Stores pdus and their content +CREATE TABLE IF NOT EXISTS pdus( + pdu_id TEXT, + origin TEXT, + context TEXT, + pdu_type TEXT, + ts INTEGER, + depth INTEGER DEFAULT 0 NOT NULL, + is_state BOOL, + content_json TEXT, + unrecognized_keys TEXT, + outlier BOOL NOT NULL, + have_processed BOOL, + CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) +); + +-- Stores what the current state pdu is for a given (context, pdu_type, key) tuple +CREATE TABLE IF NOT EXISTS state_pdus( + pdu_id TEXT, + origin TEXT, + context TEXT, + pdu_type TEXT, + state_key TEXT, + power_level TEXT, + prev_state_id TEXT, + prev_state_origin TEXT, + CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) + CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin) +); + +CREATE TABLE IF NOT EXISTS current_state( + pdu_id TEXT, + origin TEXT, + context TEXT, + pdu_type TEXT, + state_key TEXT, + CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) + CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE +); + +-- Stores where each pdu we want to send should be sent and the delivery status. +create TABLE IF NOT EXISTS pdu_destinations( + pdu_id TEXT, + origin TEXT, + destination TEXT, + delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered + CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE +); + +CREATE TABLE IF NOT EXISTS pdu_forward_extremities( + pdu_id TEXT, + origin TEXT, + context TEXT, + CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE +); + +CREATE TABLE IF NOT EXISTS pdu_backward_extremities( + pdu_id TEXT, + origin TEXT, + context TEXT, + CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE +); + +CREATE TABLE IF NOT EXISTS pdu_edges( + pdu_id TEXT, + origin TEXT, + prev_pdu_id TEXT, + prev_origin TEXT, + context TEXT, + CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context) +); + +CREATE TABLE IF NOT EXISTS context_depth( + context TEXT, + min_depth INTEGER, + CONSTRAINT uniqueness UNIQUE (context) +); + +CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context); + + +CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin); + +CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin); +-- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination); + +CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context); +CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin); + +CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin); + +CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context); diff --git a/synapse/storage/schema/presence.sql b/synapse/storage/schema/presence.sql new file mode 100644 index 0000000000..b22e3ba863 --- /dev/null +++ b/synapse/storage/schema/presence.sql @@ -0,0 +1,37 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS presence( + user_id INTEGER NOT NULL, + state INTEGER, + status_msg TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) +); + +-- For each of /my/ users which possibly-remote users are allowed to see their +-- presence state +CREATE TABLE IF NOT EXISTS presence_allow_inbound( + observed_user_id INTEGER NOT NULL, + observer_user_id TEXT, -- a UserID, + FOREIGN KEY(observed_user_id) REFERENCES users(id) +); + +-- For each of /my/ users (watcher), which possibly-remote users are they +-- watching? +CREATE TABLE IF NOT EXISTS presence_list( + user_id INTEGER NOT NULL, + observed_user_id TEXT, -- a UserID, + accepted BOOLEAN, + FOREIGN KEY(user_id) REFERENCES users(id) +); diff --git a/synapse/storage/schema/profiles.sql b/synapse/storage/schema/profiles.sql new file mode 100644 index 0000000000..1092d7672c --- /dev/null +++ b/synapse/storage/schema/profiles.sql @@ -0,0 +1,20 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS profiles( + user_id INTEGER NOT NULL, + displayname TEXT, + avatar_url TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) +); diff --git a/synapse/storage/schema/room_aliases.sql b/synapse/storage/schema/room_aliases.sql new file mode 100644 index 0000000000..71a8b90e4d --- /dev/null +++ b/synapse/storage/schema/room_aliases.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS room_aliases( + room_alias TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS room_alias_servers( + room_alias TEXT NOT NULL, + server TEXT NOT NULL +); + + + diff --git a/synapse/storage/schema/transactions.sql b/synapse/storage/schema/transactions.sql new file mode 100644 index 0000000000..4b1a2368f6 --- /dev/null +++ b/synapse/storage/schema/transactions.sql @@ -0,0 +1,61 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- Stores what transaction ids we have received and what our response was +CREATE TABLE IF NOT EXISTS received_transactions( + transaction_id TEXT, + origin TEXT, + ts INTEGER, + response_code INTEGER, + response_json TEXT, + has_been_referenced BOOL default 0, -- Whether thishas been referenced by a prev_tx + CONSTRAINT uniquesss UNIQUE (transaction_id, origin) ON CONFLICT REPLACE +); + +CREATE UNIQUE INDEX IF NOT EXISTS transactions_txid ON received_transactions(transaction_id, origin); +CREATE INDEX IF NOT EXISTS transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0; + + +-- Stores what transactions we've sent, what their response was (if we got one) and whether we have +-- since referenced the transaction in another outgoing transaction +CREATE TABLE IF NOT EXISTS sent_transactions( + id INTEGER PRIMARY KEY AUTOINCREMENT, -- This is used to apply insertion ordering + transaction_id TEXT, + destination TEXT, + response_code INTEGER DEFAULT 0, + response_json TEXT, + ts INTEGER +); + +CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination); +CREATE INDEX IF NOT EXISTS sent_transaction_dest_referenced ON sent_transactions( + destination +); +-- So that we can do an efficient look up of all transactions that have yet to be successfully +-- sent. +CREATE INDEX IF NOT EXISTS sent_transaction_sent ON sent_transactions(response_code); + + +-- For sent transactions only. +CREATE TABLE IF NOT EXISTS transaction_id_to_pdu( + transaction_id INTEGER, + destination TEXT, + pdu_id TEXT, + pdu_origin TEXT +); + +CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination); +CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); +CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination); + diff --git a/synapse/storage/schema/users.sql b/synapse/storage/schema/users.sql new file mode 100644 index 0000000000..46b60297cb --- /dev/null +++ b/synapse/storage/schema/users.sql @@ -0,0 +1,31 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS users( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT, + password_hash TEXT, + creation_ts INTEGER, + UNIQUE(name) ON CONFLICT ROLLBACK +); + +CREATE TABLE IF NOT EXISTS access_tokens( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + device_id TEXT, + token TEXT NOT NULL, + last_used INTEGER, + FOREIGN KEY(user_id) REFERENCES users(id), + UNIQUE(token) ON CONFLICT ROLLBACK +); diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py new file mode 100644 index 0000000000..c3b1bfeb32 --- /dev/null +++ b/synapse/storage/stream.py @@ -0,0 +1,282 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from .message import MessagesTable +from .feedback import FeedbackTable +from .roomdata import RoomDataTable +from .roommember import RoomMemberTable + +import json +import logging + +logger = logging.getLogger(__name__) + + +class StreamStore(SQLBaseStore): + + def get_message_stream(self, user_id, from_key, to_key, room_id, limit=0, + with_feedback=False): + """Get all messages for this user between the given keys. + + Args: + user_id (str): The user who is requesting messages. + from_key (int): The ID to start returning results from (exclusive). + to_key (int): The ID to stop returning results (exclusive). + room_id (str): Gets messages only for this room. Can be None, in + which case all room messages will be returned. + Returns: + A tuple of rows (list of namedtuples), new_id(int) + """ + if with_feedback and room_id: # with fb MUST specify a room ID + return self._db_pool.runInteraction( + self._get_message_rows_with_feedback, + user_id, from_key, to_key, room_id, limit + ) + else: + return self._db_pool.runInteraction( + self._get_message_rows, + user_id, from_key, to_key, room_id, limit + ) + + def _get_message_rows(self, txn, user_id, from_pkey, to_pkey, room_id, + limit): + # work out which rooms this user is joined in on and join them with + # the room id on the messages table, bounded by the specified pkeys + + # get all messages where the *current* membership state is 'join' for + # this user in that room. + query = ("SELECT messages.* FROM messages WHERE ? IN" + + " (SELECT membership from room_memberships WHERE user_id=?" + + " AND room_id = messages.room_id ORDER BY id DESC LIMIT 1)") + query_args = ["join", user_id] + + if room_id: + query += " AND messages.room_id=?" + query_args.append(room_id) + + (query, query_args) = self._append_stream_operations( + "messages", query, query_args, from_pkey, to_pkey, limit=limit + ) + + logger.debug("[SQL] %s : %s", query, query_args) + cursor = txn.execute(query, query_args) + return self._as_events(cursor, MessagesTable, from_pkey) + + def _get_message_rows_with_feedback(self, txn, user_id, from_pkey, to_pkey, + room_id, limit): + # this col represents the compressed feedback JSON as per spec + compressed_feedback_col = ( + "'[' || group_concat('{\"sender_id\":\"' || f.fb_sender_id" + + " || '\",\"feedback_type\":\"' || f.feedback_type" + + " || '\",\"content\":' || f.content || '}') || ']'" + ) + + global_msg_id_join = ("f.room_id = messages.room_id" + + " and f.msg_id = messages.msg_id" + + " and messages.user_id = f.msg_sender_id") + + select_query = ( + "SELECT messages.*, f.content AS fb_content, f.fb_sender_id" + + ", " + compressed_feedback_col + " AS compressed_fb" + + " FROM messages LEFT JOIN feedback f ON " + global_msg_id_join) + + current_membership_sub_query = ( + "(SELECT membership from room_memberships rm" + + " WHERE user_id=? AND room_id = rm.room_id" + + " ORDER BY id DESC LIMIT 1)") + + where = (" WHERE ? IN " + current_membership_sub_query + + " AND messages.room_id=?") + + query = select_query + where + query_args = ["join", user_id, room_id] + + (query, query_args) = self._append_stream_operations( + "messages", query, query_args, from_pkey, to_pkey, + limit=limit, group_by=" GROUP BY messages.id " + ) + + logger.debug("[SQL] %s : %s", query, query_args) + cursor = txn.execute(query, query_args) + + # convert the result set into events + entries = self.cursor_to_dict(cursor) + events = [] + for entry in entries: + # TODO we should spec the cursor > event mapping somewhere else. + event = {} + straight_mappings = ["msg_id", "user_id", "room_id"] + for key in straight_mappings: + event[key] = entry[key] + event["content"] = json.loads(entry["content"]) + if entry["compressed_fb"]: + event["feedback"] = json.loads(entry["compressed_fb"]) + events.append(event) + + latest_pkey = from_pkey if len(entries) == 0 else entries[-1]["id"] + + return (events, latest_pkey) + + def get_room_member_stream(self, user_id, from_key, to_key): + """Get all room membership events for this user between the given keys. + + Args: + user_id (str): The user who is requesting membership events. + from_key (int): The ID to start returning results from (exclusive). + to_key (int): The ID to stop returning results (exclusive). + Returns: + A tuple of rows (list of namedtuples), new_id(int) + """ + return self._db_pool.runInteraction( + self._get_room_member_rows, user_id, from_key, to_key + ) + + def _get_room_member_rows(self, txn, user_id, from_pkey, to_pkey): + # get all room membership events for rooms which the user is + # *currently* joined in on, or all invite events for this user. + current_membership_sub_query = ( + "(SELECT membership FROM room_memberships" + + " WHERE user_id=? AND room_id = rm.room_id" + + " ORDER BY id DESC LIMIT 1)") + + query = ("SELECT rm.* FROM room_memberships rm " + # all membership events for rooms you've currently joined. + + " WHERE (? IN " + current_membership_sub_query + # all invite membership events for this user + + " OR rm.membership=? AND user_id=?)" + + " AND rm.id > ?") + query_args = ["join", user_id, "invite", user_id, from_pkey] + + if to_pkey != -1: + query += " AND rm.id < ?" + query_args.append(to_pkey) + + cursor = txn.execute(query, query_args) + return self._as_events(cursor, RoomMemberTable, from_pkey) + + def get_feedback_stream(self, user_id, from_key, to_key, room_id, limit=0): + return self._db_pool.runInteraction( + self._get_feedback_rows, + user_id, from_key, to_key, room_id, limit + ) + + def _get_feedback_rows(self, txn, user_id, from_pkey, to_pkey, room_id, + limit): + # work out which rooms this user is joined in on and join them with + # the room id on the feedback table, bounded by the specified pkeys + + # get all messages where the *current* membership state is 'join' for + # this user in that room. + query = ( + "SELECT feedback.* FROM feedback WHERE ? IN " + + "(SELECT membership from room_memberships WHERE user_id=?" + + " AND room_id = feedback.room_id ORDER BY id DESC LIMIT 1)") + query_args = ["join", user_id] + + if room_id: + query += " AND feedback.room_id=?" + query_args.append(room_id) + + (query, query_args) = self._append_stream_operations( + "feedback", query, query_args, from_pkey, to_pkey, limit=limit + ) + + logger.debug("[SQL] %s : %s", query, query_args) + cursor = txn.execute(query, query_args) + return self._as_events(cursor, FeedbackTable, from_pkey) + + def get_room_data_stream(self, user_id, from_key, to_key, room_id, + limit=0): + return self._db_pool.runInteraction( + self._get_room_data_rows, + user_id, from_key, to_key, room_id, limit + ) + + def _get_room_data_rows(self, txn, user_id, from_pkey, to_pkey, room_id, + limit): + # work out which rooms this user is joined in on and join them with + # the room id on the feedback table, bounded by the specified pkeys + + # get all messages where the *current* membership state is 'join' for + # this user in that room. + query = ( + "SELECT room_data.* FROM room_data WHERE ? IN " + + "(SELECT membership from room_memberships WHERE user_id=?" + + " AND room_id = room_data.room_id ORDER BY id DESC LIMIT 1)") + query_args = ["join", user_id] + + if room_id: + query += " AND room_data.room_id=?" + query_args.append(room_id) + + (query, query_args) = self._append_stream_operations( + "room_data", query, query_args, from_pkey, to_pkey, limit=limit + ) + + logger.debug("[SQL] %s : %s", query, query_args) + cursor = txn.execute(query, query_args) + return self._as_events(cursor, RoomDataTable, from_pkey) + + def _append_stream_operations(self, table_name, query, query_args, + from_pkey, to_pkey, limit=None, + group_by=""): + LATEST_ROW = -1 + order_by = "" + if to_pkey > from_pkey: + if from_pkey != LATEST_ROW: + # e.g. from=5 to=9 >> from 5 to 9 >> id>5 AND id<9 + query += (" AND %s.id > ? AND %s.id < ?" % + (table_name, table_name)) + query_args.append(from_pkey) + query_args.append(to_pkey) + else: + # e.g. from=-1 to=5 >> from now to 5 >> id>5 ORDER BY id DESC + query += " AND %s.id > ? " % table_name + order_by = "ORDER BY id DESC" + query_args.append(to_pkey) + elif from_pkey > to_pkey: + if to_pkey != LATEST_ROW: + # from=9 to=5 >> from 9 to 5 >> id>5 AND id<9 ORDER BY id DESC + query += (" AND %s.id > ? AND %s.id < ? " % + (table_name, table_name)) + order_by = "ORDER BY id DESC" + query_args.append(to_pkey) + query_args.append(from_pkey) + else: + # from=5 to=-1 >> from 5 to now >> id>5 + query += " AND %s.id > ?" % table_name + query_args.append(from_pkey) + + query += group_by + order_by + + if limit and limit > 0: + query += " LIMIT ?" + query_args.append(str(limit)) + + return (query, query_args) + + def _as_events(self, cursor, table, from_pkey): + data_entries = table.decode_results(cursor) + last_pkey = from_pkey + if data_entries: + last_pkey = data_entries[-1].id + + events = [ + entry.as_event(self.event_factory).get_dict() + for entry in data_entries + ] + + return (events, last_pkey) diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py new file mode 100644 index 0000000000..aa41e2ad7f --- /dev/null +++ b/synapse/storage/transactions.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore, Table +from .pdu import PdusTable + +from collections import namedtuple + +import logging + +logger = logging.getLogger(__name__) + + +class TransactionStore(SQLBaseStore): + """A collection of queries for handling PDUs. + """ + + def get_received_txn_response(self, transaction_id, origin): + """For an incoming transaction from a given origin, check if we have + already responded to it. If so, return the response code and response + body (as a dict). + + Args: + transaction_id (str) + origin(str) + + Returns: + tuple: None if we have not previously responded to + this transaction or a 2-tuple of (int, dict) + """ + + return self._db_pool.runInteraction( + self._get_received_txn_response, transaction_id, origin + ) + + def _get_received_txn_response(self, txn, transaction_id, origin): + where_clause = "transaction_id = ? AND origin = ?" + query = ReceivedTransactionsTable.select_statement(where_clause) + + txn.execute(query, (transaction_id, origin)) + + results = ReceivedTransactionsTable.decode_results(txn.fetchall()) + + if results and results[0].response_code: + return (results[0].response_code, results[0].response_json) + else: + return None + + def set_received_txn_response(self, transaction_id, origin, code, + response_dict): + """Persist the response we returened for an incoming transaction, and + should return for subsequent transactions with the same transaction_id + and origin. + + Args: + txn + transaction_id (str) + origin (str) + code (int) + response_json (str) + """ + + return self._db_pool.runInteraction( + self._set_received_txn_response, + transaction_id, origin, code, response_dict + ) + + def _set_received_txn_response(self, txn, transaction_id, origin, code, + response_json): + query = ( + "UPDATE %s " + "SET response_code = ?, response_json = ? " + "WHERE transaction_id = ? AND origin = ?" + ) % ReceivedTransactionsTable.table_name + + txn.execute(query, (code, response_json, transaction_id, origin)) + + def prep_send_transaction(self, transaction_id, destination, ts, pdu_list): + """Persists an outgoing transaction and calculates the values for the + previous transaction id list. + + This should be called before sending the transaction so that it has the + correct value for the `prev_ids` key. + + Args: + transaction_id (str) + destination (str) + ts (int) + pdu_list (list) + + Returns: + list: A list of previous transaction ids. + """ + + return self._db_pool.runInteraction( + self._prep_send_transaction, + transaction_id, destination, ts, pdu_list + ) + + def _prep_send_transaction(self, txn, transaction_id, destination, ts, + pdu_list): + + # First we find out what the prev_txs should be. + # Since we know that we are only sending one transaction at a time, + # we can simply take the last one. + query = "%s ORDER BY id DESC LIMIT 1" % ( + SentTransactions.select_statement("destination = ?"), + ) + + results = txn.execute(query, (destination,)) + results = SentTransactions.decode_results(results) + + prev_txns = [r.transaction_id for r in results] + + # Actually add the new transaction to the sent_transactions table. + + query = SentTransactions.insert_statement() + txn.execute(query, SentTransactions.EntryType( + None, + transaction_id=transaction_id, + destination=destination, + ts=ts, + response_code=0, + response_json=None + )) + + # Update the tx id -> pdu id mapping + + values = [ + (transaction_id, destination, pdu[0], pdu[1]) + for pdu in pdu_list + ] + + logger.debug("Inserting: %s", repr(values)) + + query = TransactionsToPduTable.insert_statement() + txn.executemany(query, values) + + return prev_txns + + def delivered_txn(self, transaction_id, destination, code, response_dict): + """Persists the response for an outgoing transaction. + + Args: + transaction_id (str) + destination (str) + code (int) + response_json (str) + """ + return self._db_pool.runInteraction( + self._delivered_txn, + transaction_id, destination, code, response_dict + ) + + def _delivered_txn(cls, txn, transaction_id, destination, + code, response_json): + query = ( + "UPDATE %s " + "SET response_code = ?, response_json = ? " + "WHERE transaction_id = ? AND destination = ?" + ) % SentTransactions.table_name + + txn.execute(query, (code, response_json, transaction_id, destination)) + + def get_transactions_after(self, transaction_id, destination): + """Get all transactions after a given local transaction_id. + + Args: + transaction_id (str) + destination (str) + + Returns: + list: A list of `ReceivedTransactionsTable.EntryType` + """ + return self._db_pool.runInteraction( + self._get_transactions_after, transaction_id, destination + ) + + def _get_transactions_after(cls, txn, transaction_id, destination): + where = ( + "destination = ? AND id > (select id FROM %s WHERE " + "transaction_id = ? AND destination = ?)" + ) % ( + SentTransactions.table_name + ) + query = SentTransactions.select_statement(where) + + txn.execute(query, (destination, transaction_id, destination)) + + return ReceivedTransactionsTable.decode_results(txn.fetchall()) + + def get_pdus_after_transaction(self, transaction_id, destination): + """For a given local transaction_id that we sent to a given destination + home server, return a list of PDUs that were sent to that destination + after it. + + Args: + txn + transaction_id (str) + destination (str) + + Returns + list: A list of PduTuple + """ + return self._db_pool.runInteraction( + self._get_pdus_after_transaction, + transaction_id, destination + ) + + def _get_pdus_after_transaction(self, txn, transaction_id, destination): + + # Query that first get's all transaction_ids with an id greater than + # the one given from the `sent_transactions` table. Then JOIN on this + # from the `tx->pdu` table to get a list of (pdu_id, origin) that + # specify the pdus that were sent in those transactions. + query = ( + "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp " + "INNER JOIN %(sent_tx)s as st " + "ON tp.transaction_id = st.transaction_id " + "AND tp.destination = st.destination " + "WHERE st.id > (" + "SELECT id FROM %(sent_tx)s " + "WHERE transaction_id = ? AND destination = ?" + ) % { + "tx_pdu": TransactionsToPduTable.table_name, + "sent_tx": SentTransactions.table_name, + } + + txn.execute(query, (transaction_id, destination)) + + pdus = PdusTable.decode_results(txn.fetchall()) + + return self._get_pdu_tuples(txn, pdus) + + +class ReceivedTransactionsTable(Table): + table_name = "received_transactions" + + fields = [ + "transaction_id", + "origin", + "ts", + "response_code", + "response_json", + "has_been_referenced", + ] + + EntryType = namedtuple("ReceivedTransactionsEntry", fields) + + +class SentTransactions(Table): + table_name = "sent_transactions" + + fields = [ + "id", + "transaction_id", + "destination", + "ts", + "response_code", + "response_json", + ] + + EntryType = namedtuple("SentTransactionsEntry", fields) + + +class TransactionsToPduTable(Table): + table_name = "transaction_id_to_pdu" + + fields = [ + "transaction_id", + "destination", + "pdu_id", + "pdu_origin", + ] + + EntryType = namedtuple("TransactionsToPduEntry", fields) diff --git a/synapse/types.py b/synapse/types.py new file mode 100644 index 0000000000..1adc95bbb0 --- /dev/null +++ b/synapse/types.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import SynapseError + +from collections import namedtuple + + +class DomainSpecificString( + namedtuple("DomainSpecificString", ("localpart", "domain", "is_mine")) +): + """Common base class among ID/name strings that have a local part and a + domain name, prefixed with a sigil. + + Has the fields: + + 'localpart' : The local part of the name (without the leading sigil) + 'domain' : The domain part of the name + 'is_mine' : Boolean indicating if the domain name is recognised by the + HomeServer as being its own + """ + + @classmethod + def from_string(cls, s, hs): + """Parse the string given by 's' into a structure object.""" + if s[0] != cls.SIGIL: + raise SynapseError(400, "Expected %s string to start with '%s'" % ( + cls.__name__, cls.SIGIL, + )) + + parts = s[1:].split(':', 1) + if len(parts) != 2: + raise SynapseError( + 400, "Expected %s of the form '%slocalname:domain'" % ( + cls.__name__, cls.SIGIL, + ) + ) + + domain = parts[1] + + # This code will need changing if we want to support multiple domain + # names on one HS + is_mine = domain == hs.hostname + return cls(localpart=parts[0], domain=domain, is_mine=is_mine) + + def to_string(self): + """Return a string encoding the fields of the structure object.""" + return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) + + @classmethod + def create_local(cls, localpart, hs): + """Create a structure on the local domain""" + return cls(localpart=localpart, domain=hs.hostname, is_mine=True) + + +class UserID(DomainSpecificString): + """Structure representing a user ID.""" + SIGIL = "@" + + +class RoomAlias(DomainSpecificString): + """Structure representing a room name.""" + SIGIL = "#" + + +class RoomID(DomainSpecificString): + """Structure representing a room id. """ + SIGIL = "!" diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py new file mode 100644 index 0000000000..5361cb7ec2 --- /dev/null +++ b/synapse/util/__init__.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import reactor + +import time + + +class Clock(object): + """A small utility that obtains current time-of-day so that time may be + mocked during unit-tests. + + TODO(paul): Also move the sleep() functionallity into it + """ + + def time(self): + """Returns the current system time in seconds since epoch.""" + return time.time() + + def time_msec(self): + """Returns the current system time in miliseconds since epoch.""" + return self.time() * 1000 + + def call_later(self, delay, callback): + return reactor.callLater(delay, callback) + + def cancel_call_later(self, timer): + timer.cancel() diff --git a/synapse/util/async.py b/synapse/util/async.py new file mode 100644 index 0000000000..e04db8e285 --- /dev/null +++ b/synapse/util/async.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer, reactor + + +def sleep(seconds): + d = defer.Deferred() + reactor.callLater(seconds, d.callback, seconds) + return d diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py new file mode 100644 index 0000000000..32d19402b4 --- /dev/null +++ b/synapse/util/distributor.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +import logging + + +logger = logging.getLogger(__name__) + + +class Distributor(object): + """A central dispatch point for loosely-connected pieces of code to + register, observe, and fire signals. + + Signals are named simply by strings. + + TODO(paul): It would be nice to give signals stronger object identities, + so we can attach metadata, docstrings, detect typoes, etc... But this + model will do for today. + """ + + def __init__(self): + self.signals = {} + self.pre_registration = {} + + def declare(self, name): + if name in self.signals: + raise KeyError("%r already has a signal named %s" % (self, name)) + + self.signals[name] = Signal(name) + + if name in self.pre_registration: + signal = self.signals[name] + for observer in self.pre_registration[name]: + signal.observe(observer) + + def observe(self, name, observer): + if name in self.signals: + self.signals[name].observe(observer) + else: + # TODO: Avoid strong ordering dependency by allowing people to + # pre-register observations on signals that don't exist yet. + if name not in self.pre_registration: + self.pre_registration[name] = [] + self.pre_registration[name].append(observer) + + def fire(self, name, *args, **kwargs): + if name not in self.signals: + raise KeyError("%r does not have a signal named %s" % (self, name)) + + return self.signals[name].fire(*args, **kwargs) + + +class Signal(object): + """A Signal is a dispatch point that stores a list of callables as + observers of it. + + Signals can be "fired", meaning that every callable observing it is + invoked. Firing a signal does not change its state; it can be fired again + at any later point. Firing a signal passes any arguments from the fire + method into all of the observers. + """ + + def __init__(self, name): + self.name = name + self.observers = [] + + def observe(self, observer): + """Adds a new callable to the observer list which will be invoked by + the 'fire' method. + + Each observer callable may return a Deferred.""" + self.observers.append(observer) + + def fire(self, *args, **kwargs): + """Invokes every callable in the observer list, passing in the args and + kwargs. Exceptions thrown by observers are logged but ignored. It is + not an error to fire a signal with no observers. + + Returns a Deferred that will complete when all the observers have + completed.""" + deferreds = [] + for observer in self.observers: + d = defer.maybeDeferred(observer, *args, **kwargs) + + def eb(failure): + logger.warning( + "%s signal observer %s failed: %r", + self.name, observer, failure, + exc_info=( + failure.type, + failure.value, + failure.getTracebackObject())) + deferreds.append(d.addErrback(eb)) + + return defer.DeferredList(deferreds) diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py new file mode 100644 index 0000000000..190a80a322 --- /dev/null +++ b/synapse/util/jsonobject.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +class JsonEncodedObject(object): + """ A common base class for defining protocol units that are represented + as JSON. + + Attributes: + unrecognized_keys (dict): A dict containing all the key/value pairs we + don't recognize. + """ + + valid_keys = [] # keys we will store + """A list of strings that represent keys we know about + and can handle. If we have values for these keys they will be + included in the `dictionary` instance variable. + """ + + internal_keys = [] # keys to ignore while building dict + """A list of strings that should *not* be encoded into JSON. + """ + + required_keys = [] + """A list of strings that we require to exist. If they are not given upon + construction it raises an exception. + """ + + def __init__(self, **kwargs): + """ Takes the dict of `kwargs` and loads all keys that are *valid* + (i.e., are included in the `valid_keys` list) into the dictionary` + instance variable. + + Any keys that aren't recognized are added to the `unrecognized_keys` + attribute. + + Args: + **kwargs: Attributes associated with this protocol unit. + """ + for required_key in self.required_keys: + if required_key not in kwargs: + raise RuntimeError("Key %s is required" % required_key) + + self.unrecognized_keys = {} # Keys we were given not listed as valid + for k, v in kwargs.items(): + if k in self.valid_keys or k in self.internal_keys: + self.__dict__[k] = v + else: + self.unrecognized_keys[k] = v + + def get_dict(self): + """ Converts this protocol unit into a :py:class:`dict`, ready to be + encoded as JSON. + + The keys it encodes are: `valid_keys` - `internal_keys` + + Returns + dict + """ + d = { + k: _encode(v) for (k, v) in self.__dict__.items() + if k in self.valid_keys and k not in self.internal_keys + } + d.update(self.unrecognized_keys) + return copy.deepcopy(d) + + def get_full_dict(self): + d = { + k: v for (k, v) in self.__dict__.items() + if k in self.valid_keys or k in self.internal_keys + } + d.update(self.unrecognized_keys) + return copy.deepcopy(d) + + def __str__(self): + return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) + +def _encode(obj): + if type(obj) is list: + return [_encode(o) for o in obj] + + if isinstance(obj, JsonEncodedObject): + return obj.get_dict() + + return obj diff --git a/synapse/util/lockutils.py b/synapse/util/lockutils.py new file mode 100644 index 0000000000..e4d609d84e --- /dev/null +++ b/synapse/util/lockutils.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import logging + + +logger = logging.getLogger(__name__) + + +class Lock(object): + + def __init__(self, deferred): + self._deferred = deferred + self.released = False + + def release(self): + self.released = True + self._deferred.callback(None) + + def __del__(self): + if not self.released: + logger.critical("Lock was destructed but never released!") + self.release() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.release() + + +class LockManager(object): + """ Utility class that allows us to lock based on a `key` """ + + def __init__(self): + self._lock_deferreds = {} + + @defer.inlineCallbacks + def lock(self, key): + """ Allows us to block until it is our turn. + Args: + key (str) + Returns: + Lock + """ + new_deferred = defer.Deferred() + old_deferred = self._lock_deferreds.get(key) + self._lock_deferreds[key] = new_deferred + + if old_deferred: + yield old_deferred + + defer.returnValue(Lock(new_deferred)) diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py new file mode 100644 index 0000000000..08d5aafca4 --- /dev/null +++ b/synapse/util/logutils.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from inspect import getcallargs + +import logging + + +def log_function(f): + """ Function decorator that logs every call to that function. + """ + func_name = f.__name__ + lineno = f.func_code.co_firstlineno + pathname = f.func_code.co_filename + + def wrapped(*args, **kwargs): + name = f.__module__ + logger = logging.getLogger(name) + level = logging.DEBUG + + if logger.isEnabledFor(level): + bound_args = getcallargs(f, *args, **kwargs) + + def format(value): + r = str(value) + if len(r) > 50: + r = r[:50] + "..." + return r + + func_args = [ + "%s=%s" % (k, format(v)) for k, v in bound_args.items() + ] + + msg_args = { + "func_name": func_name, + "args": ", ".join(func_args) + } + + record = logging.LogRecord( + name=name, + level=level, + pathname=pathname, + lineno=lineno, + msg="Invoked '%(func_name)s' with args: %(args)s", + args=msg_args, + exc_info=None + ) + + logger.handle(record) + + return f(*args, **kwargs) + + return wrapped diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py new file mode 100644 index 0000000000..91550583a4 --- /dev/null +++ b/synapse/util/stringutils.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +import string + + +def origin_from_ucid(ucid): + return ucid.split("@", 1)[1] + + +def random_string(length): + return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) |