diff options
Diffstat (limited to 'synapse')
95 files changed, 6156 insertions, 1939 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index dacda89865..012425fa60 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" This is a reference implementation of a synapse home server. +""" This is a reference implementation of a Matrix home server. """ __version__ = "0.6.1d" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e31482cfaa..37e31d2b6f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -21,6 +21,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor +from synapse.types import UserID, ClientInfo import logging @@ -101,10 +102,12 @@ class Auth(object): def check_host_in_room(self, room_id, host): curr_state = yield self.state.get_current_state(room_id) + logger.debug("Got curr_state %s", curr_state) + for event in curr_state: if event.type == EventTypes.Member: try: - if self.hs.parse_userid(event.state_key).domain != host: + if UserID.from_string(event.state_key).domain != host: continue except: logger.warn("state_key not user_id: %s", event.state_key) @@ -289,7 +292,9 @@ class Auth(object): Args: request - An HTTP request with an access_token query parameter. Returns: - UserID : User ID object of the user making the request + tuple : of UserID and device string: + User ID object of the user making the request + Client ID object of the client instance the user is using Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -298,6 +303,8 @@ class Auth(object): access_token = request.args["access_token"][0] user_info = yield self.get_user_by_token(access_token) user = user_info["user"] + device_id = user_info["device_id"] + token_id = user_info["token_id"] ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( @@ -313,7 +320,7 @@ class Auth(object): user_agent=user_agent ) - defer.returnValue(user) + defer.returnValue((user, ClientInfo(device_id, token_id))) except KeyError: raise AuthError(403, "Missing access token.") @@ -337,7 +344,8 @@ class Auth(object): user_info = { "admin": bool(ret.get("admin", False)), "device_id": ret.get("device_id"), - "user": self.hs.parse_userid(ret.get("name")), + "user": UserID.from_string(ret.get("name")), + "token_id": ret.get("token_id", None), } defer.returnValue(user_info) @@ -352,9 +360,23 @@ class Auth(object): def add_auth_events(self, builder, context): yield run_on_reactor() - if builder.type == EventTypes.Create: - builder.auth_events = [] - return + auth_ids = self.compute_auth_events(builder, context) + + auth_events_entries = yield self.store.add_event_hashes( + auth_ids + ) + + builder.auth_events = auth_events_entries + + context.auth_events = { + k: v + for k, v in context.current_state.items() + if v.event_id in auth_ids + } + + def compute_auth_events(self, event, context): + if event.type == EventTypes.Create: + return [] auth_ids = [] @@ -367,7 +389,7 @@ class Auth(object): key = (EventTypes.JoinRules, "", ) join_rule_event = context.current_state.get(key) - key = (EventTypes.Member, builder.user_id, ) + key = (EventTypes.Member, event.user_id, ) member_event = context.current_state.get(key) key = (EventTypes.Create, "", ) @@ -381,8 +403,8 @@ class Auth(object): else: is_public = False - if builder.type == EventTypes.Member: - e_type = builder.content["membership"] + if event.type == EventTypes.Member: + e_type = event.content["membership"] if e_type in [Membership.JOIN, Membership.INVITE]: if join_rule_event: auth_ids.append(join_rule_event.event_id) @@ -397,17 +419,7 @@ class Auth(object): if member_event.content["membership"] == Membership.JOIN: auth_ids.append(member_event.event_id) - auth_events_entries = yield self.store.add_event_hashes( - auth_ids - ) - - builder.auth_events = auth_events_entries - - context.auth_events = { - k: v - for k, v in context.current_state.items() - if v.event_id in auth_ids - } + return auth_ids @log_function def _can_send_event(self, event, auth_events): @@ -461,7 +473,7 @@ class Auth(object): "You are not allowed to set others state" ) else: - sender_domain = self.hs.parse_userid( + sender_domain = UserID.from_string( event.user_id ).domain @@ -496,7 +508,7 @@ class Auth(object): # Validate users for k, v in user_list.items(): try: - self.hs.parse_userid(k) + UserID.from_string(k) except: raise SynapseError(400, "Not a valid user_id: %s" % (k,)) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 7ee6dcc46e..0d3fc629af 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -74,3 +74,9 @@ class EventTypes(object): Message = "m.room.message" Topic = "m.room.topic" Name = "m.room.name" + + +class RejectedReason(object): + AUTH_ERROR = "auth_error" + REPLACED = "replaced" + NOT_ANCESTOR = "not_ancestor" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 2b049debf3..ad478aa6b7 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -21,6 +21,7 @@ logger = logging.getLogger(__name__) class Codes(object): + UNRECOGNIZED = "M_UNRECOGNIZED" UNAUTHORIZED = "M_UNAUTHORIZED" FORBIDDEN = "M_FORBIDDEN" BAD_JSON = "M_BAD_JSON" @@ -34,6 +35,7 @@ class Codes(object): LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" + MISSING_PARAM = "M_MISSING_PARAM", TOO_LARGE = "M_TOO_LARGE" @@ -81,6 +83,35 @@ class RegistrationError(SynapseError): pass +class UnrecognizedRequestError(SynapseError): + """An error indicating we don't understand the request you're trying to make""" + def __init__(self, *args, **kwargs): + if "errcode" not in kwargs: + kwargs["errcode"] = Codes.UNRECOGNIZED + message = None + if len(args) == 0: + message = "Unrecognized request" + else: + message = args[0] + super(UnrecognizedRequestError, self).__init__( + 400, + message, + **kwargs + ) + + +class NotFoundError(SynapseError): + """An error indicating we can't find the thing you asked for""" + def __init__(self, *args, **kwargs): + if "errcode" not in kwargs: + kwargs["errcode"] = Codes.NOT_FOUND + super(NotFoundError, self).__init__( + 404, + "Not found", + **kwargs + ) + + class AuthError(SynapseError): """An error raised when there was a problem authorising an event.""" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py new file mode 100644 index 0000000000..4d570b74f8 --- /dev/null +++ b/synapse/api/filtering.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import SynapseError +from synapse.types import UserID, RoomID + + +class Filtering(object): + + def __init__(self, hs): + super(Filtering, self).__init__() + self.store = hs.get_datastore() + + def get_user_filter(self, user_localpart, filter_id): + result = self.store.get_user_filter(user_localpart, filter_id) + result.addCallback(Filter) + return result + + def add_user_filter(self, user_localpart, user_filter): + self._check_valid_filter(user_filter) + return self.store.add_user_filter(user_localpart, user_filter) + + # TODO(paul): surely we should probably add a delete_user_filter or + # replace_user_filter at some point? There's no REST API specified for + # them however + + def _check_valid_filter(self, user_filter_json): + """Check if the provided filter is valid. + + This inspects all definitions contained within the filter. + + Args: + user_filter_json(dict): The filter + Raises: + SynapseError: If the filter is not valid. + """ + # NB: Filters are the complete json blobs. "Definitions" are an + # individual top-level key e.g. public_user_data. Filters are made of + # many definitions. + + top_level_definitions = [ + "public_user_data", "private_user_data", "server_data" + ] + + room_level_definitions = [ + "state", "events", "ephemeral" + ] + + for key in top_level_definitions: + if key in user_filter_json: + self._check_definition(user_filter_json[key]) + + if "room" in user_filter_json: + for key in room_level_definitions: + if key in user_filter_json["room"]: + self._check_definition(user_filter_json["room"][key]) + + def _check_definition(self, definition): + """Check if the provided definition is valid. + + This inspects not only the types but also the values to make sure they + make sense. + + Args: + definition(dict): The filter definition + Raises: + SynapseError: If there was a problem with this definition. + """ + # NB: Filters are the complete json blobs. "Definitions" are an + # individual top-level key e.g. public_user_data. Filters are made of + # many definitions. + if type(definition) != dict: + raise SynapseError( + 400, "Expected JSON object, not %s" % (definition,) + ) + + # check rooms are valid room IDs + room_id_keys = ["rooms", "not_rooms"] + for key in room_id_keys: + if key in definition: + if type(definition[key]) != list: + raise SynapseError(400, "Expected %s to be a list." % key) + for room_id in definition[key]: + RoomID.from_string(room_id) + + # check senders are valid user IDs + user_id_keys = ["senders", "not_senders"] + for key in user_id_keys: + if key in definition: + if type(definition[key]) != list: + raise SynapseError(400, "Expected %s to be a list." % key) + for user_id in definition[key]: + UserID.from_string(user_id) + + # TODO: We don't limit event type values but we probably should... + # check types are valid event types + event_keys = ["types", "not_types"] + for key in event_keys: + if key in definition: + if type(definition[key]) != list: + raise SynapseError(400, "Expected %s to be a list." % key) + for event_type in definition[key]: + if not isinstance(event_type, basestring): + raise SynapseError(400, "Event type should be a string") + + if "format" in definition: + event_format = definition["format"] + if event_format not in ["federation", "events"]: + raise SynapseError(400, "Invalid format: %s" % (event_format,)) + + if "select" in definition: + event_select_list = definition["select"] + for select_key in event_select_list: + if select_key not in ["event_id", "origin_server_ts", + "thread_id", "content", "content.body"]: + raise SynapseError(400, "Bad select: %s" % (select_key,)) + + if ("bundle_updates" in definition and + type(definition["bundle_updates"]) != bool): + raise SynapseError(400, "Bad bundle_updates: expected bool.") + + +class Filter(object): + def __init__(self, filter_json): + self.filter_json = filter_json + + def filter_public_user_data(self, events): + return self._filter_on_key(events, ["public_user_data"]) + + def filter_private_user_data(self, events): + return self._filter_on_key(events, ["private_user_data"]) + + def filter_room_state(self, events): + return self._filter_on_key(events, ["room", "state"]) + + def filter_room_events(self, events): + return self._filter_on_key(events, ["room", "events"]) + + def filter_room_ephemeral(self, events): + return self._filter_on_key(events, ["room", "ephemeral"]) + + def _filter_on_key(self, events, keys): + filter_json = self.filter_json + if not filter_json: + return events + + try: + # extract the right definition from the filter + definition = filter_json + for key in keys: + definition = definition[key] + return self._filter_with_definition(events, definition) + except KeyError: + # return all events if definition isn't specified. + return events + + def _filter_with_definition(self, events, definition): + return [e for e in events if self._passes_definition(definition, e)] + + def _passes_definition(self, definition, event): + """Check if the event passes through the given definition. + + Args: + definition(dict): The definition to check against. + event(Event): The event to check. + Returns: + True if the event passes through the filter. + """ + # Algorithm notes: + # For each key in the definition, check the event meets the criteria: + # * For types: Literal match or prefix match (if ends with wildcard) + # * For senders/rooms: Literal match only + # * "not_" checks take presedence (e.g. if "m.*" is in both 'types' + # and 'not_types' then it is treated as only being in 'not_types') + + # room checks + if hasattr(event, "room_id"): + room_id = event.room_id + allow_rooms = definition.get("rooms", None) + reject_rooms = definition.get("not_rooms", None) + if reject_rooms and room_id in reject_rooms: + return False + if allow_rooms and room_id not in allow_rooms: + return False + + # sender checks + if hasattr(event, "sender"): + # Should we be including event.state_key for some event types? + sender = event.sender + allow_senders = definition.get("senders", None) + reject_senders = definition.get("not_senders", None) + if reject_senders and sender in reject_senders: + return False + if allow_senders and sender not in allow_senders: + return False + + # type checks + if "not_types" in definition: + for def_type in definition["not_types"]: + if self._event_matches_type(event, def_type): + return False + if "types" in definition: + included = False + for def_type in definition["types"]: + if self._event_matches_type(event, def_type): + included = True + break + if not included: + return False + + return True + + def _event_matches_type(self, event, def_type): + if def_type.endswith("*"): + type_prefix = def_type[:-1] + return event.type.startswith(type_prefix) + else: + return event.type == def_type diff --git a/synapse/api/urls.py b/synapse/api/urls.py index a299392049..693c0efda6 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -16,6 +16,7 @@ """Contains the URL paths to prefix various aspects of the server with. """ CLIENT_PREFIX = "/_matrix/client/api/v1" +CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha" FEDERATION_PREFIX = "/_matrix/federation/v1" WEB_CLIENT_PREFIX = "/_matrix/client" CONTENT_REPO_PREFIX = "/_matrix/content" diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 5fec8da7ca..f20ccfb5b6 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -18,27 +18,33 @@ from synapse.storage import prepare_database, UpgradeDatabaseException from synapse.server import HomeServer +from synapse.python_dependencies import check_requirements + from twisted.internet import reactor from twisted.enterprise import adbapi from twisted.web.resource import Resource from twisted.web.static import File from twisted.web.server import Site from synapse.http.server import JsonResource, RootRedirect -from synapse.media.v0.content_repository import ContentRepoResource -from synapse.media.v1.media_repository import MediaRepositoryResource +from synapse.rest.media.v0.content_repository import ContentRepoResource +from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.http.server_key_resource import LocalKey from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.api.urls import ( CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, - SERVER_KEY_PREFIX, MEDIA_PREFIX + SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, ) from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext +from synapse.rest.client.v1 import ClientV1RestResource +from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource from daemonize import Daemonize import twisted.manhole.telnet +import synapse + import logging import os import re @@ -55,7 +61,10 @@ class SynapseHomeServer(HomeServer): return MatrixFederationHttpClient(self) def build_resource_for_client(self): - return JsonResource() + return ClientV1RestResource(self) + + def build_resource_for_client_v2_alpha(self): + return ClientV2AlphaRestResource(self) def build_resource_for_federation(self): return JsonResource() @@ -100,6 +109,7 @@ class SynapseHomeServer(HomeServer): # [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ] desired_tree = [ (CLIENT_PREFIX, self.get_resource_for_client()), + (CLIENT_V2_ALPHA_PREFIX, self.get_resource_for_client_v2_alpha()), (FEDERATION_PREFIX, self.get_resource_for_federation()), (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()), @@ -198,7 +208,10 @@ def setup(): config.setup_logging() + check_requirements() + logger.info("Server hostname: %s", config.server_name) + logger.info("Server version: %s", synapse.__version__) if re.search(":[0-9]+$", config.server_name): domain_with_port = config.server_name @@ -217,8 +230,6 @@ def setup(): content_addr=config.content_addr, ) - hs.register_servlets() - hs.create_resource_tree( web_client=config.webclient, redirect_root_to_web_client=True, @@ -234,13 +245,20 @@ def setup(): except UpgradeDatabaseException: sys.stderr.write( "\nFailed to upgrade database.\n" - "Have you checked for version specific instructions in UPGRADES.rst?\n" + "Have you checked for version specific instructions in" + " UPGRADES.rst?\n" ) sys.exit(1) logger.info("Database prepared in %s.", db_name) - hs.get_db_pool() + db_pool = hs.get_db_pool() + + if db_name == ":memory:": + # Memory databases will need to be setup each time they are opened. + reactor.callWhenRunning( + db_pool.runWithConnection, prepare_database + ) if config.manhole: f = twisted.manhole.telnet.ShellFactory() @@ -254,6 +272,8 @@ def setup(): bind_port = None hs.start_listening(bind_port, config.unsecure_port) + hs.get_pusherpool().start() + if config.daemonize: print config.pid_file daemon = Daemonize( @@ -277,6 +297,7 @@ def run(): def main(): with LoggingContext("main"): + check_requirements() setup() diff --git a/synapse/config/database.py b/synapse/config/database.py index 0d33583a7d..daa161c952 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -20,7 +20,10 @@ import os class DatabaseConfig(Config): def __init__(self, args): super(DatabaseConfig, self).__init__(args) - self.database_path = self.abspath(args.database_path) + if args.database_path == ":memory:": + self.database_path = ":memory:" + else: + self.database_path = self.abspath(args.database_path) @classmethod def add_arguments(cls, parser): diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 15383b3184..f9568ebd21 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -66,7 +66,10 @@ class LoggingConfig(Config): formatter = logging.Formatter(log_format) if self.log_file: - handler = logging.FileHandler(self.log_file) + # TODO: Customisable file size / backup count + handler = logging.handlers.RotatingFileHandler( + self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 + ) else: handler = logging.StreamHandler() handler.setFormatter(formatter) diff --git a/synapse/config/server.py b/synapse/config/server.py index 4f73c85466..31e44cc857 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -47,8 +47,12 @@ class ServerConfig(Config): def add_arguments(cls, parser): super(ServerConfig, cls).add_arguments(parser) server_group = parser.add_argument_group("server") - server_group.add_argument("-H", "--server-name", default="localhost", - help="The name of the server") + server_group.add_argument( + "-H", "--server-name", default="localhost", + help="The domain name of the server, with optional explicit port. " + "This is used by remote servers to connect to this server, " + "e.g. matrix.org, localhost:8080, etc." + ) server_group.add_argument("--signing-key-path", help="The signing key to sign messages with") server_group.add_argument("-p", "--bind-port", metavar="PORT", diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 9c910fa3fc..cdb6279764 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -61,9 +61,11 @@ class SynapseKeyClientProtocol(HTTPClient): def __init__(self): self.remote_key = defer.Deferred() + self.host = None def connectionMade(self): - logger.debug("Connected to %s", self.transport.getHost()) + self.host = self.transport.getHost() + logger.debug("Connected to %s", self.host) self.sendCommand(b"GET", b"/_matrix/key/v1/") self.endHeaders() self.timer = reactor.callLater( @@ -92,8 +94,7 @@ class SynapseKeyClientProtocol(HTTPClient): self.timer.cancel() def on_timeout(self): - logger.debug("Timeout waiting for response from %s", - self.transport.getHost()) + logger.debug("Timeout waiting for response from %s", self.host) self.remote_key.errback(IOError("Timeout waiting for response")) self.transport.abortConnection() diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 4252e5ab5c..bf07951027 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -18,7 +18,7 @@ from synapse.util.frozenutils import freeze, unfreeze class _EventInternalMetadata(object): def __init__(self, internal_metadata_dict): - self.__dict__ = internal_metadata_dict + self.__dict__ = dict(internal_metadata_dict) def get_dict(self): return dict(self.__dict__) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index d4cb602ebb..9d45bdb892 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -23,22 +23,17 @@ import copy class EventBuilder(EventBase): - def __init__(self, key_values={}): + def __init__(self, key_values={}, internal_metadata_dict={}): signatures = copy.deepcopy(key_values.pop("signatures", {})) unsigned = copy.deepcopy(key_values.pop("unsigned", {})) super(EventBuilder, self).__init__( key_values, signatures=signatures, - unsigned=unsigned + unsigned=unsigned, + internal_metadata_dict=internal_metadata_dict, ) - def update_event_key(self, key, value): - self._event_dict[key] = value - - def update_event_keys(self, other_dict): - self._event_dict.update(other_dict) - def build(self): return FrozenEvent.from_event(self) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 6bbba8d6ba..7e98bdef28 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -20,3 +20,4 @@ class EventContext(object): self.current_state = current_state self.auth_events = auth_events self.state_group = None + self.rejected = False diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 4ad37188ba..1aa952150e 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -45,12 +45,14 @@ def prune_event(event): "membership", ] + event_dict = event.get_dict() + new_content = {} def add_fields(*fields): for field in fields: if field in event.content: - new_content[field] = event.content[field] + new_content[field] = event_dict["content"][field] if event_type == EventTypes.Member: add_fields("membership") @@ -75,7 +77,7 @@ def prune_event(event): allowed_fields = { k: v - for k, v in event.get_dict().items() + for k, v in event_dict.items() if k in allowed_keys } @@ -86,48 +88,78 @@ def prune_event(event): if "age_ts" in event.unsigned: allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] - return type(event)(allowed_fields) + return type(event)( + allowed_fields, + internal_metadata_dict=event.internal_metadata.get_dict() + ) + + +def format_event_raw(d): + return d + + +def format_event_for_client_v1(d): + d["user_id"] = d.pop("sender", None) + + move_keys = ("age", "redacted_because", "replaces_state", "prev_content") + for key in move_keys: + if key in d["unsigned"]: + d[key] = d["unsigned"][key] + + drop_keys = ( + "auth_events", "prev_events", "hashes", "signatures", "depth", + "unsigned", "origin", "prev_state" + ) + for key in drop_keys: + d.pop(key, None) + return d + + +def format_event_for_client_v2(d): + drop_keys = ( + "auth_events", "prev_events", "hashes", "signatures", "depth", + "origin", "prev_state", + ) + for key in drop_keys: + d.pop(key, None) + return d + + +def format_event_for_client_v2_without_event_id(d): + d = format_event_for_client_v2(d) + d.pop("room_id", None) + d.pop("event_id", None) + return d -def serialize_event(hs, e): +def serialize_event(e, time_now_ms, as_client_event=True, + event_format=format_event_for_client_v1, + token_id=None): # FIXME(erikj): To handle the case of presence events and the like if not isinstance(e, EventBase): return e + time_now_ms = int(time_now_ms) + # Should this strip out None's? d = {k: v for k, v in e.get_dict().items()} + if "age_ts" in d["unsigned"]: - now = int(hs.get_clock().time_msec()) - d["age"] = now - d["unsigned"]["age_ts"] + d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"] del d["unsigned"]["age_ts"] - d["user_id"] = d.pop("sender", None) - if "redacted_because" in e.unsigned: - d["redacted_because"] = serialize_event( - hs, e.unsigned["redacted_because"] + d["unsigned"]["redacted_because"] = serialize_event( + e.unsigned["redacted_because"], time_now_ms ) - del d["unsigned"]["redacted_because"] - - if "redacted_by" in e.unsigned: - d["redacted_by"] = e.unsigned["redacted_by"] - del d["unsigned"]["redacted_by"] - - if "replaces_state" in e.unsigned: - d["replaces_state"] = e.unsigned["replaces_state"] - del d["unsigned"]["replaces_state"] + if token_id is not None: + if token_id == getattr(e.internal_metadata, "token_id", None): + txn_id = getattr(e.internal_metadata, "txn_id", None) + if txn_id is not None: + d["unsigned"]["transaction_id"] = txn_id - if "prev_content" in e.unsigned: - d["prev_content"] = e.unsigned["prev_content"] - del d["unsigned"]["prev_content"] - - del d["auth_events"] - del d["prev_events"] - del d["hashes"] - del d["signatures"] - d.pop("depth", None) - d.pop("unsigned", None) - d.pop("origin", None) - - return d + if as_client_event: + return event_format(d) + else: + return d diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py new file mode 100644 index 0000000000..e1539bd0e0 --- /dev/null +++ b/synapse/federation/federation_client.py @@ -0,0 +1,409 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from .units import Edu + +from synapse.util.logutils import log_function +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from syutil.jsonutil import encode_canonical_json + +from synapse.crypto.event_signing import check_event_content_hash + +from synapse.api.errors import SynapseError + +import logging + + +logger = logging.getLogger(__name__) + + +class FederationClient(object): + @log_function + def send_pdu(self, pdu, destinations): + """Informs the replication layer about a new PDU generated within the + home server that should be transmitted to others. + + TODO: Figure out when we should actually resolve the deferred. + + Args: + pdu (Pdu): The new Pdu. + + Returns: + Deferred: Completes when we have successfully processed the PDU + and replicated it to any interested remote home servers. + """ + order = self._order + self._order += 1 + + logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) + + # TODO, add errback, etc. + self._transaction_queue.enqueue_pdu(pdu, destinations, order) + + logger.debug( + "[%s] transaction_layer.enqueue_pdu... done", + pdu.event_id + ) + + @log_function + def send_edu(self, destination, edu_type, content): + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) + + # TODO, add errback, etc. + self._transaction_queue.enqueue_edu(edu) + return defer.succeed(None) + + @log_function + def send_failure(self, failure, destination): + self._transaction_queue.enqueue_failure(failure, destination) + return defer.succeed(None) + + @log_function + def make_query(self, destination, query_type, args, + retry_on_dns_fail=True): + """Sends a federation Query to a remote homeserver of the given type + and arguments. + + Args: + destination (str): Domain name of the remote homeserver + query_type (str): Category of the query type; should match the + handler name used in register_query_handler(). + args (dict): Mapping of strings to strings containing the details + of the query request. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + return self.transport_layer.make_query( + destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail + ) + + @defer.inlineCallbacks + @log_function + def backfill(self, dest, context, limit, extremities): + """Requests some more historic PDUs for the given context from the + given destination server. + + Args: + dest (str): The remote home server to ask. + context (str): The context to backfill. + limit (int): The maximum number of PDUs to return. + extremities (list): List of PDU id and origins of the first pdus + we have seen from the context + + Returns: + Deferred: Results in the received PDUs. + """ + logger.debug("backfill extrem=%s", extremities) + + # If there are no extremeties then we've (probably) reached the start. + if not extremities: + return + + transaction_data = yield self.transport_layer.backfill( + dest, context, extremities, limit) + + logger.debug("backfill transaction_data=%s", repr(transaction_data)) + + pdus = [ + self.event_from_pdu_json(p, outlier=False) + for p in transaction_data["pdus"] + ] + + for i, pdu in enumerate(pdus): + pdus[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + defer.returnValue(pdus) + + @defer.inlineCallbacks + @log_function + def get_pdu(self, destinations, event_id, outlier=False): + """Requests the PDU with given origin and ID from the remote home + servers. + + Will attempt to get the PDU from each destination in the list until + one succeeds. + + This will persist the PDU locally upon receipt. + + Args: + destinations (list): Which home servers to query + pdu_origin (str): The home server that originally sent the pdu. + event_id (str) + outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if + it's from an arbitary point in the context as opposed to part + of the current block of PDUs. Defaults to `False` + + Returns: + Deferred: Results in the requested PDU. + """ + + # TODO: Rate limit the number of times we try and get the same event. + + pdu = None + for destination in destinations: + try: + transaction_data = yield self.transport_layer.get_event( + destination, event_id + ) + + logger.debug("transaction_data %r", transaction_data) + + pdu_list = [ + self.event_from_pdu_json(p, outlier=outlier) + for p in transaction_data["pdus"] + ] + + if pdu_list: + pdu = pdu_list[0] + + # Check signatures are correct. + pdu = yield self._check_sigs_and_hash(pdu) + + break + + except Exception as e: + logger.info( + "Failed to get PDU %s from %s because %s", + event_id, destination, e, + ) + continue + + defer.returnValue(pdu) + + @defer.inlineCallbacks + @log_function + def get_state_for_room(self, destination, room_id, event_id): + """Requests all of the `current` state PDUs for a given room from + a remote home server. + + Args: + destination (str): The remote homeserver to query for the state. + room_id (str): The id of the room we're interested in. + event_id (str): The id of the event we want the state at. + + Returns: + Deferred: Results in a list of PDUs. + """ + + result = yield self.transport_layer.get_room_state( + destination, room_id, event_id=event_id, + ) + + pdus = [ + self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] + ] + + auth_chain = [ + self.event_from_pdu_json(p, outlier=True) + for p in result.get("auth_chain", []) + ] + + for i, pdu in enumerate(pdus): + pdus[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + for i, pdu in enumerate(auth_chain): + auth_chain[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + defer.returnValue((pdus, auth_chain)) + + @defer.inlineCallbacks + @log_function + def get_event_auth(self, destination, room_id, event_id): + res = yield self.transport_layer.get_event_auth( + destination, room_id, event_id, + ) + + auth_chain = [ + self.event_from_pdu_json(p, outlier=True) + for p in res["auth_chain"] + ] + + for i, pdu in enumerate(auth_chain): + auth_chain[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + auth_chain.sort(key=lambda e: e.depth) + + defer.returnValue(auth_chain) + + @defer.inlineCallbacks + def make_join(self, destination, room_id, user_id): + ret = yield self.transport_layer.make_join( + destination, room_id, user_id + ) + + pdu_dict = ret["event"] + + logger.debug("Got response to make_join: %s", pdu_dict) + + defer.returnValue(self.event_from_pdu_json(pdu_dict)) + + @defer.inlineCallbacks + def send_join(self, destination, pdu): + time_now = self._clock.time_msec() + _, content = yield self.transport_layer.send_join( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + logger.debug("Got content: %s", content) + + state = [ + self.event_from_pdu_json(p, outlier=True) + for p in content.get("state", []) + ] + + auth_chain = [ + self.event_from_pdu_json(p, outlier=True) + for p in content.get("auth_chain", []) + ] + + for i, pdu in enumerate(state): + state[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + for i, pdu in enumerate(auth_chain): + auth_chain[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + auth_chain.sort(key=lambda e: e.depth) + + defer.returnValue({ + "state": state, + "auth_chain": auth_chain, + }) + + @defer.inlineCallbacks + def send_invite(self, destination, room_id, event_id, pdu): + time_now = self._clock.time_msec() + code, content = yield self.transport_layer.send_invite( + destination=destination, + room_id=room_id, + event_id=event_id, + content=pdu.get_pdu_json(time_now), + ) + + pdu_dict = content["event"] + + logger.debug("Got response to send_invite: %s", pdu_dict) + + pdu = self.event_from_pdu_json(pdu_dict) + + # Check signatures are correct. + pdu = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + defer.returnValue(pdu) + + @defer.inlineCallbacks + def query_auth(self, destination, room_id, event_id, local_auth): + """ + Params: + destination (str) + event_it (str) + local_auth (list) + """ + time_now = self._clock.time_msec() + + send_content = { + "auth_chain": [e.get_pdu_json(time_now) for e in local_auth], + } + + code, content = yield self.transport_layer.send_query_auth( + destination=destination, + room_id=room_id, + event_id=event_id, + content=send_content, + ) + + auth_chain = [ + (yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) + for e in content["auth_chain"] + ] + + ret = { + "auth_chain": auth_chain, + "rejects": content.get("rejects", []), + "missing": content.get("missing", []), + } + + defer.returnValue(ret) + + def event_from_pdu_json(self, pdu_json, outlier=False): + event = FrozenEvent( + pdu_json + ) + + event.internal_metadata.outlier = outlier + + return event + + @defer.inlineCallbacks + def _check_sigs_and_hash(self, pdu): + """Throws a SynapseError if the PDU does not have the correct + signatures. + + Returns: + FrozenEvent: Either the given event or it redacted if it failed the + content hash check. + """ + # Check signatures are correct. + redacted_event = prune_event(pdu) + redacted_pdu_json = redacted_event.get_pdu_json() + + try: + yield self.keyring.verify_json_for_server( + pdu.origin, redacted_pdu_json + ) + except SynapseError: + logger.warn( + "Signature check failed for %s redacted to %s", + encode_canonical_json(pdu.get_pdu_json()), + encode_canonical_json(redacted_pdu_json), + ) + raise + + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s, %s", + pdu.event_id, encode_canonical_json(pdu.get_dict()) + ) + defer.returnValue(redacted_event) + + defer.returnValue(pdu) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py new file mode 100644 index 0000000000..5fbd8b19de --- /dev/null +++ b/synapse/federation/federation_server.py @@ -0,0 +1,462 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from .units import Transaction, Edu + +from synapse.util.logutils import log_function +from synapse.util.logcontext import PreserveLoggingContext +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from syutil.jsonutil import encode_canonical_json + +from synapse.crypto.event_signing import check_event_content_hash + +from synapse.api.errors import FederationError, SynapseError + +import logging + + +logger = logging.getLogger(__name__) + + +class FederationServer(object): + def set_handler(self, handler): + """Sets the handler that the replication layer will use to communicate + receipt of new PDUs from other home servers. The required methods are + documented on :py:class:`.ReplicationHandler`. + """ + self.handler = handler + + def register_edu_handler(self, edu_type, handler): + if edu_type in self.edu_handlers: + raise KeyError("Already have an EDU handler for %s" % (edu_type,)) + + self.edu_handlers[edu_type] = handler + + def register_query_handler(self, query_type, handler): + """Sets the handler callable that will be used to handle an incoming + federation Query of the given type. + + Args: + query_type (str): Category name of the query, which should match + the string used by make_query. + handler (callable): Invoked to handle incoming queries of this type + + handler is invoked as: + result = handler(args) + + where 'args' is a dict mapping strings to strings of the query + arguments. It should return a Deferred that will eventually yield an + object to encode as JSON. + """ + if query_type in self.query_handlers: + raise KeyError( + "Already have a Query handler for %s" % (query_type,) + ) + + self.query_handlers[query_type] = handler + + @defer.inlineCallbacks + @log_function + def on_backfill_request(self, origin, room_id, versions, limit): + pdus = yield self.handler.on_backfill_request( + origin, room_id, versions, limit + ) + + defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) + + @defer.inlineCallbacks + @log_function + def on_incoming_transaction(self, transaction_data): + transaction = Transaction(**transaction_data) + + for p in transaction.pdus: + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] + if "age" in p: + p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) + del p["age"] + + pdu_list = [ + self.event_from_pdu_json(p) for p in transaction.pdus + ] + + logger.debug("[%s] Got transaction", transaction.transaction_id) + + response = yield self.transaction_actions.have_responded(transaction) + + if response: + logger.debug( + "[%s] We've already responed to this request", + transaction.transaction_id + ) + defer.returnValue(response) + return + + logger.debug("[%s] Transaction is new", transaction.transaction_id) + + with PreserveLoggingContext(): + dl = [] + for pdu in pdu_list: + dl.append(self._handle_new_pdu(transaction.origin, pdu)) + + if hasattr(transaction, "edus"): + for edu in [Edu(**x) for x in transaction.edus]: + self.received_edu( + transaction.origin, + edu.edu_type, + edu.content + ) + + results = yield defer.DeferredList(dl) + + ret = [] + for r in results: + if r[0]: + ret.append({}) + else: + logger.exception(r[1]) + ret.append({"error": str(r[1])}) + + logger.debug("Returning: %s", str(ret)) + + yield self.transaction_actions.set_response( + transaction, + 200, response + ) + defer.returnValue((200, response)) + + def received_edu(self, origin, edu_type, content): + if edu_type in self.edu_handlers: + self.edu_handlers[edu_type](origin, content) + else: + logger.warn("Received EDU of type %s with no handler", edu_type) + + @defer.inlineCallbacks + @log_function + def on_context_state_request(self, origin, room_id, event_id): + if event_id: + pdus = yield self.handler.get_state_for_pdu( + origin, room_id, event_id, + ) + auth_chain = yield self.store.get_auth_chain( + [pdu.event_id for pdu in pdus] + ) + else: + raise NotImplementedError("Specify an event") + + defer.returnValue((200, { + "pdus": [pdu.get_pdu_json() for pdu in pdus], + "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], + })) + + @defer.inlineCallbacks + @log_function + def on_pdu_request(self, origin, event_id): + pdu = yield self._get_persisted_pdu(origin, event_id) + + if pdu: + defer.returnValue( + (200, self._transaction_from_pdus([pdu]).get_dict()) + ) + else: + defer.returnValue((404, "")) + + @defer.inlineCallbacks + @log_function + def on_pull_request(self, origin, versions): + raise NotImplementedError("Pull transactions not implemented") + + @defer.inlineCallbacks + def on_query_request(self, query_type, args): + if query_type in self.query_handlers: + response = yield self.query_handlers[query_type](args) + defer.returnValue((200, response)) + else: + defer.returnValue( + (404, "No handler for Query type '%s'" % (query_type,)) + ) + + @defer.inlineCallbacks + def on_make_join_request(self, room_id, user_id): + pdu = yield self.handler.on_make_join_request(room_id, user_id) + time_now = self._clock.time_msec() + defer.returnValue({"event": pdu.get_pdu_json(time_now)}) + + @defer.inlineCallbacks + def on_invite_request(self, origin, content): + pdu = self.event_from_pdu_json(content) + ret_pdu = yield self.handler.on_invite_request(origin, pdu) + time_now = self._clock.time_msec() + defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) + + @defer.inlineCallbacks + def on_send_join_request(self, origin, content): + logger.debug("on_send_join_request: content: %s", content) + pdu = self.event_from_pdu_json(content) + logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) + res_pdus = yield self.handler.on_send_join_request(origin, pdu) + time_now = self._clock.time_msec() + defer.returnValue((200, { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [ + p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] + ], + })) + + @defer.inlineCallbacks + def on_event_auth(self, origin, room_id, event_id): + time_now = self._clock.time_msec() + auth_pdus = yield self.handler.on_event_auth(event_id) + defer.returnValue((200, { + "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], + })) + + @defer.inlineCallbacks + def on_query_auth_request(self, origin, content, event_id): + """ + Content is a dict with keys:: + auth_chain (list): A list of events that give the auth chain. + missing (list): A list of event_ids indicating what the other + side (`origin`) think we're missing. + rejects (dict): A mapping from event_id to a 2-tuple of reason + string and a proof (or None) of why the event was rejected. + The keys of this dict give the list of events the `origin` has + rejected. + + Args: + origin (str) + content (dict) + event_id (str) + + Returns: + Deferred: Results in `dict` with the same format as `content` + """ + auth_chain = [ + (yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) + for e in content["auth_chain"] + ] + + missing = [ + (yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) + for e in content.get("missing", []) + ] + + ret = yield self.handler.on_query_auth( + origin, event_id, auth_chain, content.get("rejects", []), missing + ) + + time_now = self._clock.time_msec() + send_content = { + "auth_chain": [ + e.get_pdu_json(time_now) + for e in ret["auth_chain"] + ], + "rejects": ret.get("rejects", []), + "missing": ret.get("missing", []), + } + + defer.returnValue( + (200, send_content) + ) + + @log_function + def _get_persisted_pdu(self, origin, event_id, do_auth=True): + """ Get a PDU from the database with given origin and id. + + Returns: + Deferred: Results in a `Pdu`. + """ + return self.handler.get_persisted_pdu( + origin, event_id, do_auth=do_auth + ) + + def _transaction_from_pdus(self, pdu_list): + """Returns a new Transaction containing the given PDUs suitable for + transmission. + """ + time_now = self._clock.time_msec() + pdus = [p.get_pdu_json(time_now) for p in pdu_list] + return Transaction( + origin=self.server_name, + pdus=pdus, + origin_server_ts=int(time_now), + destination=None, + ) + + @defer.inlineCallbacks + @log_function + def _handle_new_pdu(self, origin, pdu, max_recursion=10): + # We reprocess pdus when we have seen them only as outliers + existing = yield self._get_persisted_pdu( + origin, pdu.event_id, do_auth=False + ) + + # FIXME: Currently we fetch an event again when we already have it + # if it has been marked as an outlier. + + already_seen = ( + existing and ( + not existing.internal_metadata.is_outlier() + or pdu.internal_metadata.is_outlier() + ) + ) + if already_seen: + logger.debug("Already seen pdu %s", pdu.event_id) + defer.returnValue({}) + return + + # Check signature. + try: + pdu = yield self._check_sigs_and_hash(pdu) + except SynapseError as e: + raise FederationError( + "ERROR", + e.code, + e.msg, + affected=pdu.event_id, + ) + + state = None + + auth_chain = [] + + have_seen = yield self.store.have_events( + [ev for ev, _ in pdu.prev_events] + ) + + fetch_state = False + + # Get missing pdus if necessary. + if not pdu.internal_metadata.is_outlier(): + # We only backfill backwards to the min depth. + min_depth = yield self.handler.get_min_depth_for_context( + pdu.room_id + ) + + logger.debug( + "_handle_new_pdu min_depth for %s: %d", + pdu.room_id, min_depth + ) + + if min_depth and pdu.depth > min_depth and max_recursion > 0: + for event_id, hashes in pdu.prev_events: + if event_id not in have_seen: + logger.debug( + "_handle_new_pdu requesting pdu %s", + event_id + ) + + try: + new_pdu = yield self.federation_client.get_pdu( + [origin, pdu.origin], + event_id=event_id, + ) + + if new_pdu: + yield self._handle_new_pdu( + origin, + new_pdu, + max_recursion=max_recursion-1 + ) + + logger.debug("Processed pdu %s", event_id) + else: + logger.warn("Failed to get PDU %s", event_id) + fetch_state = True + except: + # TODO(erikj): Do some more intelligent retries. + logger.exception("Failed to get PDU") + fetch_state = True + else: + prevs = {e_id for e_id, _ in pdu.prev_events} + seen = set(have_seen.keys()) + if prevs - seen: + fetch_state = True + else: + fetch_state = True + + if fetch_state: + # We need to get the state at this event, since we haven't + # processed all the prev events. + logger.debug( + "_handle_new_pdu getting state for %s", + pdu.room_id + ) + state, auth_chain = yield self.get_state_for_room( + origin, pdu.room_id, pdu.event_id, + ) + + ret = yield self.handler.on_receive_pdu( + origin, + pdu, + backfilled=False, + state=state, + auth_chain=auth_chain, + ) + + defer.returnValue(ret) + + def __str__(self): + return "<ReplicationLayer(%s)>" % self.server_name + + def event_from_pdu_json(self, pdu_json, outlier=False): + event = FrozenEvent( + pdu_json + ) + + event.internal_metadata.outlier = outlier + + return event + + @defer.inlineCallbacks + def _check_sigs_and_hash(self, pdu): + """Throws a SynapseError if the PDU does not have the correct + signatures. + + Returns: + FrozenEvent: Either the given event or it redacted if it failed the + content hash check. + """ + # Check signatures are correct. + redacted_event = prune_event(pdu) + redacted_pdu_json = redacted_event.get_pdu_json() + + try: + yield self.keyring.verify_json_for_server( + pdu.origin, redacted_pdu_json + ) + except SynapseError: + logger.warn( + "Signature check failed for %s redacted to %s", + encode_canonical_json(pdu.get_pdu_json()), + encode_canonical_json(redacted_pdu_json), + ) + raise + + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s, %s", + pdu.event_id, encode_canonical_json(pdu.get_dict()) + ) + defer.returnValue(redacted_event) + + defer.returnValue(pdu) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index a4c29b484b..e442c6c5d5 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -17,23 +17,20 @@ a given transport. """ -from twisted.internet import defer +from .federation_client import FederationClient +from .federation_server import FederationServer -from .units import Transaction, Edu +from .transaction_queue import TransactionQueue from .persistence import TransactionActions -from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext -from synapse.events import FrozenEvent - import logging logger = logging.getLogger(__name__) -class ReplicationLayer(object): +class ReplicationLayer(FederationClient, FederationServer): """This layer is responsible for replicating with remote home servers over the given transport. I.e., does the sending and receiving of PDUs to remote home servers. @@ -54,922 +51,26 @@ class ReplicationLayer(object): def __init__(self, hs, transport_layer): self.server_name = hs.hostname + self.keyring = hs.get_keyring() + self.transport_layer = transport_layer self.transport_layer.register_received_handler(self) self.transport_layer.register_request_handler(self) - self.store = hs.get_datastore() - # self.pdu_actions = PduActions(self.store) - self.transaction_actions = TransactionActions(self.store) + self.federation_client = self - self._transaction_queue = _TransactionQueue( - hs, self.transaction_actions, transport_layer - ) + self.store = hs.get_datastore() self.handler = None self.edu_handlers = {} self.query_handlers = {} - self._order = 0 - self._clock = hs.get_clock() - self.event_builder_factory = hs.get_event_builder_factory() - - def set_handler(self, handler): - """Sets the handler that the replication layer will use to communicate - receipt of new PDUs from other home servers. The required methods are - documented on :py:class:`.ReplicationHandler`. - """ - self.handler = handler - - def register_edu_handler(self, edu_type, handler): - if edu_type in self.edu_handlers: - raise KeyError("Already have an EDU handler for %s" % (edu_type,)) - - self.edu_handlers[edu_type] = handler - - def register_query_handler(self, query_type, handler): - """Sets the handler callable that will be used to handle an incoming - federation Query of the given type. - - Args: - query_type (str): Category name of the query, which should match - the string used by make_query. - handler (callable): Invoked to handle incoming queries of this type - - handler is invoked as: - result = handler(args) - - where 'args' is a dict mapping strings to strings of the query - arguments. It should return a Deferred that will eventually yield an - object to encode as JSON. - """ - if query_type in self.query_handlers: - raise KeyError( - "Already have a Query handler for %s" % (query_type,) - ) - - self.query_handlers[query_type] = handler - - @log_function - def send_pdu(self, pdu, destinations): - """Informs the replication layer about a new PDU generated within the - home server that should be transmitted to others. - - TODO: Figure out when we should actually resolve the deferred. - - Args: - pdu (Pdu): The new Pdu. - - Returns: - Deferred: Completes when we have successfully processed the PDU - and replicated it to any interested remote home servers. - """ - order = self._order - self._order += 1 - - logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) - - # TODO, add errback, etc. - self._transaction_queue.enqueue_pdu(pdu, destinations, order) - - logger.debug( - "[%s] transaction_layer.enqueue_pdu... done", - pdu.event_id - ) - - @log_function - def send_edu(self, destination, edu_type, content): - edu = Edu( - origin=self.server_name, - destination=destination, - edu_type=edu_type, - content=content, - ) - - # TODO, add errback, etc. - self._transaction_queue.enqueue_edu(edu) - return defer.succeed(None) - - @log_function - def send_failure(self, failure, destination): - self._transaction_queue.enqueue_failure(failure, destination) - return defer.succeed(None) - - @log_function - def make_query(self, destination, query_type, args, - retry_on_dns_fail=True): - """Sends a federation Query to a remote homeserver of the given type - and arguments. - - Args: - destination (str): Domain name of the remote homeserver - query_type (str): Category of the query type; should match the - handler name used in register_query_handler(). - args (dict): Mapping of strings to strings containing the details - of the query request. - - Returns: - a Deferred which will eventually yield a JSON object from the - response - """ - return self.transport_layer.make_query( - destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail - ) - - @defer.inlineCallbacks - @log_function - def backfill(self, dest, context, limit, extremities): - """Requests some more historic PDUs for the given context from the - given destination server. - - Args: - dest (str): The remote home server to ask. - context (str): The context to backfill. - limit (int): The maximum number of PDUs to return. - extremities (list): List of PDU id and origins of the first pdus - we have seen from the context - - Returns: - Deferred: Results in the received PDUs. - """ - logger.debug("backfill extrem=%s", extremities) - - # If there are no extremeties then we've (probably) reached the start. - if not extremities: - return - - transaction_data = yield self.transport_layer.backfill( - dest, context, extremities, limit) - - logger.debug("backfill transaction_data=%s", repr(transaction_data)) - - transaction = Transaction(**transaction_data) - - pdus = [ - self.event_from_pdu_json(p, outlier=False) - for p in transaction.pdus - ] - for pdu in pdus: - yield self._handle_new_pdu(dest, pdu, backfilled=True) - - defer.returnValue(pdus) - - @defer.inlineCallbacks - @log_function - def get_pdu(self, destination, event_id, outlier=False): - """Requests the PDU with given origin and ID from the remote home - server. - - This will persist the PDU locally upon receipt. - - Args: - destination (str): Which home server to query - pdu_origin (str): The home server that originally sent the pdu. - event_id (str) - outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if - it's from an arbitary point in the context as opposed to part - of the current block of PDUs. Defaults to `False` - - Returns: - Deferred: Results in the requested PDU. - """ - - transaction_data = yield self.transport_layer.get_event( - destination, event_id - ) - - transaction = Transaction(**transaction_data) - - pdu_list = [ - self.event_from_pdu_json(p, outlier=outlier) - for p in transaction.pdus - ] - - pdu = None - if pdu_list: - pdu = pdu_list[0] - yield self._handle_new_pdu(destination, pdu) - - defer.returnValue(pdu) - - @defer.inlineCallbacks - @log_function - def get_state_for_context(self, destination, context, event_id): - """Requests all of the `current` state PDUs for a given context from - a remote home server. - - Args: - destination (str): The remote homeserver to query for the state. - context (str): The context we're interested in. - event_id (str): The id of the event we want the state at. - - Returns: - Deferred: Results in a list of PDUs. - """ - - result = yield self.transport_layer.get_context_state( - destination, - context, - event_id=event_id, - ) - - pdus = [ - self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] - ] - - auth_chain = [ - self.event_from_pdu_json(p, outlier=True) - for p in result.get("auth_chain", []) - ] - - defer.returnValue((pdus, auth_chain)) - - @defer.inlineCallbacks - @log_function - def get_event_auth(self, destination, context, event_id): - res = yield self.transport_layer.get_event_auth( - destination, context, event_id, - ) - - auth_chain = [ - self.event_from_pdu_json(p, outlier=True) - for p in res["auth_chain"] - ] - - auth_chain.sort(key=lambda e: e.depth) - - defer.returnValue(auth_chain) - - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, origin, context, versions, limit): - pdus = yield self.handler.on_backfill_request( - origin, context, versions, limit - ) - - defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) - - @defer.inlineCallbacks - @log_function - def on_incoming_transaction(self, transaction_data): - transaction = Transaction(**transaction_data) - - for p in transaction.pdus: - if "unsigned" in p: - unsigned = p["unsigned"] - if "age" in unsigned: - p["age"] = unsigned["age"] - if "age" in p: - p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) - del p["age"] - - pdu_list = [ - self.event_from_pdu_json(p) for p in transaction.pdus - ] - - logger.debug("[%s] Got transaction", transaction.transaction_id) - - response = yield self.transaction_actions.have_responded(transaction) - - if response: - logger.debug("[%s] We've already responed to this request", - transaction.transaction_id) - defer.returnValue(response) - return - - logger.debug("[%s] Transaction is new", transaction.transaction_id) - - with PreserveLoggingContext(): - dl = [] - for pdu in pdu_list: - dl.append(self._handle_new_pdu(transaction.origin, pdu)) - - if hasattr(transaction, "edus"): - for edu in [Edu(**x) for x in transaction.edus]: - self.received_edu( - transaction.origin, - edu.edu_type, - edu.content - ) - - results = yield defer.DeferredList(dl) - - ret = [] - for r in results: - if r[0]: - ret.append({}) - else: - logger.exception(r[1]) - ret.append({"error": str(r[1])}) - - logger.debug("Returning: %s", str(ret)) - - yield self.transaction_actions.set_response( - transaction, - 200, response - ) - defer.returnValue((200, response)) - - def received_edu(self, origin, edu_type, content): - if edu_type in self.edu_handlers: - self.edu_handlers[edu_type](origin, content) - else: - logger.warn("Received EDU of type %s with no handler", edu_type) - - @defer.inlineCallbacks - @log_function - def on_context_state_request(self, origin, context, event_id): - if event_id: - pdus = yield self.handler.get_state_for_pdu( - origin, - context, - event_id, - ) - auth_chain = yield self.store.get_auth_chain( - [pdu.event_id for pdu in pdus] - ) - else: - raise NotImplementedError("Specify an event") - - defer.returnValue((200, { - "pdus": [pdu.get_pdu_json() for pdu in pdus], - "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], - })) - - @defer.inlineCallbacks - @log_function - def on_pdu_request(self, origin, event_id): - pdu = yield self._get_persisted_pdu(origin, event_id) - - if pdu: - defer.returnValue( - (200, self._transaction_from_pdus([pdu]).get_dict()) - ) - else: - defer.returnValue((404, "")) - - @defer.inlineCallbacks - @log_function - def on_pull_request(self, origin, versions): - raise NotImplementedError("Pull transacions not implemented") - - @defer.inlineCallbacks - def on_query_request(self, query_type, args): - if query_type in self.query_handlers: - response = yield self.query_handlers[query_type](args) - defer.returnValue((200, response)) - else: - defer.returnValue( - (404, "No handler for Query type '%s'" % (query_type, )) - ) - - @defer.inlineCallbacks - def on_make_join_request(self, context, user_id): - pdu = yield self.handler.on_make_join_request(context, user_id) - time_now = self._clock.time_msec() - defer.returnValue({ - "event": pdu.get_pdu_json(time_now), - }) - - @defer.inlineCallbacks - def on_invite_request(self, origin, content): - pdu = self.event_from_pdu_json(content) - ret_pdu = yield self.handler.on_invite_request(origin, pdu) - time_now = self._clock.time_msec() - defer.returnValue( - ( - 200, - { - "event": ret_pdu.get_pdu_json(time_now), - } - ) - ) - - @defer.inlineCallbacks - def on_send_join_request(self, origin, content): - logger.debug("on_send_join_request: content: %s", content) - pdu = self.event_from_pdu_json(content) - logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - res_pdus = yield self.handler.on_send_join_request(origin, pdu) - time_now = self._clock.time_msec() - defer.returnValue((200, { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - })) - - @defer.inlineCallbacks - def on_event_auth(self, origin, context, event_id): - time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) - defer.returnValue( - ( - 200, - { - "auth_chain": [ - a.get_pdu_json(time_now) for a in auth_pdus - ], - } - ) - ) - - @defer.inlineCallbacks - def make_join(self, destination, context, user_id): - ret = yield self.transport_layer.make_join( - destination=destination, - context=context, - user_id=user_id, - ) - - pdu_dict = ret["event"] - - logger.debug("Got response to make_join: %s", pdu_dict) - - defer.returnValue(self.event_from_pdu_json(pdu_dict)) - - @defer.inlineCallbacks - def send_join(self, destination, pdu): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_join( - destination, - pdu.room_id, - pdu.event_id, - pdu.get_pdu_json(time_now), - ) - - logger.debug("Got content: %s", content) - - state = [ - self.event_from_pdu_json(p, outlier=True) - for p in content.get("state", []) - ] - - # FIXME: We probably want to do something with the auth_chain given - # to us - - auth_chain = [ - self.event_from_pdu_json(p, outlier=True) - for p in content.get("auth_chain", []) - ] - - auth_chain.sort(key=lambda e: e.depth) - - defer.returnValue({ - "state": state, - "auth_chain": auth_chain, - }) - - @defer.inlineCallbacks - def send_invite(self, destination, context, event_id, pdu): - time_now = self._clock.time_msec() - code, content = yield self.transport_layer.send_invite( - destination=destination, - context=context, - event_id=event_id, - content=pdu.get_pdu_json(time_now), - ) - - pdu_dict = content["event"] - - logger.debug("Got response to send_invite: %s", pdu_dict) - - defer.returnValue(self.event_from_pdu_json(pdu_dict)) - - @log_function - def _get_persisted_pdu(self, origin, event_id, do_auth=True): - """ Get a PDU from the database with given origin and id. - - Returns: - Deferred: Results in a `Pdu`. - """ - return self.handler.get_persisted_pdu( - origin, event_id, do_auth=do_auth - ) - - def _transaction_from_pdus(self, pdu_list): - """Returns a new Transaction containing the given PDUs suitable for - transmission. - """ - time_now = self._clock.time_msec() - pdus = [p.get_pdu_json(time_now) for p in pdu_list] - return Transaction( - origin=self.server_name, - pdus=pdus, - origin_server_ts=int(time_now), - destination=None, - ) - - @defer.inlineCallbacks - @log_function - def _handle_new_pdu(self, origin, pdu, backfilled=False): - # We reprocess pdus when we have seen them only as outliers - existing = yield self._get_persisted_pdu( - origin, pdu.event_id, do_auth=False - ) - - already_seen = ( - existing and ( - not existing.internal_metadata.is_outlier() - or pdu.internal_metadata.is_outlier() - ) - ) - if already_seen: - logger.debug("Already seen pdu %s", pdu.event_id) - defer.returnValue({}) - return - - state = None - - auth_chain = [] - - # We need to make sure we have all the auth events. - # for e_id, _ in pdu.auth_events: - # exists = yield self._get_persisted_pdu( - # origin, - # e_id, - # do_auth=False - # ) - # - # if not exists: - # try: - # logger.debug( - # "_handle_new_pdu fetch missing auth event %s from %s", - # e_id, - # origin, - # ) - # - # yield self.get_pdu( - # origin, - # event_id=e_id, - # outlier=True, - # ) - # - # logger.debug("Processed pdu %s", e_id) - # except: - # logger.warn( - # "Failed to get auth event %s from %s", - # e_id, - # origin - # ) - - # Get missing pdus if necessary. - if not pdu.internal_metadata.is_outlier(): - # We only backfill backwards to the min depth. - min_depth = yield self.handler.get_min_depth_for_context( - pdu.room_id - ) - - logger.debug( - "_handle_new_pdu min_depth for %s: %d", - pdu.room_id, min_depth - ) - - if min_depth and pdu.depth > min_depth: - for event_id, hashes in pdu.prev_events: - exists = yield self._get_persisted_pdu( - origin, - event_id, - do_auth=False - ) - - if not exists: - logger.debug( - "_handle_new_pdu requesting pdu %s", - event_id - ) - - try: - yield self.get_pdu( - origin, - event_id=event_id, - ) - logger.debug("Processed pdu %s", event_id) - except: - # TODO(erikj): Do some more intelligent retries. - logger.exception("Failed to get PDU") - else: - # We need to get the state at this event, since we have reached - # a backward extremity edge. - logger.debug( - "_handle_new_pdu getting state for %s", - pdu.room_id - ) - state, auth_chain = yield self.get_state_for_context( - origin, pdu.room_id, pdu.event_id, - ) - - if not backfilled: - ret = yield self.handler.on_receive_pdu( - origin, - pdu, - backfilled=backfilled, - state=state, - auth_chain=auth_chain, - ) - else: - ret = None - - # yield self.pdu_actions.mark_as_processed(pdu) + self.transaction_actions = TransactionActions(self.store) + self._transaction_queue = TransactionQueue(hs, transport_layer) - defer.returnValue(ret) + self._order = 0 def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name - - def event_from_pdu_json(self, pdu_json, outlier=False): - event = FrozenEvent( - pdu_json - ) - - event.internal_metadata.outlier = outlier - - return event - - -class _TransactionQueue(object): - """This class makes sure we only have one transaction in flight at - a time for a given destination. - - It batches pending PDUs into single transactions. - """ - - def __init__(self, hs, transaction_actions, transport_layer): - self.server_name = hs.hostname - self.transaction_actions = transaction_actions - self.transport_layer = transport_layer - - self._clock = hs.get_clock() - self.store = hs.get_datastore() - - # Is a mapping from destinations -> deferreds. Used to keep track - # of which destinations have transactions in flight and when they are - # done - self.pending_transactions = {} - - # Is a mapping from destination -> list of - # tuple(pending pdus, deferred, order) - self.pending_pdus_by_dest = {} - # destination -> list of tuple(edu, deferred) - self.pending_edus_by_dest = {} - - # destination -> list of tuple(failure, deferred) - self.pending_failures_by_dest = {} - - # HACK to get unique tx id - self._next_txn_id = int(self._clock.time_msec()) - - @defer.inlineCallbacks - @log_function - def enqueue_pdu(self, pdu, destinations, order): - # We loop through all destinations to see whether we already have - # a transaction in progress. If we do, stick it in the pending_pdus - # table and we'll get back to it later. - - destinations = set(destinations) - destinations.discard(self.server_name) - destinations.discard("localhost") - - logger.debug("Sending to: %s", str(destinations)) - - if not destinations: - return - - deferreds = [] - - for destination in destinations: - deferred = defer.Deferred() - self.pending_pdus_by_dest.setdefault(destination, []).append( - (pdu, deferred, order) - ) - - def eb(failure): - if not deferred.called: - deferred.errback(failure) - else: - logger.warn("Failed to send pdu", failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(eb) - - deferreds.append(deferred) - - yield defer.DeferredList(deferreds) - - # NO inlineCallbacks - def enqueue_edu(self, edu): - destination = edu.destination - - if destination == self.server_name: - return - - deferred = defer.Deferred() - self.pending_edus_by_dest.setdefault(destination, []).append( - (edu, deferred) - ) - - def eb(failure): - if not deferred.called: - deferred.errback(failure) - else: - logger.warn("Failed to send edu", failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(eb) - - return deferred - - @defer.inlineCallbacks - def enqueue_failure(self, failure, destination): - deferred = defer.Deferred() - - self.pending_failures_by_dest.setdefault( - destination, [] - ).append( - (failure, deferred) - ) - - yield deferred - - @defer.inlineCallbacks - @log_function - def _attempt_new_transaction(self, destination): - - (retry_last_ts, retry_interval) = (0, 0) - retry_timings = yield self.store.get_destination_retry_timings( - destination - ) - if retry_timings: - (retry_last_ts, retry_interval) = ( - retry_timings.retry_last_ts, retry_timings.retry_interval - ) - if retry_last_ts + retry_interval > int(self._clock.time_msec()): - logger.info( - "TX [%s] not ready for retry yet - " - "dropping transaction for now", - destination, - ) - return - else: - logger.info("TX [%s] is ready for retry", destination) - - logger.info("TX [%s] _attempt_new_transaction", destination) - - if destination in self.pending_transactions: - # XXX: pending_transactions can get stuck on by a never-ending - # request at which point pending_pdus_by_dest just keeps growing. - # we need application-layer timeouts of some flavour of these - # requests - return - - # list of (pending_pdu, deferred, order) - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_failures = self.pending_failures_by_dest.pop(destination, []) - - if pending_pdus: - logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d", destination, len(pending_pdus)) - - if not pending_pdus and not pending_edus and not pending_failures: - return - - logger.debug( - "TX [%s] Attempting new transaction " - "(pdus: %d, edus: %d, failures: %d)", - destination, - len(pending_pdus), - len(pending_edus), - len(pending_failures) - ) - - # Sort based on the order field - pending_pdus.sort(key=lambda t: t[2]) - - pdus = [x[0] for x in pending_pdus] - edus = [x[0] for x in pending_edus] - failures = [x[0].get_dict() for x in pending_failures] - deferreds = [ - x[1] - for x in pending_pdus + pending_edus + pending_failures - ] - - try: - self.pending_transactions[destination] = 1 - - logger.debug("TX [%s] Persisting transaction...", destination) - - transaction = Transaction.create_new( - origin_server_ts=int(self._clock.time_msec()), - transaction_id=str(self._next_txn_id), - origin=self.server_name, - destination=destination, - pdus=pdus, - edus=edus, - pdu_failures=failures, - ) - - self._next_txn_id += 1 - - yield self.transaction_actions.prepare_to_send(transaction) - - logger.debug("TX [%s] Persisted transaction", destination) - logger.info( - "TX [%s] Sending transaction [%s]", - destination, - transaction.transaction_id, - ) - - # Actually send the transaction - - # FIXME (erikj): This is a bit of a hack to make the Pdu age - # keys work - def json_data_cb(): - data = transaction.get_dict() - now = int(self._clock.time_msec()) - if "pdus" in data: - for p in data["pdus"]: - if "age_ts" in p: - unsigned = p.setdefault("unsigned", {}) - unsigned["age"] = now - int(p["age_ts"]) - del p["age_ts"] - return data - - code, response = yield self.transport_layer.send_transaction( - transaction, json_data_cb - ) - - logger.info("TX [%s] got %d response", destination, code) - - logger.debug("TX [%s] Sent transaction", destination) - logger.debug("TX [%s] Marking as delivered...", destination) - - yield self.transaction_actions.delivered( - transaction, code, response - ) - - logger.debug("TX [%s] Marked as delivered", destination) - logger.debug("TX [%s] Yielding to callbacks...", destination) - - for deferred in deferreds: - if code == 200: - if retry_last_ts: - # this host is alive! reset retry schedule - yield self.store.set_destination_retry_timings( - destination, 0, 0 - ) - deferred.callback(None) - else: - self.set_retrying(destination, retry_interval) - deferred.errback(RuntimeError("Got status %d" % code)) - - # Ensures we don't continue until all callbacks on that - # deferred have fired - try: - yield deferred - except: - pass - - logger.debug("TX [%s] Yielded to callbacks", destination) - - except Exception as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) - - self.set_retrying(destination, retry_interval) - - for deferred in deferreds: - if not deferred.called: - deferred.errback(e) - - finally: - # We want to be *very* sure we delete this after we stop processing - self.pending_transactions.pop(destination, None) - - # Check to see if there is anything else to send. - self._attempt_new_transaction(destination) - - @defer.inlineCallbacks - def set_retrying(self, destination, retry_interval): - # track that this destination is having problems and we should - # give it a chance to recover before trying it again - - if retry_interval: - retry_interval *= 2 - # plateau at hourly retries for now - if retry_interval >= 60 * 60 * 1000: - retry_interval = 60 * 60 * 1000 - else: - retry_interval = 2000 # try again at first after 2 seconds - - yield self.store.set_destination_retry_timings( - destination, - int(self._clock.time_msec()), - retry_interval - ) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py new file mode 100644 index 0000000000..9d4f2c09a2 --- /dev/null +++ b/synapse/federation/transaction_queue.py @@ -0,0 +1,317 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from .persistence import TransactionActions +from .units import Transaction + +from synapse.util.logutils import log_function +from synapse.util.logcontext import PreserveLoggingContext + +import logging + + +logger = logging.getLogger(__name__) + + +class TransactionQueue(object): + """This class makes sure we only have one transaction in flight at + a time for a given destination. + + It batches pending PDUs into single transactions. + """ + + def __init__(self, hs, transport_layer): + self.server_name = hs.hostname + + self.store = hs.get_datastore() + self.transaction_actions = TransactionActions(self.store) + + self.transport_layer = transport_layer + + self._clock = hs.get_clock() + + # Is a mapping from destinations -> deferreds. Used to keep track + # of which destinations have transactions in flight and when they are + # done + self.pending_transactions = {} + + # Is a mapping from destination -> list of + # tuple(pending pdus, deferred, order) + self.pending_pdus_by_dest = {} + # destination -> list of tuple(edu, deferred) + self.pending_edus_by_dest = {} + + # destination -> list of tuple(failure, deferred) + self.pending_failures_by_dest = {} + + # HACK to get unique tx id + self._next_txn_id = int(self._clock.time_msec()) + + @defer.inlineCallbacks + @log_function + def enqueue_pdu(self, pdu, destinations, order): + # We loop through all destinations to see whether we already have + # a transaction in progress. If we do, stick it in the pending_pdus + # table and we'll get back to it later. + + destinations = set(destinations) + destinations.discard(self.server_name) + destinations.discard("localhost") + + logger.debug("Sending to: %s", str(destinations)) + + if not destinations: + return + + deferreds = [] + + for destination in destinations: + deferred = defer.Deferred() + self.pending_pdus_by_dest.setdefault(destination, []).append( + (pdu, deferred, order) + ) + + def eb(failure): + if not deferred.called: + deferred.errback(failure) + else: + logger.warn("Failed to send pdu", failure) + + with PreserveLoggingContext(): + self._attempt_new_transaction(destination).addErrback(eb) + + deferreds.append(deferred) + + yield defer.DeferredList(deferreds) + + # NO inlineCallbacks + def enqueue_edu(self, edu): + destination = edu.destination + + if destination == self.server_name: + return + + deferred = defer.Deferred() + self.pending_edus_by_dest.setdefault(destination, []).append( + (edu, deferred) + ) + + def eb(failure): + if not deferred.called: + deferred.errback(failure) + else: + logger.warn("Failed to send edu", failure) + + with PreserveLoggingContext(): + self._attempt_new_transaction(destination).addErrback(eb) + + return deferred + + @defer.inlineCallbacks + def enqueue_failure(self, failure, destination): + deferred = defer.Deferred() + + self.pending_failures_by_dest.setdefault( + destination, [] + ).append( + (failure, deferred) + ) + + yield deferred + + @defer.inlineCallbacks + @log_function + def _attempt_new_transaction(self, destination): + + (retry_last_ts, retry_interval) = (0, 0) + retry_timings = yield self.store.get_destination_retry_timings( + destination + ) + if retry_timings: + (retry_last_ts, retry_interval) = ( + retry_timings.retry_last_ts, retry_timings.retry_interval + ) + if retry_last_ts + retry_interval > int(self._clock.time_msec()): + logger.info( + "TX [%s] not ready for retry yet - " + "dropping transaction for now", + destination, + ) + return + else: + logger.info("TX [%s] is ready for retry", destination) + + logger.info("TX [%s] _attempt_new_transaction", destination) + + if destination in self.pending_transactions: + # XXX: pending_transactions can get stuck on by a never-ending + # request at which point pending_pdus_by_dest just keeps growing. + # we need application-layer timeouts of some flavour of these + # requests + return + + # list of (pending_pdu, deferred, order) + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_failures = self.pending_failures_by_dest.pop(destination, []) + + if pending_pdus: + logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) + + if not pending_pdus and not pending_edus and not pending_failures: + return + + logger.debug( + "TX [%s] Attempting new transaction" + " (pdus: %d, edus: %d, failures: %d)", + destination, + len(pending_pdus), + len(pending_edus), + len(pending_failures) + ) + + # Sort based on the order field + pending_pdus.sort(key=lambda t: t[2]) + + pdus = [x[0] for x in pending_pdus] + edus = [x[0] for x in pending_edus] + failures = [x[0].get_dict() for x in pending_failures] + deferreds = [ + x[1] + for x in pending_pdus + pending_edus + pending_failures + ] + + try: + self.pending_transactions[destination] = 1 + + logger.debug("TX [%s] Persisting transaction...", destination) + + transaction = Transaction.create_new( + origin_server_ts=int(self._clock.time_msec()), + transaction_id=str(self._next_txn_id), + origin=self.server_name, + destination=destination, + pdus=pdus, + edus=edus, + pdu_failures=failures, + ) + + self._next_txn_id += 1 + + yield self.transaction_actions.prepare_to_send(transaction) + + logger.debug("TX [%s] Persisted transaction", destination) + logger.info( + "TX [%s] Sending transaction [%s]", + destination, + transaction.transaction_id, + ) + + # Actually send the transaction + + # FIXME (erikj): This is a bit of a hack to make the Pdu age + # keys work + def json_data_cb(): + data = transaction.get_dict() + now = int(self._clock.time_msec()) + if "pdus" in data: + for p in data["pdus"]: + if "age_ts" in p: + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] + return data + + code, response = yield self.transport_layer.send_transaction( + transaction, json_data_cb + ) + + logger.info("TX [%s] got %d response", destination, code) + + logger.debug("TX [%s] Sent transaction", destination) + logger.debug("TX [%s] Marking as delivered...", destination) + + yield self.transaction_actions.delivered( + transaction, code, response + ) + + logger.debug("TX [%s] Marked as delivered", destination) + logger.debug("TX [%s] Yielding to callbacks...", destination) + + for deferred in deferreds: + if code == 200: + if retry_last_ts: + # this host is alive! reset retry schedule + yield self.store.set_destination_retry_timings( + destination, 0, 0 + ) + deferred.callback(None) + else: + self.set_retrying(destination, retry_interval) + deferred.errback(RuntimeError("Got status %d" % code)) + + # Ensures we don't continue until all callbacks on that + # deferred have fired + try: + yield deferred + except: + pass + + logger.debug("TX [%s] Yielded to callbacks", destination) + + except Exception as e: + # We capture this here as there as nothing actually listens + # for this finishing functions deferred. + logger.warn( + "TX [%s] Problem in _attempt_transaction: %s", + destination, + e, + ) + + self.set_retrying(destination, retry_interval) + + for deferred in deferreds: + if not deferred.called: + deferred.errback(e) + + finally: + # We want to be *very* sure we delete this after we stop processing + self.pending_transactions.pop(destination, None) + + # Check to see if there is anything else to send. + self._attempt_new_transaction(destination) + + @defer.inlineCallbacks + def set_retrying(self, destination, retry_interval): + # track that this destination is having problems and we should + # give it a chance to recover before trying it again + + if retry_interval: + retry_interval *= 2 + # plateau at hourly retries for now + if retry_interval >= 60 * 60 * 1000: + retry_interval = 60 * 60 * 1000 + else: + retry_interval = 2000 # try again at first after 2 seconds + + yield self.store.set_destination_retry_timings( + destination, + int(self._clock.time_msec()), + retry_interval + ) diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py new file mode 100644 index 0000000000..6800ac46c5 --- /dev/null +++ b/synapse/federation/transport/__init__.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The transport layer is responsible for both sending transactions to remote +home servers and receiving a variety of requests from other home servers. + +By default this is done over HTTPS (and all home servers are required to +support HTTPS), however individual pairings of servers may decide to +communicate over a different (albeit still reliable) protocol. +""" + +from .server import TransportLayerServer +from .client import TransportLayerClient + + +class TransportLayer(TransportLayerServer, TransportLayerClient): + """This is a basic implementation of the transport layer that translates + transactions and other requests to/from HTTP. + + Attributes: + server_name (str): Local home server host + + server (synapse.http.server.HttpServer): the http server to + register listeners on + + client (synapse.http.client.HttpClient): the http client used to + send requests + + request_handler (TransportRequestHandler): The handler to fire when we + receive requests for data. + + received_handler (TransportReceivedHandler): The handler to fire when + we receive data. + """ + + def __init__(self, homeserver, server_name, server, client): + """ + Args: + server_name (str): Local home server host + server (synapse.protocol.http.HttpServer): the http server to + register listeners on + client (synapse.protocol.http.HttpClient): the http client used to + send requests + """ + self.keyring = homeserver.get_keyring() + self.server_name = server_name + self.server = server + self.client = client + self.request_handler = None + self.received_handler = None diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py new file mode 100644 index 0000000000..4cb1dea2de --- /dev/null +++ b/synapse/federation/transport/client.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.urls import FEDERATION_PREFIX as PREFIX +from synapse.util.logutils import log_function + +import logging +import json + + +logger = logging.getLogger(__name__) + + +class TransportLayerClient(object): + """Sends federation HTTP requests to other servers""" + + @log_function + def get_room_state(self, destination, room_id, event_id): + """ Requests all state for a given room from the given server at the + given event. + + Args: + destination (str): The host name of the remote home server we want + to get the state from. + context (str): The name of the context we want the state of + event_id (str): The event we want the context at. + + Returns: + Deferred: Results in a dict received from the remote homeserver. + """ + logger.debug("get_room_state dest=%s, room=%s", + destination, room_id) + + path = PREFIX + "/state/%s/" % room_id + return self.client.get_json( + destination, path=path, args={"event_id": event_id}, + ) + + @log_function + def get_event(self, destination, event_id): + """ Requests the pdu with give id and origin from the given server. + + Args: + destination (str): The host name of the remote home server we want + to get the state from. + event_id (str): The id of the event being requested. + + Returns: + Deferred: Results in a dict received from the remote homeserver. + """ + logger.debug("get_pdu dest=%s, event_id=%s", + destination, event_id) + + path = PREFIX + "/event/%s/" % (event_id, ) + return self.client.get_json(destination, path=path) + + @log_function + def backfill(self, destination, room_id, event_tuples, limit): + """ Requests `limit` previous PDUs in a given context before list of + PDUs. + + Args: + dest (str) + room_id (str) + event_tuples (list) + limt (int) + + Returns: + Deferred: Results in a dict received from the remote homeserver. + """ + logger.debug( + "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s", + destination, room_id, repr(event_tuples), str(limit) + ) + + if not event_tuples: + # TODO: raise? + return + + path = PREFIX + "/backfill/%s/" % (room_id,) + + args = { + "v": event_tuples, + "limit": [str(limit)], + } + + return self.client.get_json( + destination, + path=path, + args=args, + ) + + @defer.inlineCallbacks + @log_function + def send_transaction(self, transaction, json_data_callback=None): + """ Sends the given Transaction to its destination + + Args: + transaction (Transaction) + + Returns: + Deferred: Results of the deferred is a tuple in the form of + (response_code, response_body) where the response_body is a + python dict decoded from json + """ + logger.debug( + "send_data dest=%s, txid=%s", + transaction.destination, transaction.transaction_id + ) + + if transaction.destination == self.server_name: + raise RuntimeError("Transport layer cannot send to itself!") + + # FIXME: This is only used by the tests. The actual json sent is + # generated by the json_data_callback. + json_data = transaction.get_dict() + + code, response = yield self.client.put_json( + transaction.destination, + path=PREFIX + "/send/%s/" % transaction.transaction_id, + data=json_data, + json_data_callback=json_data_callback, + ) + + logger.debug( + "send_data dest=%s, txid=%s, got response: %d", + transaction.destination, transaction.transaction_id, code + ) + + defer.returnValue((code, response)) + + @defer.inlineCallbacks + @log_function + def make_query(self, destination, query_type, args, retry_on_dns_fail): + path = PREFIX + "/query/%s" % query_type + + response = yield self.client.get_json( + destination=destination, + path=path, + args=args, + retry_on_dns_fail=retry_on_dns_fail, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks + @log_function + def make_join(self, destination, room_id, user_id, retry_on_dns_fail=True): + path = PREFIX + "/make_join/%s/%s" % (room_id, user_id) + + response = yield self.client.get_json( + destination=destination, + path=path, + retry_on_dns_fail=retry_on_dns_fail, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks + @log_function + def send_join(self, destination, room_id, event_id, content): + path = PREFIX + "/send_join/%s/%s" % (room_id, event_id) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_join", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks + @log_function + def send_invite(self, destination, room_id, event_id, content): + path = PREFIX + "/invite/%s/%s" % (room_id, event_id) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_invite", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks + @log_function + def get_event_auth(self, destination, room_id, event_id): + path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id) + + response = yield self.client.get_json( + destination=destination, + path=path, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks + @log_function + def send_query_auth(self, destination, room_id, event_id, content): + path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id) + + code, content = yield self.client.post_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_invite", code) + + defer.returnValue(json.loads(content)) diff --git a/synapse/federation/transport.py b/synapse/federation/transport/server.py index 1f0f06e0fe..9c9f8d525b 100644 --- a/synapse/federation/transport.py +++ b/synapse/federation/transport/server.py @@ -13,14 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The transport layer is responsible for both sending transactions to remote -home servers and receiving a variety of requests from other home servers. - -Typically, this is done over HTTP (and all home servers are required to -support HTTP), however individual pairings of servers may decide to communicate -over a different (albeit still reliable) protocol. -""" - from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX @@ -35,241 +27,8 @@ import re logger = logging.getLogger(__name__) -class TransportLayer(object): - """This is a basic implementation of the transport layer that translates - transactions and other requests to/from HTTP. - - Attributes: - server_name (str): Local home server host - - server (synapse.http.server.HttpServer): the http server to - register listeners on - - client (synapse.http.client.HttpClient): the http client used to - send requests - - request_handler (TransportRequestHandler): The handler to fire when we - receive requests for data. - - received_handler (TransportReceivedHandler): The handler to fire when - we receive data. - """ - - def __init__(self, homeserver, server_name, server, client): - """ - Args: - server_name (str): Local home server host - server (synapse.protocol.http.HttpServer): the http server to - register listeners on - client (synapse.protocol.http.HttpClient): the http client used to - send requests - """ - self.keyring = homeserver.get_keyring() - self.server_name = server_name - self.server = server - self.client = client - self.request_handler = None - self.received_handler = None - - @log_function - def get_context_state(self, destination, context, event_id=None): - """ Requests all state for a given context (i.e. room) from the - given server. - - Args: - destination (str): The host name of the remote home server we want - to get the state from. - context (str): The name of the context we want the state of - - Returns: - Deferred: Results in a dict received from the remote homeserver. - """ - logger.debug("get_context_state dest=%s, context=%s", - destination, context) - - subpath = "/state/%s/" % context - - args = {} - if event_id: - args["event_id"] = event_id - - return self._do_request_for_transaction( - destination, subpath, args=args - ) - - @log_function - def get_event(self, destination, event_id): - """ Requests the pdu with give id and origin from the given server. - - Args: - destination (str): The host name of the remote home server we want - to get the state from. - event_id (str): The id of the event being requested. - - Returns: - Deferred: Results in a dict received from the remote homeserver. - """ - logger.debug("get_pdu dest=%s, event_id=%s", - destination, event_id) - - subpath = "/event/%s/" % (event_id, ) - - return self._do_request_for_transaction(destination, subpath) - - @log_function - def backfill(self, dest, context, event_tuples, limit): - """ Requests `limit` previous PDUs in a given context before list of - PDUs. - - Args: - dest (str) - context (str) - event_tuples (list) - limt (int) - - Returns: - Deferred: Results in a dict received from the remote homeserver. - """ - logger.debug( - "backfill dest=%s, context=%s, event_tuples=%s, limit=%s", - dest, context, repr(event_tuples), str(limit) - ) - - if not event_tuples: - # TODO: raise? - return - - subpath = "/backfill/%s/" % (context,) - - args = { - "v": event_tuples, - "limit": [str(limit)], - } - - return self._do_request_for_transaction( - dest, - subpath, - args=args, - ) - - @defer.inlineCallbacks - @log_function - def send_transaction(self, transaction, json_data_callback=None): - """ Sends the given Transaction to its destination - - Args: - transaction (Transaction) - - Returns: - Deferred: Results of the deferred is a tuple in the form of - (response_code, response_body) where the response_body is a - python dict decoded from json - """ - logger.debug( - "send_data dest=%s, txid=%s", - transaction.destination, transaction.transaction_id - ) - - if transaction.destination == self.server_name: - raise RuntimeError("Transport layer cannot send to itself!") - - # FIXME: This is only used by the tests. The actual json sent is - # generated by the json_data_callback. - json_data = transaction.get_dict() - - code, response = yield self.client.put_json( - transaction.destination, - path=PREFIX + "/send/%s/" % transaction.transaction_id, - data=json_data, - json_data_callback=json_data_callback, - ) - - logger.debug( - "send_data dest=%s, txid=%s, got response: %d", - transaction.destination, transaction.transaction_id, code - ) - - defer.returnValue((code, response)) - - @defer.inlineCallbacks - @log_function - def make_query(self, destination, query_type, args, retry_on_dns_fail): - path = PREFIX + "/query/%s" % query_type - - response = yield self.client.get_json( - destination=destination, - path=path, - args=args, - retry_on_dns_fail=retry_on_dns_fail, - ) - - defer.returnValue(response) - - @defer.inlineCallbacks - @log_function - def make_join(self, destination, context, user_id, retry_on_dns_fail=True): - path = PREFIX + "/make_join/%s/%s" % (context, user_id,) - - response = yield self.client.get_json( - destination=destination, - path=path, - retry_on_dns_fail=retry_on_dns_fail, - ) - - defer.returnValue(response) - - @defer.inlineCallbacks - @log_function - def send_join(self, destination, context, event_id, content): - path = PREFIX + "/send_join/%s/%s" % ( - context, - event_id, - ) - - code, content = yield self.client.put_json( - destination=destination, - path=path, - data=content, - ) - - if not 200 <= code < 300: - raise RuntimeError("Got %d from send_join", code) - - defer.returnValue(json.loads(content)) - - @defer.inlineCallbacks - @log_function - def send_invite(self, destination, context, event_id, content): - path = PREFIX + "/invite/%s/%s" % ( - context, - event_id, - ) - - code, content = yield self.client.put_json( - destination=destination, - path=path, - data=content, - ) - - if not 200 <= code < 300: - raise RuntimeError("Got %d from send_invite", code) - - defer.returnValue(json.loads(content)) - - @defer.inlineCallbacks - @log_function - def get_event_auth(self, destination, context, event_id): - path = PREFIX + "/event_auth/%s/%s" % ( - context, - event_id, - ) - - response = yield self.client.get_json( - destination=destination, - path=path, - ) - - defer.returnValue(response) +class TransportLayerServer(object): + """Handles incoming federation HTTP requests""" @defer.inlineCallbacks def _authenticate_request(self, request): @@ -283,7 +42,7 @@ class TransportLayer(object): content = None origin = None - if request.method == "PUT": + if request.method in ["PUT", "POST"]: # TODO: Handle other method types? other content types? try: content_bytes = request.content.read() @@ -373,8 +132,6 @@ class TransportLayer(object): """ self.request_handler = handler - # TODO(markjh): Namespace the federation URI paths - # This is for when someone asks us for everything since version X self.server.register_path( "GET", @@ -477,6 +234,16 @@ class TransportLayer(object): ) ) ) + self.server.register_path( + "POST", + re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + self._on_query_auth_request( + origin, content, event_id, + ) + ) + ) @defer.inlineCallbacks @log_function @@ -528,34 +295,6 @@ class TransportLayer(object): defer.returnValue((code, response)) - @defer.inlineCallbacks - @log_function - def _do_request_for_transaction(self, destination, subpath, args={}): - """ - Args: - destination (str) - path (str) - args (dict): This is parsed directly to the HttpClient. - - Returns: - Deferred: Results in a dict. - """ - - data = yield self.client.get_json( - destination, - path=PREFIX + subpath, - args=args, - ) - - # Add certain keys to the JSON, ready for decoding as a Transaction - data.update( - origin=destination, - destination=self.server_name, - transaction_id=None - ) - - defer.returnValue(data) - @log_function def _on_backfill_request(self, origin, context, v_list, limits): if not limits: @@ -596,3 +335,12 @@ class TransportLayer(object): ) defer.returnValue((200, content)) + + @defer.inlineCallbacks + @log_function + def _on_query_auth_request(self, origin, content, event_id): + new_content = yield self.request_handler.on_query_auth_request( + origin, content, event_id + ) + + defer.returnValue((200, new_content)) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index fe071a4bc2..a32eab9316 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -26,6 +26,7 @@ from .presence import PresenceHandler from .directory import DirectoryHandler from .typing import TypingNotificationHandler from .admin import AdminHandler +from .sync import SyncHandler class Handlers(object): @@ -51,3 +52,4 @@ class Handlers(object): self.directory_handler = DirectoryHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs) self.admin_handler = AdminHandler(hs) + self.sync_handler = SyncHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 38af034b4d..1773fa20aa 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -19,6 +19,7 @@ from synapse.api.errors import LimitExceededError, SynapseError from synapse.util.async import run_on_reactor from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.api.constants import Membership, EventTypes +from synapse.types import UserID import logging @@ -113,7 +114,7 @@ class BaseHandler(object): if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: - invitee = self.hs.parse_userid(event.state_key) + invitee = UserID.from_string(event.state_key) if not self.hs.is_mine(invitee): # TODO: Can we add signature from remote server in a nicer # way? If we have been invited by a remote server, we need @@ -134,7 +135,7 @@ class BaseHandler(object): if k[0] == EventTypes.Member: if s.content["membership"] == Membership.JOIN: destinations.add( - self.hs.parse_userid(s.state_key).domain + UserID.from_string(s.state_key).domain ) except SynapseError: logger.warn( @@ -144,7 +145,5 @@ class BaseHandler(object): yield self.notifier.on_new_room_event(event, extra_users=extra_users) yield federation_handler.handle_new_event( - event, - None, - destinations=destinations, + event, destinations=destinations, ) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 91fceda2ac..58e9a91562 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -19,6 +19,7 @@ from ._base import BaseHandler from synapse.api.errors import SynapseError, Codes, CodeMessageException from synapse.api.constants import EventTypes +from synapse.types import RoomAlias import logging @@ -122,7 +123,7 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def on_directory_query(self, args): - room_alias = self.hs.parse_roomalias(args["room_alias"]) + room_alias = RoomAlias.from_string(args["room_alias"]) if not self.hs.is_mine(room_alias): raise SynapseError( 400, "Room Alias is not hosted on this Home Server" diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 808219bd10..025e7e7e62 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -17,6 +17,8 @@ from twisted.internet import defer from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logutils import log_function +from synapse.types import UserID +from synapse.events.utils import serialize_event from ._base import BaseHandler @@ -46,39 +48,43 @@ class EventStreamHandler(BaseHandler): @defer.inlineCallbacks @log_function - def get_stream(self, auth_user_id, pagin_config, timeout=0): - auth_user = self.hs.parse_userid(auth_user_id) + def get_stream(self, auth_user_id, pagin_config, timeout=0, + as_client_event=True, affect_presence=True): + auth_user = UserID.from_string(auth_user_id) try: - if auth_user not in self._streams_per_user: - self._streams_per_user[auth_user] = 0 - if auth_user in self._stop_timer_per_user: - try: - self.clock.cancel_call_later( - self._stop_timer_per_user.pop(auth_user) + if affect_presence: + if auth_user not in self._streams_per_user: + self._streams_per_user[auth_user] = 0 + if auth_user in self._stop_timer_per_user: + try: + self.clock.cancel_call_later( + self._stop_timer_per_user.pop(auth_user) + ) + except: + logger.exception("Failed to cancel event timer") + else: + yield self.distributor.fire( + "started_user_eventstream", auth_user ) - except: - logger.exception("Failed to cancel event timer") - else: - yield self.distributor.fire( - "started_user_eventstream", auth_user - ) - self._streams_per_user[auth_user] += 1 + self._streams_per_user[auth_user] += 1 if pagin_config.from_token is None: pagin_config.from_token = None rm_handler = self.hs.get_handlers().room_member_handler - logger.debug("BETA") room_ids = yield rm_handler.get_rooms_for_user(auth_user) - logger.debug("ALPHA") with PreserveLoggingContext(): events, tokens = yield self.notifier.get_events_for( auth_user, room_ids, pagin_config, timeout ) - chunks = [self.hs.serialize_event(e) for e in events] + time_now = self.clock.time_msec() + + chunks = [ + serialize_event(e, time_now, as_client_event) for e in events + ] chunk = { "chunk": chunks, @@ -89,27 +95,28 @@ class EventStreamHandler(BaseHandler): defer.returnValue(chunk) finally: - self._streams_per_user[auth_user] -= 1 - if not self._streams_per_user[auth_user]: - del self._streams_per_user[auth_user] - - # 10 seconds of grace to allow the client to reconnect again - # before we think they're gone - def _later(): - logger.debug( - "_later stopped_user_eventstream %s", auth_user - ) + if affect_presence: + self._streams_per_user[auth_user] -= 1 + if not self._streams_per_user[auth_user]: + del self._streams_per_user[auth_user] + + # 10 seconds of grace to allow the client to reconnect again + # before we think they're gone + def _later(): + logger.debug( + "_later stopped_user_eventstream %s", auth_user + ) - self._stop_timer_per_user.pop(auth_user, None) + self._stop_timer_per_user.pop(auth_user, None) - yield self.distributor.fire( - "stopped_user_eventstream", auth_user - ) + return self.distributor.fire( + "stopped_user_eventstream", auth_user + ) - logger.debug("Scheduling _later: for %s", auth_user) - self._stop_timer_per_user[auth_user] = ( - self.clock.call_later(30, _later) - ) + logger.debug("Scheduling _later: for %s", auth_user) + self._stop_timer_per_user[auth_user] = ( + self.clock.call_later(30, _later) + ) class EventHandler(BaseHandler): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d26975a88a..8bf5a4cc11 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -17,18 +17,16 @@ from ._base import BaseHandler -from synapse.events.utils import prune_event from synapse.api.errors import ( - AuthError, FederationError, SynapseError, StoreError, + AuthError, FederationError, StoreError, ) -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.crypto.event_signing import ( - compute_event_signature, check_event_content_hash, - add_hashes_and_signatures, + compute_event_signature, add_hashes_and_signatures, ) -from syutil.jsonutil import encode_canonical_json +from synapse.types import UserID from twisted.internet import defer @@ -75,14 +73,14 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def handle_new_event(self, event, snapshot, destinations): + def handle_new_event(self, event, destinations): """ Takes in an event from the client to server side, that has already been authed and handled by the state module, and sends it to any remote home servers that may be interested. Args: - event - snapshot (.storage.Snapshot): THe snapshot the event happened after + event: The event to send + destinations: A list of destinations to send it to Returns: Deferred: Resolved when it has successfully been queued for @@ -112,33 +110,6 @@ class FederationHandler(BaseHandler): logger.debug("Processing event: %s", event.event_id) - redacted_event = prune_event(event) - - redacted_pdu_json = redacted_event.get_pdu_json() - try: - yield self.keyring.verify_json_for_server( - event.origin, redacted_pdu_json - ) - except SynapseError as e: - logger.warn( - "Signature check failed for %s redacted to %s", - encode_canonical_json(pdu.get_pdu_json()), - encode_canonical_json(redacted_pdu_json), - ) - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=event.event_id, - ) - - if not check_event_content_hash(event): - logger.warn( - "Event content has been tampered, redacting %s, %s", - event.event_id, encode_canonical_json(event.get_dict()) - ) - event = redacted_event - logger.debug("Event: %s", event) # FIXME (erikj): Awful hack to make the case where we are not currently @@ -148,41 +119,20 @@ class FederationHandler(BaseHandler): event.room_id, self.server_name ) - if not is_in_room and not event.internal_metadata.outlier: + if not is_in_room and not event.internal_metadata.is_outlier(): logger.debug("Got event for room we're not in.") - - replication = self.replication_layer - - if not state: - state, auth_chain = yield replication.get_state_for_context( - origin, context=event.room_id, event_id=event.event_id, - ) - - if not auth_chain: - auth_chain = yield replication.get_event_auth( - origin, - context=event.room_id, - event_id=event.event_id, - ) - - for e in auth_chain: - e.internal_metadata.outlier = True - try: - yield self._handle_new_event(e, fetch_auth_from=origin) - except: - logger.exception( - "Failed to handle auth event %s", - e.event_id, - ) - current_state = state - if state: + if state and auth_chain is not None: for e in state: - logging.info("A :) %r", e) e.internal_metadata.outlier = True try: - yield self._handle_new_event(e) + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + yield self._handle_new_event(origin, e, auth_events=auth) except: logger.exception( "Failed to handle state event %s", @@ -191,6 +141,7 @@ class FederationHandler(BaseHandler): try: yield self._handle_new_event( + origin, event, state=state, backfilled=backfilled, @@ -227,7 +178,7 @@ class FederationHandler(BaseHandler): extra_users = [] if event.type == EventTypes.Member: target_user_id = event.state_key - target_user = self.hs.parse_userid(target_user_id) + target_user = UserID.from_string(target_user_id) extra_users.append(target_user) yield self.notifier.on_new_room_event( @@ -236,7 +187,7 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.Member: if event.membership == Membership.JOIN: - user = self.hs.parse_userid(event.state_key) + user = UserID.from_string(event.state_key) yield self.distributor.fire( "user_joined_room", user=user, room_id=event.room_id ) @@ -281,7 +232,7 @@ class FederationHandler(BaseHandler): """ pdu = yield self.replication_layer.send_invite( destination=target_host, - context=event.room_id, + room_id=event.room_id, event_id=event.event_id, pdu=event ) @@ -392,8 +343,19 @@ class FederationHandler(BaseHandler): for e in auth_chain: e.internal_metadata.outlier = True + + if e.event_id == event.event_id: + continue + try: - yield self._handle_new_event(e) + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + yield self._handle_new_event( + target_host, e, auth_events=auth + ) except: logger.exception( "Failed to handle auth event %s", @@ -401,11 +363,18 @@ class FederationHandler(BaseHandler): ) for e in state: - # FIXME: Auth these. + if e.event_id == event.event_id: + continue + e.internal_metadata.outlier = True try: + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } yield self._handle_new_event( - e, fetch_auth_from=target_host + target_host, e, auth_events=auth ) except: logger.exception( @@ -413,10 +382,18 @@ class FederationHandler(BaseHandler): e.event_id, ) + auth_ids = [e_id for e_id, _ in event.auth_events] + auth_events = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + yield self._handle_new_event( + target_host, new_event, state=state, current_state=state, + auth_events=auth_events, ) yield self.notifier.on_new_room_event( @@ -480,7 +457,7 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = yield self._handle_new_event(event) + context = yield self._handle_new_event(origin, event) logger.debug( "on_send_join_request: After _handle_new_event: %s, sigs: %s", @@ -491,7 +468,7 @@ class FederationHandler(BaseHandler): extra_users = [] if event.type == EventTypes.Member: target_user_id = event.state_key - target_user = self.hs.parse_userid(target_user_id) + target_user = UserID.from_string(target_user_id) extra_users.append(target_user) yield self.notifier.on_new_room_event( @@ -500,7 +477,7 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.Member: if event.content["membership"] == Membership.JOIN: - user = self.hs.parse_userid(event.state_key) + user = UserID.from_string(event.state_key) yield self.distributor.fire( "user_joined_room", user=user, room_id=event.room_id ) @@ -514,7 +491,7 @@ class FederationHandler(BaseHandler): if k[0] == EventTypes.Member: if s.content["membership"] == Membership.JOIN: destinations.add( - self.hs.parse_userid(s.state_key).domain + UserID.from_string(s.state_key).domain ) except: logger.warn( @@ -565,7 +542,7 @@ class FederationHandler(BaseHandler): backfilled=False, ) - target_user = self.hs.parse_userid(event.state_key) + target_user = UserID.from_string(event.state_key) yield self.notifier.on_new_room_event( event, extra_users=[target_user], ) @@ -617,13 +594,13 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function - def on_backfill_request(self, origin, context, pdu_list, limit): - in_room = yield self.auth.check_host_in_room(context, origin) + def on_backfill_request(self, origin, room_id, pdu_list, limit): + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") events = yield self.store.get_backfill_events( - context, + room_id, pdu_list, limit ) @@ -681,11 +658,12 @@ class FederationHandler(BaseHandler): waiters.pop().callback(None) @defer.inlineCallbacks - def _handle_new_event(self, event, state=None, backfilled=False, - current_state=None, fetch_auth_from=None): + @log_function + def _handle_new_event(self, origin, event, state=None, backfilled=False, + current_state=None, auth_events=None): logger.debug( - "_handle_new_event: Before annotate: %s, sigs: %s", + "_handle_new_event: %s, sigs: %s", event.event_id, event.signatures, ) @@ -693,65 +671,44 @@ class FederationHandler(BaseHandler): event, old_state=state ) + if not auth_events: + auth_events = context.auth_events + logger.debug( - "_handle_new_event: Before auth fetch: %s, sigs: %s", - event.event_id, event.signatures, + "_handle_new_event: %s, auth_events: %s", + event.event_id, auth_events, ) is_new_state = not event.internal_metadata.is_outlier() - known_ids = set( - [s.event_id for s in context.auth_events.values()] - ) - - for e_id, _ in event.auth_events: - if e_id not in known_ids: - e = yield self.store.get_event(e_id, allow_none=True) - - if not e and fetch_auth_from is not None: - # Grab the auth_chain over federation if we are missing - # auth events. - auth_chain = yield self.replication_layer.get_event_auth( - fetch_auth_from, event.event_id, event.room_id - ) - for auth_event in auth_chain: - yield self._handle_new_event(auth_event) - e = yield self.store.get_event(e_id, allow_none=True) - - if not e: - # TODO: Do some conflict res to make sure that we're - # not the ones who are wrong. - logger.info( - "Rejecting %s as %s not in db or %s", - event.event_id, e_id, known_ids, - ) - # FIXME: How does raising AuthError work with federation? - raise AuthError(403, "Cannot find auth event") - - context.auth_events[(e.type, e.state_key)] = e - - logger.debug( - "_handle_new_event: Before hack: %s, sigs: %s", - event.event_id, event.signatures, - ) - + # This is a hack to fix some old rooms where the initial join event + # didn't reference the create event in its auth events. if event.type == EventTypes.Member and not event.auth_events: if len(event.prev_events) == 1: c = yield self.store.get_event(event.prev_events[0][0]) if c.type == EventTypes.Create: - context.auth_events[(c.type, c.state_key)] = c + auth_events[(c.type, c.state_key)] = c - logger.debug( - "_handle_new_event: Before auth check: %s, sigs: %s", - event.event_id, event.signatures, - ) + try: + yield self.do_auth( + origin, event, context, auth_events=auth_events + ) + except AuthError as e: + logger.warn( + "Rejecting %s because %s", + event.event_id, e.msg + ) - self.auth.check(event, auth_events=context.auth_events) + context.rejected = RejectedReason.AUTH_ERROR - logger.debug( - "_handle_new_event: Before persist_event: %s, sigs: %s", - event.event_id, event.signatures, - ) + yield self.store.persist_event( + event, + context=context, + backfilled=backfilled, + is_new_state=False, + current_state=current_state, + ) + raise yield self.store.persist_event( event, @@ -761,9 +718,294 @@ class FederationHandler(BaseHandler): current_state=current_state, ) - logger.debug( - "_handle_new_event: After persist_event: %s, sigs: %s", - event.event_id, event.signatures, + defer.returnValue(context) + + @defer.inlineCallbacks + def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, + missing): + # Just go through and process each event in `remote_auth_chain`. We + # don't want to fall into the trap of `missing` being wrong. + for e in remote_auth_chain: + try: + yield self._handle_new_event(origin, e) + except AuthError: + pass + + # Now get the current auth_chain for the event. + local_auth_chain = yield self.store.get_auth_chain([event_id]) + + # TODO: Check if we would now reject event_id. If so we need to tell + # everyone. + + ret = yield self.construct_auth_difference( + local_auth_chain, remote_auth_chain ) - defer.returnValue(context) + for event in ret["auth_chain"]: + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) + ) + + logger.debug("on_query_auth reutrning: %s", ret) + + defer.returnValue(ret) + + @defer.inlineCallbacks + @log_function + def do_auth(self, origin, event, context, auth_events): + # Check if we have all the auth events. + res = yield self.store.have_events( + [e_id for e_id, _ in event.auth_events] + ) + + event_auth_events = set(e_id for e_id, _ in event.auth_events) + seen_events = set(res.keys()) + + missing_auth = event_auth_events - seen_events + + if missing_auth: + logger.debug("Missing auth: %s", missing_auth) + # If we don't have all the auth events, we need to get them. + remote_auth_chain = yield self.replication_layer.get_event_auth( + origin, event.room_id, event.event_id + ) + + seen_remotes = yield self.store.have_events( + [e.event_id for e in remote_auth_chain] + ) + + for e in remote_auth_chain: + if e.event_id in seen_remotes.keys(): + continue + + if e.event_id == event.event_id: + continue + + try: + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in remote_auth_chain + if e.event_id in auth_ids + } + e.internal_metadata.outlier = True + + logger.debug( + "do_auth %s missing_auth: %s", + event.event_id, e.event_id + ) + yield self._handle_new_event( + origin, e, auth_events=auth + ) + + if e.event_id in event_auth_events: + auth_events[(e.type, e.state_key)] = e + except AuthError: + pass + + # FIXME: Assumes we have and stored all the state for all the + # prev_events + current_state = set(e.event_id for e in auth_events.values()) + different_auth = event_auth_events - current_state + + if different_auth and not event.internal_metadata.is_outlier(): + # Do auth conflict res. + logger.debug("Different auth: %s", different_auth) + + # 1. Get what we think is the auth chain. + auth_ids = self.auth.compute_auth_events(event, context) + local_auth_chain = yield self.store.get_auth_chain(auth_ids) + + # 2. Get remote difference. + result = yield self.replication_layer.query_auth( + origin, + event.room_id, + event.event_id, + local_auth_chain, + ) + + seen_remotes = yield self.store.have_events( + [e.event_id for e in result["auth_chain"]] + ) + + # 3. Process any remote auth chain events we haven't seen. + for ev in result["auth_chain"]: + if ev.event_id in seen_remotes.keys(): + continue + + if ev.event_id == event.event_id: + continue + + try: + auth_ids = [e_id for e_id, _ in ev.auth_events] + auth = { + (e.type, e.state_key): e for e in result["auth_chain"] + if e.event_id in auth_ids + } + ev.internal_metadata.outlier = True + + logger.debug( + "do_auth %s different_auth: %s", + event.event_id, e.event_id + ) + + yield self._handle_new_event( + origin, ev, auth_events=auth + ) + + if ev.event_id in event_auth_events: + auth_events[(ev.type, ev.state_key)] = ev + except AuthError: + pass + + # 4. Look at rejects and their proofs. + # TODO. + + context.current_state.update(auth_events) + context.state_group = None + + try: + self.auth.check(event, auth_events=auth_events) + except AuthError: + raise + + @defer.inlineCallbacks + def construct_auth_difference(self, local_auth, remote_auth): + """ Given a local and remote auth chain, find the differences. This + assumes that we have already processed all events in remote_auth + + Params: + local_auth (list) + remote_auth (list) + + Returns: + dict + """ + + logger.debug("construct_auth_difference Start!") + + # TODO: Make sure we are OK with local_auth or remote_auth having more + # auth events in them than strictly necessary. + + def sort_fun(ev): + return ev.depth, ev.event_id + + logger.debug("construct_auth_difference after sort_fun!") + + # We find the differences by starting at the "bottom" of each list + # and iterating up on both lists. The lists are ordered by depth and + # then event_id, we iterate up both lists until we find the event ids + # don't match. Then we look at depth/event_id to see which side is + # missing that event, and iterate only up that list. Repeat. + + remote_list = list(remote_auth) + remote_list.sort(key=sort_fun) + + local_list = list(local_auth) + local_list.sort(key=sort_fun) + + local_iter = iter(local_list) + remote_iter = iter(remote_list) + + logger.debug("construct_auth_difference before get_next!") + + def get_next(it, opt=None): + try: + return it.next() + except: + return opt + + current_local = get_next(local_iter) + current_remote = get_next(remote_iter) + + logger.debug("construct_auth_difference before while") + + missing_remotes = [] + missing_locals = [] + while current_local or current_remote: + if current_remote is None: + missing_locals.append(current_local) + current_local = get_next(local_iter) + continue + + if current_local is None: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + if current_local.event_id == current_remote.event_id: + current_local = get_next(local_iter) + current_remote = get_next(remote_iter) + continue + + if current_local.depth < current_remote.depth: + missing_locals.append(current_local) + current_local = get_next(local_iter) + continue + + if current_local.depth > current_remote.depth: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + # They have the same depth, so we fall back to the event_id order + if current_local.event_id < current_remote.event_id: + missing_locals.append(current_local) + current_local = get_next(local_iter) + + if current_local.event_id > current_remote.event_id: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + logger.debug("construct_auth_difference after while") + + # missing locals should be sent to the server + # We should find why we are missing remotes, as they will have been + # rejected. + + # Remove events from missing_remotes if they are referencing a missing + # remote. We only care about the "root" rejected ones. + missing_remote_ids = [e.event_id for e in missing_remotes] + base_remote_rejected = list(missing_remotes) + for e in missing_remotes: + for e_id, _ in e.auth_events: + if e_id in missing_remote_ids: + base_remote_rejected.remove(e) + + reason_map = {} + + for e in base_remote_rejected: + reason = yield self.store.get_rejection_reason(e.event_id) + if reason is None: + # FIXME: ERRR?! + logger.warn("Could not find reason for %s", e.event_id) + raise RuntimeError("") + + reason_map[e.event_id] = reason + + if reason == RejectedReason.AUTH_ERROR: + pass + elif reason == RejectedReason.REPLACED: + # TODO: Get proof + pass + elif reason == RejectedReason.NOT_ANCESTOR: + # TODO: Get proof. + pass + + logger.debug("construct_auth_difference returning") + + defer.returnValue({ + "auth_chain": local_auth, + "rejects": { + e.event_id: { + "reason": reason_map[e.event_id], + "proof": None, + } + for e in base_remote_rejected + }, + "missing": [e.event_id for e in missing_locals], + }) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7195de98b5..6fbd2af4ab 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -18,8 +18,10 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import RoomError from synapse.streams.config import PaginationConfig +from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util.logcontext import PreserveLoggingContext +from synapse.types import UserID from ._base import BaseHandler @@ -67,7 +69,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_messages(self, user_id=None, room_id=None, pagin_config=None, - feedback=False): + feedback=False, as_client_event=True): """Get messages in a room. Args: @@ -76,6 +78,7 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config rules to apply, if any. feedback (bool): True to get compressed feedback with the messages + as_client_event (bool): True to get events in client-server format. Returns: dict: Pagination API results """ @@ -88,7 +91,7 @@ class MessageHandler(BaseHandler): yield self.hs.get_event_sources().get_current_token() ) - user = self.hs.parse_userid(user_id) + user = UserID.from_string(user_id) events, next_key = yield data_source.get_pagination_rows( user, pagin_config.get_source_config("room"), room_id @@ -98,8 +101,12 @@ class MessageHandler(BaseHandler): "room_key", next_key ) + time_now = self.clock.time_msec() + chunk = { - "chunk": [self.hs.serialize_event(e) for e in events], + "chunk": [ + serialize_event(e, time_now, as_client_event) for e in events + ], "start": pagin_config.from_token.to_string(), "end": next_token.to_string(), } @@ -107,7 +114,8 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_and_send_event(self, event_dict, ratelimit=True): + def create_and_send_event(self, event_dict, ratelimit=True, + client=None, txn_id=None): """ Given a dict from a client, create and handle a new event. Creates an FrozenEvent object, filling out auth_events, prev_events, @@ -127,13 +135,13 @@ class MessageHandler(BaseHandler): if ratelimit: self.ratelimit(builder.user_id) # TODO(paul): Why does 'event' not have a 'user' object? - user = self.hs.parse_userid(builder.user_id) + user = UserID.from_string(builder.user_id) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if builder.type == EventTypes.Member: membership = builder.content.get("membership", None) if membership == Membership.JOIN: - joinee = self.hs.parse_userid(builder.state_key) + joinee = UserID.from_string(builder.state_key) # If event doesn't include a display name, add one. yield self.distributor.fire( "collect_presencelike_data", @@ -141,6 +149,15 @@ class MessageHandler(BaseHandler): builder.content ) + if client is not None: + if client.token_id is not None: + builder.internal_metadata.token_id = client.token_id + if client.device_id is not None: + builder.internal_metadata.device_id = client.device_id + + if txn_id is not None: + builder.internal_metadata.txn_id = txn_id + event, context = yield self._create_new_client_event( builder=builder, ) @@ -207,11 +224,12 @@ class MessageHandler(BaseHandler): # TODO: This is duplicating logic from snapshot_all_rooms current_state = yield self.state_handler.get_current_state(room_id) - defer.returnValue([self.hs.serialize_event(c) for c in current_state]) + now = self.clock.time_msec() + defer.returnValue([serialize_event(c, now) for c in current_state]) @defer.inlineCallbacks def snapshot_all_rooms(self, user_id=None, pagin_config=None, - feedback=False): + feedback=False, as_client_event=True): """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is @@ -222,6 +240,7 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config used to determine how many messages *PER ROOM* to return. feedback (bool): True to get feedback along with these messages. + as_client_event (bool): True to get events in client-server format. Returns: A list of dicts with "room_id" and "membership" keys for all rooms the user is currently invited or joined in on. Rooms where the user @@ -233,7 +252,7 @@ class MessageHandler(BaseHandler): membership_list=[Membership.INVITE, Membership.JOIN] ) - user = self.hs.parse_userid(user_id) + user = UserID.from_string(user_id) rooms_ret = [] @@ -278,9 +297,13 @@ class MessageHandler(BaseHandler): start_token = now_token.copy_and_replace("room_key", token[0]) end_token = now_token.copy_and_replace("room_key", token[1]) + time_now = self.clock.time_msec() d["messages"] = { - "chunk": [self.hs.serialize_event(m) for m in messages], + "chunk": [ + serialize_event(m, time_now, as_client_event) + for m in messages + ], "start": start_token.to_string(), "end": end_token.to_string(), } @@ -289,7 +312,8 @@ class MessageHandler(BaseHandler): event.room_id ) d["state"] = [ - self.hs.serialize_event(c) for c in current_state + serialize_event(c, time_now, as_client_event) + for c in current_state ] except: logger.exception("Failed to get snapshot") @@ -309,11 +333,12 @@ class MessageHandler(BaseHandler): # TODO(paul): I wish I was called with user objects not user_id # strings... - auth_user = self.hs.parse_userid(user_id) + auth_user = UserID.from_string(user_id) # TODO: These concurrently + time_now = self.clock.time_msec() state_tuples = yield self.state_handler.get_current_state(room_id) - state = [self.hs.serialize_event(x) for x in state_tuples] + state = [serialize_event(x, time_now) for x in state_tuples] member_event = (yield self.store.get_room_member( user_id=user_id, @@ -342,7 +367,7 @@ class MessageHandler(BaseHandler): for m in room_members: try: member_presence = yield presence_handler.get_state( - target_user=self.hs.parse_userid(m.user_id), + target_user=UserID.from_string(m.user_id), auth_user=auth_user, as_event=True, ) @@ -352,11 +377,13 @@ class MessageHandler(BaseHandler): "Failed to get member presence of %r", m.user_id ) + time_now = self.clock.time_msec() + defer.returnValue({ "membership": member_event.membership, "room_id": room_id, "messages": { - "chunk": [self.hs.serialize_event(m) for m in messages], + "chunk": [serialize_event(m, time_now) for m in messages], "start": start_token.to_string(), "end": end_token.to_string(), }, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 8aeed99274..cd0798c2b0 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -20,6 +20,7 @@ from synapse.api.constants import PresenceState from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext +from synapse.types import UserID from ._base import BaseHandler @@ -86,6 +87,10 @@ class PresenceHandler(BaseHandler): "changed_presencelike_data", self.changed_presencelike_data ) + # outbound signal from the presence module to advertise when a user's + # presence has changed + distributor.declare("user_presence_changed") + self.distributor = distributor self.federation = hs.get_replication_layer() @@ -96,22 +101,22 @@ class PresenceHandler(BaseHandler): self.federation.register_edu_handler( "m.presence_invite", lambda origin, content: self.invite_presence( - observed_user=hs.parse_userid(content["observed_user"]), - observer_user=hs.parse_userid(content["observer_user"]), + observed_user=UserID.from_string(content["observed_user"]), + observer_user=UserID.from_string(content["observer_user"]), ) ) self.federation.register_edu_handler( "m.presence_accept", lambda origin, content: self.accept_presence( - observed_user=hs.parse_userid(content["observed_user"]), - observer_user=hs.parse_userid(content["observer_user"]), + observed_user=UserID.from_string(content["observed_user"]), + observer_user=UserID.from_string(content["observer_user"]), ) ) self.federation.register_edu_handler( "m.presence_deny", lambda origin, content: self.deny_presence( - observed_user=hs.parse_userid(content["observed_user"]), - observer_user=hs.parse_userid(content["observer_user"]), + observed_user=UserID.from_string(content["observed_user"]), + observer_user=UserID.from_string(content["observer_user"]), ) ) @@ -418,7 +423,7 @@ class PresenceHandler(BaseHandler): ) for p in presence: - observed_user = self.hs.parse_userid(p.pop("observed_user_id")) + observed_user = UserID.from_string(p.pop("observed_user_id")) p["observed_user"] = observed_user p.update(self._get_or_offline_usercache(observed_user).get_state()) if "last_active" in p: @@ -441,7 +446,7 @@ class PresenceHandler(BaseHandler): user.localpart, accepted=True ) target_users = set([ - self.hs.parse_userid(x["observed_user_id"]) for x in presence + UserID.from_string(x["observed_user_id"]) for x in presence ]) # Also include people in all my rooms @@ -603,6 +608,7 @@ class PresenceHandler(BaseHandler): room_ids=room_ids, statuscache=statuscache, ) + yield self.distributor.fire("user_presence_changed", user, statuscache) @defer.inlineCallbacks def _push_presence_remote(self, user, destination, state=None): @@ -646,7 +652,7 @@ class PresenceHandler(BaseHandler): deferreds = [] for push in content.get("push", []): - user = self.hs.parse_userid(push["user_id"]) + user = UserID.from_string(push["user_id"]) logger.debug("Incoming presence update from %s", user) @@ -694,7 +700,7 @@ class PresenceHandler(BaseHandler): del self._user_cachemap[user] for poll in content.get("poll", []): - user = self.hs.parse_userid(poll) + user = UserID.from_string(poll) if not self.hs.is_mine(user): continue @@ -709,7 +715,7 @@ class PresenceHandler(BaseHandler): deferreds.append(self._push_presence_remote(user, origin)) for unpoll in content.get("unpoll", []): - user = self.hs.parse_userid(unpoll) + user = UserID.from_string(unpoll) if not self.hs.is_mine(user): continue diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 7777d3cc94..03b2159c53 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.constants import EventTypes, Membership from synapse.util.logcontext import PreserveLoggingContext +from synapse.types import UserID from ._base import BaseHandler @@ -169,7 +170,7 @@ class ProfileHandler(BaseHandler): @defer.inlineCallbacks def on_profile_query(self, args): - user = self.hs.parse_userid(args["user_id"]) + user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this Home Server") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 59719a1fae..23821d321f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -16,12 +16,14 @@ """Contains functions for performing events on rooms.""" from twisted.internet import defer +from ._base import BaseHandler + from synapse.types import UserID, RoomAlias, RoomID from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import StoreError, SynapseError from synapse.util import stringutils from synapse.util.async import run_on_reactor -from ._base import BaseHandler +from synapse.events.utils import serialize_event import logging @@ -64,7 +66,7 @@ class RoomCreationHandler(BaseHandler): invite_list = config.get("invite", []) for i in invite_list: try: - self.hs.parse_userid(i) + UserID.from_string(i) except: raise SynapseError(400, "Invalid user_id: %s" % (i,)) @@ -114,7 +116,7 @@ class RoomCreationHandler(BaseHandler): servers=[self.hs.hostname], ) - user = self.hs.parse_userid(user_id) + user = UserID.from_string(user_id) creation_events = self._create_events_for_new_room( user, room_id, is_public=is_public ) @@ -246,11 +248,9 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def get_room_members(self, room_id): - hs = self.hs - users = yield self.store.get_users_in_room(room_id) - defer.returnValue([hs.parse_userid(u) for u in users]) + defer.returnValue([UserID.from_string(u) for u in users]) @defer.inlineCallbacks def fetch_room_distributions_into(self, room_id, localusers=None, @@ -295,8 +295,9 @@ class RoomMemberHandler(BaseHandler): yield self.auth.check_joined_room(room_id, user_id) member_list = yield self.store.get_room_members(room_id=room_id) + time_now = self.clock.time_msec() event_list = [ - self.hs.serialize_event(entry) + serialize_event(entry, time_now) for entry in member_list ] chunk_data = { @@ -368,7 +369,7 @@ class RoomMemberHandler(BaseHandler): ) if prev_state and prev_state.membership == Membership.JOIN: - user = self.hs.parse_userid(event.user_id) + user = UserID.from_string(event.user_id) self.distributor.fire( "user_left_room", user=user, room_id=event.room_id ) @@ -412,7 +413,7 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def _do_join(self, event, context, room_host=None, do_auth=True): - joinee = self.hs.parse_userid(event.state_key) + joinee = UserID.from_string(event.state_key) # room_id = RoomID.from_string(event.room_id, self.hs) room_id = event.room_id @@ -425,10 +426,22 @@ class RoomMemberHandler(BaseHandler): event.room_id, self.hs.hostname ) + if not is_host_in_room: + # is *anyone* in the room? + room_member_keys = [ + v for (k, v) in context.current_state.keys() if ( + k == "m.room.member" + ) + ] + if len(room_member_keys) == 0: + # has the room been created so we can join it? + create_event = context.current_state.get(("m.room.create", "")) + if create_event: + is_host_in_room = True if is_host_in_room: should_do_dance = False - elif room_host: + elif room_host: # TODO: Shouldn't this be remote_room_host? should_do_dance = True else: # TODO(markjh): get prev_state from snapshot @@ -442,7 +455,8 @@ class RoomMemberHandler(BaseHandler): should_do_dance = not self.hs.is_mine(inviter) room_host = inviter.domain else: - should_do_dance = False + # return the same error as join_room_alias does + raise SynapseError(404, "No known servers") if should_do_dance: handler = self.hs.get_handlers().federation_handler @@ -463,7 +477,7 @@ class RoomMemberHandler(BaseHandler): do_auth=do_auth, ) - user = self.hs.parse_userid(event.user_id) + user = UserID.from_string(event.user_id) yield self.distributor.fire( "user_joined_room", user=user, room_id=room_id ) @@ -513,7 +527,7 @@ class RoomMemberHandler(BaseHandler): do_auth): yield run_on_reactor() - target_user = self.hs.parse_userid(event.state_key) + target_user = UserID.from_string(event.state_key) yield self.handle_new_client_event( event, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py new file mode 100644 index 0000000000..962686f4bb --- /dev/null +++ b/synapse/handlers/sync.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseHandler + +from synapse.streams.config import PaginationConfig +from synapse.api.constants import Membership, EventTypes + +from twisted.internet import defer + +import collections +import logging + +logger = logging.getLogger(__name__) + + +SyncConfig = collections.namedtuple("SyncConfig", [ + "user", + "client_info", + "limit", + "gap", + "sort", + "backfill", + "filter", +]) + + +class RoomSyncResult(collections.namedtuple("RoomSyncResult", [ + "room_id", + "limited", + "published", + "events", + "state", + "prev_batch", + "ephemeral", +])): + __slots__ = [] + + def __nonzero__(self): + """Make the result appear empty if there are no updates. This is used + to tell if room needs to be part of the sync result. + """ + return bool(self.events or self.state or self.ephemeral) + + +class SyncResult(collections.namedtuple("SyncResult", [ + "next_batch", # Token for the next sync + "private_user_data", # List of private events for the user. + "public_user_data", # List of public events for all users. + "rooms", # RoomSyncResult for each room. +])): + __slots__ = [] + + def __nonzero__(self): + """Make the result appear empty if there are no updates. This is used + to tell if the notifier needs to wait for more events when polling for + events. + """ + return bool( + self.private_user_data or self.public_user_data or self.rooms + ) + + +class SyncHandler(BaseHandler): + + def __init__(self, hs): + super(SyncHandler, self).__init__(hs) + self.event_sources = hs.get_event_sources() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0): + """Get the sync for a client if we have new data for it now. Otherwise + wait for new data to arrive on the server. If the timeout expires, then + return an empty sync result. + Returns: + A Deferred SyncResult. + """ + if timeout == 0 or since_token is None: + result = yield self.current_sync_for_user(sync_config, since_token) + defer.returnValue(result) + else: + def current_sync_callback(): + return self.current_sync_for_user(sync_config, since_token) + + rm_handler = self.hs.get_handlers().room_member_handler + room_ids = yield rm_handler.get_rooms_for_user(sync_config.user) + result = yield self.notifier.wait_for_events( + sync_config.user, room_ids, + sync_config.filter, timeout, current_sync_callback + ) + defer.returnValue(result) + + def current_sync_for_user(self, sync_config, since_token=None): + """Get the sync for client needed to match what the server has now. + Returns: + A Deferred SyncResult. + """ + if since_token is None: + return self.initial_sync(sync_config) + else: + if sync_config.gap: + return self.incremental_sync_with_gap(sync_config, since_token) + else: + #TODO(mjark): Handle gapless sync + raise NotImplementedError() + + @defer.inlineCallbacks + def initial_sync(self, sync_config): + """Get a sync for a client which is starting without any state + Returns: + A Deferred SyncResult. + """ + if sync_config.sort == "timeline,desc": + # TODO(mjark): Handle going through events in reverse order?. + # What does "most recent events" mean when applying the limits mean + # in this case? + raise NotImplementedError() + + now_token = yield self.event_sources.get_current_token() + + presence_stream = self.event_sources.sources["presence"] + # TODO (mjark): This looks wrong, shouldn't we be getting the presence + # UP to the present rather than after the present? + pagination_config = PaginationConfig(from_token=now_token) + presence, _ = yield presence_stream.get_pagination_rows( + user=sync_config.user, + pagination_config=pagination_config.get_source_config("presence"), + key=None + ) + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=sync_config.user.to_string(), + membership_list=[Membership.INVITE, Membership.JOIN] + ) + + # TODO (mjark): Does public mean "published"? + published_rooms = yield self.store.get_rooms(is_public=True) + published_room_ids = set(r["room_id"] for r in published_rooms) + + rooms = [] + for event in room_list: + room_sync = yield self.initial_sync_for_room( + event.room_id, sync_config, now_token, published_room_ids + ) + rooms.append(room_sync) + + defer.returnValue(SyncResult( + public_user_data=presence, + private_user_data=[], + rooms=rooms, + next_batch=now_token, + )) + + @defer.inlineCallbacks + def initial_sync_for_room(self, room_id, sync_config, now_token, + published_room_ids): + """Sync a room for a client which is starting without any state + Returns: + A Deferred RoomSyncResult. + """ + + recents, prev_batch_token, limited = yield self.load_filtered_recents( + room_id, sync_config, now_token, + ) + + current_state_events = yield self.state_handler.get_current_state( + room_id + ) + + defer.returnValue(RoomSyncResult( + room_id=room_id, + published=room_id in published_room_ids, + events=recents, + prev_batch=prev_batch_token, + state=current_state_events, + limited=limited, + ephemeral=[], + )) + + @defer.inlineCallbacks + def incremental_sync_with_gap(self, sync_config, since_token): + """ Get the incremental delta needed to bring the client up to + date with the server. + Returns: + A Deferred SyncResult. + """ + if sync_config.sort == "timeline,desc": + # TODO(mjark): Handle going through events in reverse order?. + # What does "most recent events" mean when applying the limits mean + # in this case? + raise NotImplementedError() + + now_token = yield self.event_sources.get_current_token() + + presence_source = self.event_sources.sources["presence"] + presence, presence_key = yield presence_source.get_new_events_for_user( + user=sync_config.user, + from_key=since_token.presence_key, + limit=sync_config.limit, + ) + now_token = now_token.copy_and_replace("presence_key", presence_key) + + typing_source = self.event_sources.sources["typing"] + typing, typing_key = yield typing_source.get_new_events_for_user( + user=sync_config.user, + from_key=since_token.typing_key, + limit=sync_config.limit, + ) + now_token = now_token.copy_and_replace("typing_key", typing_key) + + typing_by_room = {event["room_id"]: [event] for event in typing} + for event in typing: + event.pop("room_id") + logger.debug("Typing %r", typing_by_room) + + rm_handler = self.hs.get_handlers().room_member_handler + room_ids = yield rm_handler.get_rooms_for_user(sync_config.user) + + # TODO (mjark): Does public mean "published"? + published_rooms = yield self.store.get_rooms(is_public=True) + published_room_ids = set(r["room_id"] for r in published_rooms) + + room_events, _ = yield self.store.get_room_events_stream( + sync_config.user.to_string(), + from_key=since_token.room_key, + to_key=now_token.room_key, + room_id=None, + limit=sync_config.limit + 1, + ) + + rooms = [] + if len(room_events) <= sync_config.limit: + # There is no gap in any of the rooms. Therefore we can just + # partition the new events by room and return them. + events_by_room_id = {} + for event in room_events: + events_by_room_id.setdefault(event.room_id, []).append(event) + + for room_id in room_ids: + recents = events_by_room_id.get(room_id, []) + state = [event for event in recents if event.is_state()] + if recents: + prev_batch = now_token.copy_and_replace( + "room_key", recents[0].internal_metadata.before + ) + else: + prev_batch = now_token + + state = yield self.check_joined_room( + sync_config, room_id, state + ) + + room_sync = RoomSyncResult( + room_id=room_id, + published=room_id in published_room_ids, + events=recents, + prev_batch=prev_batch, + state=state, + limited=False, + ephemeral=typing_by_room.get(room_id, []) + ) + if room_sync: + rooms.append(room_sync) + else: + for room_id in room_ids: + room_sync = yield self.incremental_sync_with_gap_for_room( + room_id, sync_config, since_token, now_token, + published_room_ids, typing_by_room + ) + if room_sync: + rooms.append(room_sync) + + defer.returnValue(SyncResult( + public_user_data=presence, + private_user_data=[], + rooms=rooms, + next_batch=now_token, + )) + + @defer.inlineCallbacks + def load_filtered_recents(self, room_id, sync_config, now_token, + since_token=None): + limited = True + recents = [] + filtering_factor = 2 + load_limit = max(sync_config.limit * filtering_factor, 100) + max_repeat = 3 # Only try a few times per room, otherwise + room_key = now_token.room_key + + while limited and len(recents) < sync_config.limit and max_repeat: + events, keys = yield self.store.get_recent_events_for_room( + room_id, + limit=load_limit + 1, + from_token=since_token.room_key if since_token else None, + end_token=room_key, + ) + (room_key, _) = keys + loaded_recents = sync_config.filter.filter_room_events(events) + loaded_recents.extend(recents) + recents = loaded_recents + if len(events) <= load_limit: + limited = False + max_repeat -= 1 + + if len(recents) > sync_config.limit: + recents = recents[-sync_config.limit:] + room_key = recents[0].internal_metadata.before + + prev_batch_token = now_token.copy_and_replace( + "room_key", room_key + ) + + defer.returnValue((recents, prev_batch_token, limited)) + + @defer.inlineCallbacks + def incremental_sync_with_gap_for_room(self, room_id, sync_config, + since_token, now_token, + published_room_ids, typing_by_room): + """ Get the incremental delta needed to bring the client up to date for + the room. Gives the client the most recent events and the changes to + state. + Returns: + A Deferred RoomSyncResult + """ + + # TODO(mjark): Check for redactions we might have missed. + + recents, prev_batch_token, limited = yield self.load_filtered_recents( + room_id, sync_config, now_token, since_token, + ) + + logging.debug("Recents %r", recents) + + # TODO(mjark): This seems racy since this isn't being passed a + # token to indicate what point in the stream this is + current_state_events = yield self.state_handler.get_current_state( + room_id + ) + + state_at_previous_sync = yield self.get_state_at_previous_sync( + room_id, since_token=since_token + ) + + state_events_delta = yield self.compute_state_delta( + since_token=since_token, + previous_state=state_at_previous_sync, + current_state=current_state_events, + ) + + state_events_delta = yield self.check_joined_room( + sync_config, room_id, state_events_delta + ) + + room_sync = RoomSyncResult( + room_id=room_id, + published=room_id in published_room_ids, + events=recents, + prev_batch=prev_batch_token, + state=state_events_delta, + limited=limited, + ephemeral=typing_by_room.get(room_id, []) + ) + + logging.debug("Room sync: %r", room_sync) + + defer.returnValue(room_sync) + + @defer.inlineCallbacks + def get_state_at_previous_sync(self, room_id, since_token): + """ Get the room state at the previous sync the client made. + Returns: + A Deferred list of Events. + """ + last_events, token = yield self.store.get_recent_events_for_room( + room_id, end_token=since_token.room_key, limit=1, + ) + + if last_events: + last_event = last_events[0] + last_context = yield self.state_handler.compute_event_context( + last_event + ) + if last_event.is_state(): + state = [last_event] + last_context.current_state.values() + else: + state = last_context.current_state.values() + else: + state = () + defer.returnValue(state) + + def compute_state_delta(self, since_token, previous_state, current_state): + """ Works out the differnce in state between the current state and the + state the client got when it last performed a sync. + Returns: + A list of events. + """ + # TODO(mjark) Check if the state events were received by the server + # after the previous sync, since we need to include those state + # updates even if they occured logically before the previous event. + # TODO(mjark) Check for new redactions in the state events. + previous_dict = {event.event_id: event for event in previous_state} + state_delta = [] + for event in current_state: + if event.event_id not in previous_dict: + state_delta.append(event) + return state_delta + + @defer.inlineCallbacks + def check_joined_room(self, sync_config, room_id, state_delta): + joined = False + for event in state_delta: + if ( + event.type == EventTypes.Member + and event.state_key == sync_config.user.to_string() + ): + if event.content["membership"] == Membership.JOIN: + joined = True + + if joined: + state_delta = yield self.state_handler.get_current_state(room_id) + + defer.returnValue(state_delta) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index ab698b36e1..c69787005f 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -18,6 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.errors import SynapseError, AuthError +from synapse.types import UserID import logging @@ -83,9 +84,15 @@ class TypingNotificationHandler(BaseHandler): if member in self._member_typing_timer: self.clock.cancel_call_later(self._member_typing_timer[member]) + def _cb(): + logger.debug( + "%s has timed out in %s", target_user.to_string(), room_id + ) + self._stopped_typing(member) + self._member_typing_until[member] = until self._member_typing_timer[member] = self.clock.call_later( - timeout / 1000, lambda: self._stopped_typing(member) + timeout / 1000.0, _cb ) if was_present: @@ -114,6 +121,10 @@ class TypingNotificationHandler(BaseHandler): member = RoomMember(room_id=room_id, user=target_user) + if member in self._member_typing_timer: + self.clock.cancel_call_later(self._member_typing_timer[member]) + del self._member_typing_timer[member] + yield self._stopped_typing(member) @defer.inlineCallbacks @@ -136,8 +147,10 @@ class TypingNotificationHandler(BaseHandler): del self._member_typing_until[member] - self.clock.cancel_call_later(self._member_typing_timer[member]) - del self._member_typing_timer[member] + if member in self._member_typing_timer: + # Don't cancel it - either it already expired, or the real + # stopped_typing() will cancel it + del self._member_typing_timer[member] @defer.inlineCallbacks def _push_update(self, room_id, user, typing): @@ -173,7 +186,7 @@ class TypingNotificationHandler(BaseHandler): @defer.inlineCallbacks def _recv_edu(self, origin, content): room_id = content["room_id"] - user = self.homeserver.parse_userid(content["user_id"]) + user = UserID.from_string(content["user_id"]) localusers = set() diff --git a/synapse/http/client.py b/synapse/http/client.py index 7793bab106..198f575cfa 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -63,6 +63,25 @@ class SimpleHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks + def post_json_get_json(self, uri, post_json): + json_str = json.dumps(post_json) + + logger.info("HTTP POST %s -> %s", json_str, uri) + + response = yield self.agent.request( + "POST", + uri.encode("ascii"), + headers=Headers({ + "Content-Type": ["application/json"] + }), + bodyProducer=FileBodyProducer(StringIO(json_str)) + ) + + body = yield readBody(response) + + defer.returnValue(json.loads(body)) + + @defer.inlineCallbacks def get_json(self, uri, args={}): """ Get's some json from the given host and path diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index aa14782b0f..c7bf1b47b8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -72,7 +72,6 @@ class MatrixFederationHttpClient(object): requests. """ - def __init__(self, hs): self.hs = hs self.signing_key = hs.config.signing_key[0] @@ -246,6 +245,43 @@ class MatrixFederationHttpClient(object): defer.returnValue((response.code, body)) @defer.inlineCallbacks + def post_json(self, destination, path, data={}): + """ Sends the specifed json data using POST + + Args: + destination (str): The remote server to send the HTTP request + to. + path (str): The HTTP path. + data (dict): A dict containing the data that will be used as + the request body. This will be encoded as JSON. + + Returns: + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. On a 4xx or 5xx error response a + CodeMessageException is raised. + """ + + def body_callback(method, url_bytes, headers_dict): + self.sign_request( + destination, method, url_bytes, headers_dict, data + ) + return _JsonProducer(data) + + response = yield self._create_request( + destination.encode("ascii"), + "POST", + path.encode("ascii"), + body_callback=body_callback, + headers_dict={"Content-Type": ["application/json"]}, + ) + + logger.debug("Getting resp body") + body = yield readBody(response) + logger.debug("Got resp body") + + defer.returnValue((response.code, body)) + + @defer.inlineCallbacks def get_json(self, destination, path, args={}, retry_on_dns_fail=True): """ GETs some json from the given host homeserver and path diff --git a/synapse/http/server.py b/synapse/http/server.py index 8015a22edf..0f6539e1be 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -16,7 +16,7 @@ from synapse.http.agent_name import AGENT_NAME from synapse.api.errors import ( - cs_exception, SynapseError, CodeMessageException + cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError ) from synapse.util.logcontext import LoggingContext @@ -139,11 +139,7 @@ class JsonResource(HttpServer, resource.Resource): return # Huh. No one wanted to handle that? Fiiiiiine. Send 400. - self._send_response( - request, - 400, - {"error": "Unrecognized request"} - ) + raise UnrecognizedRequestError() except CodeMessageException as e: if isinstance(e, SynapseError): logger.info("%s SynapseError: %s - %s", request, e.code, e.msg) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py new file mode 100644 index 0000000000..a4eb6c817c --- /dev/null +++ b/synapse/http/servlet.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" This module contains base REST classes for constructing REST servlets. """ + +from synapse.api.errors import SynapseError + +import logging + + +logger = logging.getLogger(__name__) + + +class RestServlet(object): + + """ A Synapse REST Servlet. + + An implementing class can either provide its own custom 'register' method, + or use the automatic pattern handling provided by the base class. + + To use this latter, the implementing class instead provides a `PATTERN` + class attribute containing a pre-compiled regular expression. The automatic + register method will then use this method to register any of the following + instance methods associated with the corresponding HTTP method: + + on_GET + on_PUT + on_POST + on_DELETE + on_OPTIONS + + Automatically handles turning CodeMessageExceptions thrown by these methods + into the appropriate HTTP response. + """ + + def register(self, http_server): + """ Register this servlet with the given HTTP server. """ + if hasattr(self, "PATTERN"): + pattern = self.PATTERN + + for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): + if hasattr(self, "on_%s" % (method)): + method_handler = getattr(self, "on_%s" % (method)) + http_server.register_path(method, pattern, method_handler) + else: + raise NotImplementedError("RestServlet must register something.") + + @staticmethod + def parse_integer(request, name, default=None, required=False): + if name in request.args: + try: + return int(request.args[name][0]) + except: + message = "Query parameter %r must be an integer" % (name,) + raise SynapseError(400, message) + else: + if required: + message = "Missing integer query parameter %r" % (name,) + raise SynapseError(400, message) + else: + return default + + @staticmethod + def parse_boolean(request, name, default=None, required=False): + if name in request.args: + try: + return { + "true": True, + "false": False, + }[request.args[name][0]] + except: + message = ( + "Boolean query parameter %r must be one of" + " ['true', 'false']" + ) % (name,) + raise SynapseError(400, message) + else: + if required: + message = "Missing boolean query parameter %r" % (name,) + raise SynapseError(400, message) + else: + return default + + @staticmethod + def parse_string(request, name, default=None, required=False, + allowed_values=None, param_type="string"): + if name in request.args: + value = request.args[name][0] + if allowed_values is not None and value not in allowed_values: + message = "Query parameter %r must be one of [%s]" % ( + name, ", ".join(repr(v) for v in allowed_values) + ) + raise SynapseError(message) + else: + return value + else: + if required: + message = "Missing %s query parameter %r" % (param_type, name) + raise SynapseError(400, message) + else: + return default diff --git a/synapse/notifier.py b/synapse/notifier.py index b9d52d0c4c..e3b6ead620 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext from synapse.util.async import run_on_reactor +from synapse.types import StreamToken import logging @@ -205,6 +206,53 @@ class Notifier(object): [notify(l).addErrback(eb) for l in listeners] ) + @defer.inlineCallbacks + def wait_for_events(self, user, rooms, filter, timeout, callback): + """Wait until the callback returns a non empty response or the + timeout fires. + """ + + deferred = defer.Deferred() + + from_token = StreamToken("s0", "0", "0") + + listener = [_NotificationListener( + user=user, + rooms=rooms, + from_token=from_token, + limit=1, + timeout=timeout, + deferred=deferred, + )] + + if timeout: + self._register_with_keys(listener[0]) + + result = yield callback() + if timeout: + timed_out = [False] + + def _timeout_listener(): + timed_out[0] = True + listener[0].notify(self, [], from_token, from_token) + + self.clock.call_later(timeout/1000., _timeout_listener) + while not result and not timed_out[0]: + yield deferred + deferred = defer.Deferred() + listener[0] = _NotificationListener( + user=user, + rooms=rooms, + from_token=from_token, + limit=1, + timeout=timeout, + deferred=deferred, + ) + self._register_with_keys(listener[0]) + result = yield callback() + + defer.returnValue(result) + def get_events_for(self, user, rooms, pagination_config, timeout): """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any @@ -244,14 +292,14 @@ class Notifier(object): ) if timeout: - self.clock.call_later(timeout/1000.0, _timeout_listener) - self._register_with_keys(listener) yield self._check_for_updates(listener) if not timeout: _timeout_listener() + else: + self.clock.call_later(timeout/1000.0, _timeout_listener) return diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py new file mode 100644 index 0000000000..28e5dae81d --- /dev/null +++ b/synapse/push/__init__.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.streams.config import PaginationConfig +from synapse.types import StreamToken, UserID + +import synapse.util.async +import baserules + +import logging +import fnmatch +import json +import re + +logger = logging.getLogger(__name__) + + +class Pusher(object): + INITIAL_BACKOFF = 1000 + MAX_BACKOFF = 60 * 60 * 1000 + GIVE_UP_AFTER = 24 * 60 * 60 * 1000 + DEFAULT_ACTIONS = ['notify'] + + INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") + + def __init__(self, _hs, instance_handle, user_name, app_id, + app_display_name, device_display_name, pushkey, pushkey_ts, + data, last_token, last_success, failing_since): + self.hs = _hs + self.evStreamHandler = self.hs.get_handlers().event_stream_handler + self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() + self.instance_handle = instance_handle + self.user_name = user_name + self.app_id = app_id + self.app_display_name = app_display_name + self.device_display_name = device_display_name + self.pushkey = pushkey + self.pushkey_ts = pushkey_ts + self.data = data + self.last_token = last_token + self.last_success = last_success # not actually used + self.backoff_delay = Pusher.INITIAL_BACKOFF + self.failing_since = failing_since + self.alive = True + + # The last value of last_active_time that we saw + self.last_last_active_time = 0 + self.has_unread = True + + @defer.inlineCallbacks + def _actions_for_event(self, ev): + """ + This should take into account notification settings that the user + has configured both globally and per-room when we have the ability + to do such things. + """ + if ev['user_id'] == self.user_name: + # let's assume you probably know about messages you sent yourself + defer.returnValue(['dont_notify']) + + if ev['type'] == 'm.room.member': + if ev['state_key'] != self.user_name: + defer.returnValue(['dont_notify']) + + rules = yield self.store.get_push_rules_for_user_name(self.user_name) + + for r in rules: + r['conditions'] = json.loads(r['conditions']) + r['actions'] = json.loads(r['actions']) + + user_name_localpart = UserID.from_string(self.user_name).localpart + + rules.extend(baserules.make_base_rules(user_name_localpart)) + + # get *our* member event for display name matching + member_events_for_room = yield self.store.get_current_state( + room_id=ev['room_id'], + event_type='m.room.member', + state_key=None + ) + my_display_name = None + room_member_count = 0 + for mev in member_events_for_room: + if mev.content['membership'] != 'join': + continue + + # This loop does two things: + # 1) Find our current display name + if mev.state_key == self.user_name and 'displayname' in mev.content: + my_display_name = mev.content['displayname'] + + # and 2) Get the number of people in that room + room_member_count += 1 + + for r in rules: + matches = True + + conditions = r['conditions'] + actions = r['actions'] + + for c in conditions: + matches &= self._event_fulfills_condition( + ev, c, display_name=my_display_name, + room_member_count=room_member_count + ) + # ignore rules with no actions (we have an explict 'dont_notify' + if len(actions) == 0: + logger.warn( + "Ignoring rule id %s with no actions for user %s" % + (r['rule_id'], r['user_name']) + ) + continue + if matches: + defer.returnValue(actions) + + defer.returnValue(Pusher.DEFAULT_ACTIONS) + + def _event_fulfills_condition(self, ev, condition, display_name, room_member_count): + if condition['kind'] == 'event_match': + if 'pattern' not in condition: + logger.warn("event_match condition with no pattern") + return False + pat = condition['pattern'] + + if pat.strip("*?[]") == pat: + # no special glob characters so we assume the user means + # 'contains this string' rather than 'is this string' + pat = "*%s*" % (pat,) + + val = _value_for_dotted_key(condition['key'], ev) + if val is None: + return False + return fnmatch.fnmatch(val.upper(), pat.upper()) + elif condition['kind'] == 'device': + if 'instance_handle' not in condition: + return True + return condition['instance_handle'] == self.instance_handle + elif condition['kind'] == 'contains_display_name': + # This is special because display names can be different + # between rooms and so you can't really hard code it in a rule. + # Optimisation: we should cache these names and update them from + # the event stream. + if 'content' not in ev or 'body' not in ev['content']: + return False + if not display_name: + return False + return fnmatch.fnmatch( + ev['content']['body'].upper(), "*%s*" % (display_name.upper(),) + ) + elif condition['kind'] == 'room_member_count': + if 'is' not in condition: + return False + m = Pusher.INEQUALITY_EXPR.match(condition['is']) + if not m: + return False + ineq = m.group(1) + rhs = m.group(2) + if not rhs.isdigit(): + return False + rhs = int(rhs) + + if ineq == '' or ineq == '==': + return room_member_count == rhs + elif ineq == '<': + return room_member_count < rhs + elif ineq == '>': + return room_member_count > rhs + elif ineq == '>=': + return room_member_count >= rhs + elif ineq == '<=': + return room_member_count <= rhs + else: + return False + else: + return True + + @defer.inlineCallbacks + def get_context_for_event(self, ev): + name_aliases = yield self.store.get_room_name_and_aliases( + ev['room_id'] + ) + + ctx = {'aliases': name_aliases[1]} + if name_aliases[0] is not None: + ctx['name'] = name_aliases[0] + + their_member_events_for_room = yield self.store.get_current_state( + room_id=ev['room_id'], + event_type='m.room.member', + state_key=ev['user_id'] + ) + for mev in their_member_events_for_room: + if mev.content['membership'] == 'join' and 'displayname' in mev.content: + dn = mev.content['displayname'] + if dn is not None: + ctx['sender_display_name'] = dn + + defer.returnValue(ctx) + + @defer.inlineCallbacks + def start(self): + if not self.last_token: + # First-time setup: get a token to start from (we can't + # just start from no token, ie. 'now' + # because we need the result to be reproduceable in case + # we fail to dispatch the push) + config = PaginationConfig(from_token=None, limit='1') + chunk = yield self.evStreamHandler.get_stream( + self.user_name, config, timeout=0) + self.last_token = chunk['end'] + self.store.update_pusher_last_token( + self.user_name, self.pushkey, self.last_token) + logger.info("Pusher %s for user %s starting from token %s", + self.pushkey, self.user_name, self.last_token) + + while self.alive: + from_tok = StreamToken.from_string(self.last_token) + config = PaginationConfig(from_token=from_tok, limit='1') + chunk = yield self.evStreamHandler.get_stream( + self.user_name, config, + timeout=100*365*24*60*60*1000, affect_presence=False + ) + + # limiting to 1 may get 1 event plus 1 presence event, so + # pick out the actual event + single_event = None + for c in chunk['chunk']: + if 'event_id' in c: # Hmmm... + single_event = c + break + if not single_event: + self.last_token = chunk['end'] + continue + + if not self.alive: + continue + + processed = False + actions = yield self._actions_for_event(single_event) + tweaks = _tweaks_for_actions(actions) + + if len(actions) == 0: + logger.warn("Empty actions! Using default action.") + actions = Pusher.DEFAULT_ACTIONS + if 'notify' not in actions and 'dont_notify' not in actions: + logger.warn("Neither notify nor dont_notify in actions: adding default") + actions.extend(Pusher.DEFAULT_ACTIONS) + if 'dont_notify' in actions: + logger.debug( + "%s for %s: dont_notify", + single_event['event_id'], self.user_name + ) + processed = True + else: + rejected = yield self.dispatch_push(single_event, tweaks) + self.has_unread = True + if isinstance(rejected, list) or isinstance(rejected, tuple): + processed = True + for pk in rejected: + if pk != self.pushkey: + # for sanity, we only remove the pushkey if it + # was the one we actually sent... + logger.warn( + ("Ignoring rejected pushkey %s because we" + " didn't send it"), pk + ) + else: + logger.info( + "Pushkey %s was rejected: removing", + pk + ) + yield self.hs.get_pusherpool().remove_pusher( + self.app_id, pk + ) + + if not self.alive: + continue + + if processed: + self.backoff_delay = Pusher.INITIAL_BACKOFF + self.last_token = chunk['end'] + self.store.update_pusher_last_token_and_success( + self.user_name, + self.pushkey, + self.last_token, + self.clock.time_msec() + ) + if self.failing_since: + self.failing_since = None + self.store.update_pusher_failing_since( + self.user_name, + self.pushkey, + self.failing_since) + else: + if not self.failing_since: + self.failing_since = self.clock.time_msec() + self.store.update_pusher_failing_since( + self.user_name, + self.pushkey, + self.failing_since + ) + + if (self.failing_since and + self.failing_since < + self.clock.time_msec() - Pusher.GIVE_UP_AFTER): + # we really only give up so that if the URL gets + # fixed, we don't suddenly deliver a load + # of old notifications. + logger.warn("Giving up on a notification to user %s, " + "pushkey %s", + self.user_name, self.pushkey) + self.backoff_delay = Pusher.INITIAL_BACKOFF + self.last_token = chunk['end'] + self.store.update_pusher_last_token( + self.user_name, + self.pushkey, + self.last_token + ) + + self.failing_since = None + self.store.update_pusher_failing_since( + self.user_name, + self.pushkey, + self.failing_since + ) + else: + logger.warn("Failed to dispatch push for user %s " + "(failing for %dms)." + "Trying again in %dms", + self.user_name, + self.clock.time_msec() - self.failing_since, + self.backoff_delay) + yield synapse.util.async.sleep(self.backoff_delay / 1000.0) + self.backoff_delay *= 2 + if self.backoff_delay > Pusher.MAX_BACKOFF: + self.backoff_delay = Pusher.MAX_BACKOFF + + def stop(self): + self.alive = False + + def dispatch_push(self, p, tweaks): + """ + Overridden by implementing classes to actually deliver the notification + Args: + p: The event to notify for as a single event from the event stream + Returns: If the notification was delivered, an array containing any + pushkeys that were rejected by the push gateway. + False if the notification could not be delivered (ie. + should be retried). + """ + pass + + def reset_badge_count(self): + pass + + def presence_changed(self, state): + """ + We clear badge counts whenever a user's last_active time is bumped + This is by no means perfect but I think it's the best we can do + without read receipts. + """ + if 'last_active' in state.state: + last_active = state.state['last_active'] + if last_active > self.last_last_active_time: + self.last_last_active_time = last_active + if self.has_unread: + logger.info("Resetting badge count for %s", self.user_name) + self.reset_badge_count() + self.has_unread = False + + +def _value_for_dotted_key(dotted_key, event): + parts = dotted_key.split(".") + val = event + while len(parts) > 0: + if parts[0] not in val: + return None + val = val[parts[0]] + parts = parts[1:] + return val + + +def _tweaks_for_actions(actions): + tweaks = {} + for a in actions: + if not isinstance(a, dict): + continue + if 'set_sound' in a: + tweaks['sound'] = a['set_sound'] + return tweaks + + +class PusherConfigException(Exception): + def __init__(self, msg): + super(PusherConfigException, self).__init__(msg) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py new file mode 100644 index 0000000000..382de118e0 --- /dev/null +++ b/synapse/push/baserules.py @@ -0,0 +1,48 @@ +def make_base_rules(user_name): + rules = [ + { + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'content.body', + 'pattern': '*%s*' % (user_name,), # Matrix ID match + } + ], + 'actions': [ + 'notify', + { + 'set_sound': 'default' + } + ] + }, + { + 'conditions': [ + { + 'kind': 'contains_display_name' + } + ], + 'actions': [ + 'notify', + { + 'set_sound': 'default' + } + ] + }, + { + 'conditions': [ + { + 'kind': 'room_member_count', + 'is': '2' + } + ], + 'actions': [ + 'notify', + { + 'set_sound': 'default' + } + ] + } + ] + for r in rules: + r['priority_class'] = 0 + return rules \ No newline at end of file diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py new file mode 100644 index 0000000000..7c6953c989 --- /dev/null +++ b/synapse/push/httppusher.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.push import Pusher, PusherConfigException +from synapse.http.client import SimpleHttpClient + +from twisted.internet import defer + +import logging + +logger = logging.getLogger(__name__) + + +class HttpPusher(Pusher): + def __init__(self, _hs, instance_handle, user_name, app_id, + app_display_name, device_display_name, pushkey, pushkey_ts, + data, last_token, last_success, failing_since): + super(HttpPusher, self).__init__( + _hs, + instance_handle, + user_name, + app_id, + app_display_name, + device_display_name, + pushkey, + pushkey_ts, + data, + last_token, + last_success, + failing_since + ) + if 'url' not in data: + raise PusherConfigException( + "'url' required in data for HTTP pusher" + ) + self.url = data['url'] + self.httpCli = SimpleHttpClient(self.hs) + self.data_minus_url = {} + self.data_minus_url.update(self.data) + del self.data_minus_url['url'] + + @defer.inlineCallbacks + def _build_notification_dict(self, event, tweaks): + # we probably do not want to push for every presence update + # (we may want to be able to set up notifications when specific + # people sign in, but we'd want to only deliver the pertinent ones) + # Actually, presence events will not get this far now because we + # need to filter them out in the main Pusher code. + if 'event_id' not in event: + defer.returnValue(None) + + ctx = yield self.get_context_for_event(event) + + d = { + 'notification': { + 'id': event['event_id'], + 'type': event['type'], + 'sender': event['user_id'], + 'counts': { # -- we don't mark messages as read yet so + # we have no way of knowing + # Just set the badge to 1 until we have read receipts + 'unread': 1, + # 'missed_calls': 2 + }, + 'devices': [ + { + 'app_id': self.app_id, + 'pushkey': self.pushkey, + 'pushkey_ts': long(self.pushkey_ts / 1000), + 'data': self.data_minus_url, + 'tweaks': tweaks + } + ] + } + } + if event['type'] == 'm.room.member': + d['notification']['membership'] = event['content']['membership'] + if 'content' in event: + d['notification']['content'] = event['content'] + + if len(ctx['aliases']): + d['notification']['room_alias'] = ctx['aliases'][0] + if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0: + d['notification']['sender_display_name'] = ctx['sender_display_name'] + if 'name' in ctx and len(ctx['name']) > 0: + d['notification']['room_name'] = ctx['name'] + + defer.returnValue(d) + + @defer.inlineCallbacks + def dispatch_push(self, event, tweaks): + notification_dict = yield self._build_notification_dict(event, tweaks) + if not notification_dict: + defer.returnValue([]) + try: + resp = yield self.httpCli.post_json_get_json(self.url, notification_dict) + except: + logger.exception("Failed to push %s ", self.url) + defer.returnValue(False) + rejected = [] + if 'rejected' in resp: + rejected = resp['rejected'] + defer.returnValue(rejected) + + @defer.inlineCallbacks + def reset_badge_count(self): + d = { + 'notification': { + 'id': '', + 'type': None, + 'sender': '', + 'counts': { + 'unread': 0, + 'missed_calls': 0 + }, + 'devices': [ + { + 'app_id': self.app_id, + 'pushkey': self.pushkey, + 'pushkey_ts': long(self.pushkey_ts / 1000), + 'data': self.data_minus_url, + } + ] + } + } + try: + resp = yield self.httpCli.post_json_get_json(self.url, d) + except: + logger.exception("Failed to push %s ", self.url) + defer.returnValue(False) + rejected = [] + if 'rejected' in resp: + rejected = resp['rejected'] + defer.returnValue(rejected) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py new file mode 100644 index 0000000000..4892c21e7b --- /dev/null +++ b/synapse/push/pusherpool.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from httppusher import HttpPusher +from synapse.push import PusherConfigException + +import logging +import json + +logger = logging.getLogger(__name__) + + +class PusherPool: + def __init__(self, _hs): + self.hs = _hs + self.store = self.hs.get_datastore() + self.pushers = {} + self.last_pusher_started = -1 + + distributor = self.hs.get_distributor() + distributor.observe( + "user_presence_changed", self.user_presence_changed + ) + + @defer.inlineCallbacks + def user_presence_changed(self, user, state): + user_name = user.to_string() + + # until we have read receipts, pushers use this to reset a user's + # badge counters to zero + for p in self.pushers.values(): + if p.user_name == user_name: + yield p.presence_changed(state) + + @defer.inlineCallbacks + def start(self): + pushers = yield self.store.get_all_pushers() + for p in pushers: + p['data'] = json.loads(p['data']) + self._start_pushers(pushers) + + @defer.inlineCallbacks + def add_pusher(self, user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, pushkey, lang, data): + # we try to create the pusher just to validate the config: it + # will then get pulled out of the database, + # recreated, added and started: this means we have only one + # code path adding pushers. + self._create_pusher({ + "user_name": user_name, + "kind": kind, + "instance_handle": instance_handle, + "app_id": app_id, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "pushkey": pushkey, + "pushkey_ts": self.hs.get_clock().time_msec(), + "lang": lang, + "data": data, + "last_token": None, + "last_success": None, + "failing_since": None + }) + yield self._add_pusher_to_store( + user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, + pushkey, lang, data + ) + + @defer.inlineCallbacks + def _add_pusher_to_store(self, user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, + pushkey, lang, data): + yield self.store.add_pusher( + user_name=user_name, + instance_handle=instance_handle, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + pushkey_ts=self.hs.get_clock().time_msec(), + lang=lang, + data=json.dumps(data) + ) + self._refresh_pusher((app_id, pushkey)) + + def _create_pusher(self, pusherdict): + if pusherdict['kind'] == 'http': + return HttpPusher( + self.hs, + instance_handle=pusherdict['instance_handle'], + user_name=pusherdict['user_name'], + app_id=pusherdict['app_id'], + app_display_name=pusherdict['app_display_name'], + device_display_name=pusherdict['device_display_name'], + pushkey=pusherdict['pushkey'], + pushkey_ts=pusherdict['pushkey_ts'], + data=pusherdict['data'], + last_token=pusherdict['last_token'], + last_success=pusherdict['last_success'], + failing_since=pusherdict['failing_since'] + ) + else: + raise PusherConfigException( + "Unknown pusher type '%s' for user %s" % + (pusherdict['kind'], pusherdict['user_name']) + ) + + @defer.inlineCallbacks + def _refresh_pusher(self, app_id_pushkey): + p = yield self.store.get_pushers_by_app_id_and_pushkey( + app_id_pushkey + ) + p['data'] = json.loads(p['data']) + + self._start_pushers([p]) + + def _start_pushers(self, pushers): + logger.info("Starting %d pushers", len(pushers)) + for pusherdict in pushers: + p = self._create_pusher(pusherdict) + if p: + fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey']) + if fullid in self.pushers: + self.pushers[fullid].stop() + self.pushers[fullid] = p + p.start() + + @defer.inlineCallbacks + def remove_pusher(self, app_id, pushkey): + fullid = "%s:%s" % (app_id, pushkey) + if fullid in self.pushers: + logger.info("Stopping pusher %s", fullid) + self.pushers[fullid].stop() + del self.pushers[fullid] + yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py new file mode 100644 index 0000000000..826a36f203 --- /dev/null +++ b/synapse/python_dependencies.py @@ -0,0 +1,80 @@ +import logging +from distutils.version import LooseVersion + +logger = logging.getLogger(__name__) + +REQUIREMENTS = { + "syutil==0.0.2": ["syutil"], + "matrix_angular_sdk==0.6.0": ["syweb>=0.6.0"], + "Twisted==14.0.2": ["twisted==14.0.2"], + "service_identity>=1.0.0": ["service_identity>=1.0.0"], + "pyopenssl>=0.14": ["OpenSSL>=0.14"], + "pyyaml": ["yaml"], + "pyasn1": ["pyasn1"], + "pynacl": ["nacl"], + "daemonize": ["daemonize"], + "py-bcrypt": ["bcrypt"], + "frozendict>=0.4": ["frozendict"], + "pillow": ["PIL"], +} + + +class MissingRequirementError(Exception): + pass + + +def check_requirements(): + """Checks that all the modules needed by synapse have been correctly + installed and are at the correct version""" + for dependency, module_requirements in REQUIREMENTS.items(): + for module_requirement in module_requirements: + if ">=" in module_requirement: + module_name, required_version = module_requirement.split(">=") + version_test = ">=" + elif "==" in module_requirement: + module_name, required_version = module_requirement.split("==") + version_test = "==" + else: + module_name = module_requirement + version_test = None + + try: + module = __import__(module_name) + except ImportError: + logging.exception( + "Can't import %r which is part of %r", + module_name, dependency + ) + raise MissingRequirementError( + "Can't import %r which is part of %r" + % (module_name, dependency) + ) + version = getattr(module, "__version__", None) + file_path = getattr(module, "__file__", None) + logger.info( + "Using %r version %r from %r to satisfy %r", + module_name, version, file_path, dependency + ) + + if version_test == ">=": + if version is None: + raise MissingRequirementError( + "Version of %r isn't set as __version__ of module %r" + % (dependency, module_name) + ) + if LooseVersion(version) < LooseVersion(required_version): + raise MissingRequirementError( + "Version of %r in %r is too old. %r < %r" + % (dependency, file_path, version, required_version) + ) + elif version_test == "==": + if version is None: + raise MissingRequirementError( + "Version of %r isn't set as __version__ of module %r" + % (dependency, module_name) + ) + if LooseVersion(version) != LooseVersion(required_version): + raise MissingRequirementError( + "Unexpected version of %r in %r. %r != %r" + % (dependency, file_path, version, required_version) + ) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 88ec9cd27d..1a84d94cd9 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd +# Copyright 2015 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,36 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from . import ( - room, events, register, login, profile, presence, initial_sync, directory, - voip, admin, -) - - -class RestServletFactory(object): - - """ A factory for creating REST servlets. - - These REST servlets represent the entire client-server REST API. Generally - speaking, they serve as wrappers around events and the handlers that - process them. - - See synapse.events for information on synapse events. - """ - - def __init__(self, hs): - client_resource = hs.get_resource_for_client() - - # TODO(erikj): There *must* be a better way of doing this. - room.register_servlets(hs, client_resource) - events.register_servlets(hs, client_resource) - register.register_servlets(hs, client_resource) - login.register_servlets(hs, client_resource) - profile.register_servlets(hs, client_resource) - presence.register_servlets(hs, client_resource) - initial_sync.register_servlets(hs, client_resource) - directory.register_servlets(hs, client_resource) - voip.register_servlets(hs, client_resource) - admin.register_servlets(hs, client_resource) diff --git a/synapse/rest/base.py b/synapse/rest/base.py deleted file mode 100644 index c583945527..0000000000 --- a/synapse/rest/base.py +++ /dev/null @@ -1,80 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" This module contains base REST classes for constructing REST servlets. """ -from synapse.api.urls import CLIENT_PREFIX -from synapse.rest.transactions import HttpTransactionStore -import re - -import logging - - -logger = logging.getLogger(__name__) - - -def client_path_pattern(path_regex): - """Creates a regex compiled client path with the correct client path - prefix. - - Args: - path_regex (str): The regex string to match. This should NOT have a ^ - as this will be prefixed. - Returns: - SRE_Pattern - """ - return re.compile("^" + CLIENT_PREFIX + path_regex) - - -class RestServlet(object): - - """ A Synapse REST Servlet. - - An implementing class can either provide its own custom 'register' method, - or use the automatic pattern handling provided by the base class. - - To use this latter, the implementing class instead provides a `PATTERN` - class attribute containing a pre-compiled regular expression. The automatic - register method will then use this method to register any of the following - instance methods associated with the corresponding HTTP method: - - on_GET - on_PUT - on_POST - on_DELETE - on_OPTIONS - - Automatically handles turning CodeMessageExceptions thrown by these methods - into the appropriate HTTP response. - """ - - def __init__(self, hs): - self.hs = hs - - self.handlers = hs.get_handlers() - self.builder_factory = hs.get_event_builder_factory() - self.auth = hs.get_auth() - self.txns = HttpTransactionStore() - - def register(self, http_server): - """ Register this servlet with the given HTTP server. """ - if hasattr(self, "PATTERN"): - pattern = self.PATTERN - - for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): - if hasattr(self, "on_%s" % (method)): - method_handler = getattr(self, "on_%s" % (method)) - http_server.register_path(method, pattern, method_handler) - else: - raise NotImplementedError("RestServlet must register something.") diff --git a/synapse/rest/client/__init__.py b/synapse/rest/client/__init__.py new file mode 100644 index 0000000000..1a84d94cd9 --- /dev/null +++ b/synapse/rest/client/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/rest/client/v1/__init__.py b/synapse/rest/client/v1/__init__.py new file mode 100644 index 0000000000..d8d01cdd16 --- /dev/null +++ b/synapse/rest/client/v1/__init__.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import ( + room, events, register, login, profile, presence, initial_sync, directory, + voip, admin, pusher, push_rule +) + +from synapse.http.server import JsonResource + + +class ClientV1RestResource(JsonResource): + """A resource for version 1 of the matrix client API.""" + + def __init__(self, hs): + JsonResource.__init__(self) + self.register_servlets(self, hs) + + @staticmethod + def register_servlets(client_resource, hs): + room.register_servlets(hs, client_resource) + events.register_servlets(hs, client_resource) + register.register_servlets(hs, client_resource) + login.register_servlets(hs, client_resource) + profile.register_servlets(hs, client_resource) + presence.register_servlets(hs, client_resource) + initial_sync.register_servlets(hs, client_resource) + directory.register_servlets(hs, client_resource) + voip.register_servlets(hs, client_resource) + admin.register_servlets(hs, client_resource) + pusher.register_servlets(hs, client_resource) + push_rule.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin.py b/synapse/rest/client/v1/admin.py index 0aa83514c8..2ce754b028 100644 --- a/synapse/rest/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -16,20 +16,22 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError -from base import RestServlet, client_path_pattern +from synapse.types import UserID + +from base import ClientV1RestServlet, client_path_pattern import logging logger = logging.getLogger(__name__) -class WhoisRestServlet(RestServlet): +class WhoisRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/admin/whois/(?P<user_id>[^/]*)") @defer.inlineCallbacks def on_GET(self, request, user_id): - target_user = self.hs.parse_userid(user_id) - auth_user = yield self.auth.get_user_by_req(request) + target_user = UserID.from_string(user_id) + auth_user, client = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(auth_user) if not is_admin and target_user != auth_user: diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py new file mode 100644 index 0000000000..72332bdb10 --- /dev/null +++ b/synapse/rest/client/v1/base.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains base REST classes for constructing client v1 servlets. +""" + +from synapse.http.servlet import RestServlet +from synapse.api.urls import CLIENT_PREFIX +from .transactions import HttpTransactionStore +import re + +import logging + + +logger = logging.getLogger(__name__) + + +def client_path_pattern(path_regex): + """Creates a regex compiled client path with the correct client path + prefix. + + Args: + path_regex (str): The regex string to match. This should NOT have a ^ + as this will be prefixed. + Returns: + SRE_Pattern + """ + return re.compile("^" + CLIENT_PREFIX + path_regex) + + +class ClientV1RestServlet(RestServlet): + """A base Synapse REST Servlet for the client version 1 API. + """ + + def __init__(self, hs): + self.hs = hs + self.handlers = hs.get_handlers() + self.builder_factory = hs.get_event_builder_factory() + self.auth = hs.get_auth() + self.txns = HttpTransactionStore() diff --git a/synapse/rest/directory.py b/synapse/rest/client/v1/directory.py index 7ff44fdd9e..8f65efec5f 100644 --- a/synapse/rest/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError, Codes -from base import RestServlet, client_path_pattern +from synapse.types import RoomAlias +from .base import ClientV1RestServlet, client_path_pattern import json import logging @@ -30,12 +31,12 @@ def register_servlets(hs, http_server): ClientDirectoryServer(hs).register(http_server) -class ClientDirectoryServer(RestServlet): +class ClientDirectoryServer(ClientV1RestServlet): PATTERN = client_path_pattern("/directory/room/(?P<room_alias>[^/]*)$") @defer.inlineCallbacks def on_GET(self, request, room_alias): - room_alias = self.hs.parse_roomalias(room_alias) + room_alias = RoomAlias.from_string(room_alias) dir_handler = self.handlers.directory_handler res = yield dir_handler.get_association(room_alias) @@ -44,7 +45,7 @@ class ClientDirectoryServer(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_alias): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) if not "room_id" in content: @@ -53,7 +54,7 @@ class ClientDirectoryServer(RestServlet): logger.debug("Got content: %s", content) - room_alias = self.hs.parse_roomalias(room_alias) + room_alias = RoomAlias.from_string(room_alias) logger.debug("Got room name: %s", room_alias.to_string()) @@ -84,7 +85,7 @@ class ClientDirectoryServer(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, room_alias): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(user) if not is_admin: @@ -92,7 +93,7 @@ class ClientDirectoryServer(RestServlet): dir_handler = self.handlers.directory_handler - room_alias = self.hs.parse_roomalias(room_alias) + room_alias = RoomAlias.from_string(room_alias) yield dir_handler.delete_association( user.to_string(), room_alias diff --git a/synapse/rest/events.py b/synapse/rest/client/v1/events.py index cf6d13f817..77b7c25a03 100644 --- a/synapse/rest/events.py +++ b/synapse/rest/client/v1/events.py @@ -18,7 +18,8 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.streams.config import PaginationConfig -from synapse.rest.base import RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_pattern +from synapse.events.utils import serialize_event import logging @@ -26,14 +27,14 @@ import logging logger = logging.getLogger(__name__) -class EventStreamRestServlet(RestServlet): +class EventStreamRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/events$") DEFAULT_LONGPOLL_TIME_MS = 30000 @defer.inlineCallbacks def on_GET(self, request): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -44,8 +45,11 @@ class EventStreamRestServlet(RestServlet): except ValueError: raise SynapseError(400, "timeout must be in milliseconds.") + as_client_event = "raw" not in request.args + chunk = yield handler.get_stream( - auth_user.to_string(), pagin_config, timeout=timeout + auth_user.to_string(), pagin_config, timeout=timeout, + as_client_event=as_client_event ) except: logger.exception("Event stream failed") @@ -58,17 +62,22 @@ class EventStreamRestServlet(RestServlet): # TODO: Unit test gets, with and without auth, with different kinds of events. -class EventRestServlet(RestServlet): +class EventRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$") + def __init__(self, hs): + super(EventRestServlet, self).__init__(hs) + self.clock = hs.get_clock() + @defer.inlineCallbacks def on_GET(self, request, event_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) handler = self.handlers.event_handler event = yield handler.get_event(auth_user, event_id) + time_now = self.clock.time_msec() if event: - defer.returnValue((200, self.hs.serialize_event(event))) + defer.returnValue((200, serialize_event(event, time_now))) else: defer.returnValue((404, "Event not found.")) diff --git a/synapse/rest/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index a571589581..4a259bba64 100644 --- a/synapse/rest/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -16,23 +16,26 @@ from twisted.internet import defer from synapse.streams.config import PaginationConfig -from base import RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_pattern # TODO: Needs unit testing -class InitialSyncRestServlet(RestServlet): +class InitialSyncRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/initialSync$") @defer.inlineCallbacks def on_GET(self, request): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) with_feedback = "feedback" in request.args + as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) handler = self.handlers.message_handler content = yield handler.snapshot_all_rooms( user_id=user.to_string(), pagin_config=pagination_config, - feedback=with_feedback) + feedback=with_feedback, + as_client_event=as_client_event + ) defer.returnValue((200, content)) diff --git a/synapse/rest/login.py b/synapse/rest/client/v1/login.py index 6b8deff67b..7116ac98e8 100644 --- a/synapse/rest/login.py +++ b/synapse/rest/client/v1/login.py @@ -17,12 +17,12 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.types import UserID -from base import RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_pattern import json -class LoginRestServlet(RestServlet): +class LoginRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login$") PASS_TYPE = "m.login.password" @@ -64,7 +64,7 @@ class LoginRestServlet(RestServlet): defer.returnValue((200, result)) -class LoginFallbackRestServlet(RestServlet): +class LoginFallbackRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login/fallback$") def on_GET(self, request): @@ -73,7 +73,7 @@ class LoginFallbackRestServlet(RestServlet): return (200, {}) -class PasswordResetRestServlet(RestServlet): +class PasswordResetRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login/reset") @defer.inlineCallbacks diff --git a/synapse/rest/presence.py b/synapse/rest/client/v1/presence.py index ca4d2d21f0..7feb4aadb1 100644 --- a/synapse/rest/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -18,7 +18,8 @@ from twisted.internet import defer from synapse.api.errors import SynapseError -from base import RestServlet, client_path_pattern +from synapse.types import UserID +from .base import ClientV1RestServlet, client_path_pattern import json import logging @@ -26,13 +27,13 @@ import logging logger = logging.getLogger(__name__) -class PresenceStatusRestServlet(RestServlet): +class PresenceStatusRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status") @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) - user = self.hs.parse_userid(user_id) + auth_user, client = yield self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) state = yield self.handlers.presence_handler.get_state( target_user=user, auth_user=auth_user) @@ -41,8 +42,8 @@ class PresenceStatusRestServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) - user = self.hs.parse_userid(user_id) + auth_user, client = yield self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) state = {} try: @@ -71,13 +72,13 @@ class PresenceStatusRestServlet(RestServlet): return (200, {}) -class PresenceListRestServlet(RestServlet): +class PresenceListRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/presence/list/(?P<user_id>[^/]*)") @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) - user = self.hs.parse_userid(user_id) + auth_user, client = yield self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) if not self.hs.is_mine(user): raise SynapseError(400, "User not hosted on this Home Server") @@ -96,8 +97,8 @@ class PresenceListRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) - user = self.hs.parse_userid(user_id) + auth_user, client = yield self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) if not self.hs.is_mine(user): raise SynapseError(400, "User not hosted on this Home Server") @@ -118,7 +119,7 @@ class PresenceListRestServlet(RestServlet): raise SynapseError(400, "Bad invite value.") if len(u) == 0: continue - invited_user = self.hs.parse_userid(u) + invited_user = UserID.from_string(u) yield self.handlers.presence_handler.send_invite( observer_user=user, observed_user=invited_user ) @@ -129,7 +130,7 @@ class PresenceListRestServlet(RestServlet): raise SynapseError(400, "Bad drop value.") if len(u) == 0: continue - dropped_user = self.hs.parse_userid(u) + dropped_user = UserID.from_string(u) yield self.handlers.presence_handler.drop( observer_user=user, observed_user=dropped_user ) diff --git a/synapse/rest/profile.py b/synapse/rest/client/v1/profile.py index dc6eb424b0..15d6f3fc6c 100644 --- a/synapse/rest/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -16,17 +16,18 @@ """ This module contains REST servlets to do with profile: /profile/<paths> """ from twisted.internet import defer -from base import RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_pattern +from synapse.types import UserID import json -class ProfileDisplaynameRestServlet(RestServlet): +class ProfileDisplaynameRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname") @defer.inlineCallbacks def on_GET(self, request, user_id): - user = self.hs.parse_userid(user_id) + user = UserID.from_string(user_id) displayname = yield self.handlers.profile_handler.get_displayname( user, @@ -36,8 +37,8 @@ class ProfileDisplaynameRestServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) - user = self.hs.parse_userid(user_id) + auth_user, client = yield self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) try: content = json.loads(request.content.read()) @@ -54,12 +55,12 @@ class ProfileDisplaynameRestServlet(RestServlet): return (200, {}) -class ProfileAvatarURLRestServlet(RestServlet): +class ProfileAvatarURLRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url") @defer.inlineCallbacks def on_GET(self, request, user_id): - user = self.hs.parse_userid(user_id) + user = UserID.from_string(user_id) avatar_url = yield self.handlers.profile_handler.get_avatar_url( user, @@ -69,8 +70,8 @@ class ProfileAvatarURLRestServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) - user = self.hs.parse_userid(user_id) + auth_user, client = yield self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) try: content = json.loads(request.content.read()) @@ -87,12 +88,12 @@ class ProfileAvatarURLRestServlet(RestServlet): return (200, {}) -class ProfileRestServlet(RestServlet): +class ProfileRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)") @defer.inlineCallbacks def on_GET(self, request, user_id): - user = self.hs.parse_userid(user_id) + user = UserID.from_string(user_id) displayname = yield self.handlers.profile_handler.get_displayname( user, diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py new file mode 100644 index 0000000000..faa7919fbb --- /dev/null +++ b/synapse/rest/client/v1/push_rule.py @@ -0,0 +1,406 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError, NotFoundError, \ + StoreError +from .base import ClientV1RestServlet, client_path_pattern +from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +import synapse.push.baserules as baserules + +import json + + +class PushRuleRestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/pushrules/.*$") + PRIORITY_CLASS_MAP = { + 'default': 0, + 'underride': 1, + 'sender': 2, + 'room': 3, + 'content': 4, + 'override': 5, + } + PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()} + SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( + "Unrecognised request: You probably wanted a trailing slash") + + def rule_spec_from_path(self, path): + if len(path) < 2: + raise UnrecognizedRequestError() + if path[0] != 'pushrules': + raise UnrecognizedRequestError() + + scope = path[1] + path = path[2:] + if scope not in ['global', 'device']: + raise UnrecognizedRequestError() + + device = None + if scope == 'device': + if len(path) == 0: + raise UnrecognizedRequestError() + device = path[0] + path = path[1:] + + if len(path) == 0: + raise UnrecognizedRequestError() + + template = path[0] + path = path[1:] + + if len(path) == 0: + raise UnrecognizedRequestError() + + rule_id = path[0] + + spec = { + 'scope': scope, + 'template': template, + 'rule_id': rule_id + } + if device: + spec['device'] = device + return spec + + def rule_tuple_from_request_object(self, rule_template, rule_id, req_obj, device=None): + if rule_template in ['override', 'underride']: + if 'conditions' not in req_obj: + raise InvalidRuleException("Missing 'conditions'") + conditions = req_obj['conditions'] + for c in conditions: + if 'kind' not in c: + raise InvalidRuleException("Condition without 'kind'") + elif rule_template == 'room': + conditions = [{ + 'kind': 'event_match', + 'key': 'room_id', + 'pattern': rule_id + }] + elif rule_template == 'sender': + conditions = [{ + 'kind': 'event_match', + 'key': 'user_id', + 'pattern': rule_id + }] + elif rule_template == 'content': + if 'pattern' not in req_obj: + raise InvalidRuleException("Content rule missing 'pattern'") + pat = req_obj['pattern'] + + conditions = [{ + 'kind': 'event_match', + 'key': 'content.body', + 'pattern': pat + }] + else: + raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) + + if device: + conditions.append({ + 'kind': 'device', + 'instance_handle': device + }) + + if 'actions' not in req_obj: + raise InvalidRuleException("No actions found") + actions = req_obj['actions'] + + for a in actions: + if a in ['notify', 'dont_notify', 'coalesce']: + pass + elif isinstance(a, dict) and 'set_sound' in a: + pass + else: + raise InvalidRuleException("Unrecognised action") + + return conditions, actions + + @defer.inlineCallbacks + def on_PUT(self, request): + spec = self.rule_spec_from_path(request.postpath) + try: + priority_class = _priority_class_from_spec(spec) + except InvalidRuleException as e: + raise SynapseError(400, e.message) + + user, _ = yield self.auth.get_user_by_req(request) + + if spec['template'] == 'default': + raise SynapseError(403, "The default rules are immutable.") + + content = _parse_json(request) + + try: + (conditions, actions) = self.rule_tuple_from_request_object( + spec['template'], + spec['rule_id'], + content, + device=spec['device'] if 'device' in spec else None + ) + except InvalidRuleException as e: + raise SynapseError(400, e.message) + + before = request.args.get("before", None) + if before and len(before): + before = before[0] + after = request.args.get("after", None) + if after and len(after): + after = after[0] + + try: + yield self.hs.get_datastore().add_push_rule( + user_name=user.to_string(), + rule_id=spec['rule_id'], + priority_class=priority_class, + conditions=conditions, + actions=actions, + before=before, + after=after + ) + except InconsistentRuleException as e: + raise SynapseError(400, e.message) + except RuleNotFoundException as e: + raise SynapseError(400, e.message) + + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_DELETE(self, request): + spec = self.rule_spec_from_path(request.postpath) + try: + priority_class = _priority_class_from_spec(spec) + except InvalidRuleException as e: + raise SynapseError(400, e.message) + + user, _ = yield self.auth.get_user_by_req(request) + + if 'device' in spec: + rules = yield self.hs.get_datastore().get_push_rules_for_user_name( + user.to_string() + ) + + for r in rules: + conditions = json.loads(r['conditions']) + ih = _instance_handle_from_conditions(conditions) + if ih == spec['device'] and r['priority_class'] == priority_class: + yield self.hs.get_datastore().delete_push_rule( + user.to_string(), spec['rule_id'] + ) + defer.returnValue((200, {})) + raise NotFoundError() + else: + try: + yield self.hs.get_datastore().delete_push_rule( + user.to_string(), spec['rule_id'], + priority_class=priority_class + ) + defer.returnValue((200, {})) + except StoreError as e: + if e.code == 404: + raise NotFoundError() + else: + raise + + @defer.inlineCallbacks + def on_GET(self, request): + user, _ = yield self.auth.get_user_by_req(request) + + # we build up the full structure and then decide which bits of it + # to send which means doing unnecessary work sometimes but is + # is probably not going to make a whole lot of difference + rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name(user.to_string()) + for r in rawrules: + r["conditions"] = json.loads(r["conditions"]) + r["actions"] = json.loads(r["actions"]) + rawrules.extend(baserules.make_base_rules(user.to_string())) + + rules = {'global': {}, 'device': {}} + + rules['global'] = _add_empty_priority_class_arrays(rules['global']) + + for r in rawrules: + rulearray = None + + template_name = _priority_class_to_template_name(r['priority_class']) + + if r['priority_class'] > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']: + # per-device rule + instance_handle = _instance_handle_from_conditions(r["conditions"]) + r = _strip_device_condition(r) + if not instance_handle: + continue + if instance_handle not in rules['device']: + rules['device'][instance_handle] = {} + rules['device'][instance_handle] = ( + _add_empty_priority_class_arrays( + rules['device'][instance_handle] + ) + ) + + rulearray = rules['device'][instance_handle][template_name] + else: + rulearray = rules['global'][template_name] + + template_rule = _rule_to_template(r) + if template_rule: + rulearray.append(template_rule) + + path = request.postpath[1:] + + if path == []: + # we're a reference impl: pedantry is our job. + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + + if path[0] == '': + defer.returnValue((200, rules)) + elif path[0] == 'global': + path = path[1:] + result = _filter_ruleset_with_path(rules['global'], path) + defer.returnValue((200, result)) + elif path[0] == 'device': + path = path[1:] + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + if path[0] == '': + defer.returnValue((200, rules['device'])) + + instance_handle = path[0] + path = path[1:] + if instance_handle not in rules['device']: + ret = {} + ret = _add_empty_priority_class_arrays(ret) + defer.returnValue((200, ret)) + ruleset = rules['device'][instance_handle] + result = _filter_ruleset_with_path(ruleset, path) + defer.returnValue((200, result)) + else: + raise UnrecognizedRequestError() + + def on_OPTIONS(self, _): + return 200, {} + + +def _add_empty_priority_class_arrays(d): + for pc in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys(): + d[pc] = [] + return d + + +def _instance_handle_from_conditions(conditions): + """ + Given a list of conditions, return the instance handle of the + device rule if there is one + """ + for c in conditions: + if c['kind'] == 'device': + return c['instance_handle'] + return None + + +def _filter_ruleset_with_path(ruleset, path): + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + + if path[0] == '': + return ruleset + template_kind = path[0] + if template_kind not in ruleset: + raise UnrecognizedRequestError() + path = path[1:] + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + if path[0] == '': + return ruleset[template_kind] + rule_id = path[0] + for r in ruleset[template_kind]: + if r['rule_id'] == rule_id: + return r + raise NotFoundError + + +def _priority_class_from_spec(spec): + if spec['template'] not in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec['kind'])) + pc = PushRuleRestServlet.PRIORITY_CLASS_MAP[spec['template']] + + if spec['scope'] == 'device': + pc += len(PushRuleRestServlet.PRIORITY_CLASS_MAP) + + return pc + + +def _priority_class_to_template_name(pc): + if pc > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']: + # per-device + prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP) + return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[prio_class_index] + else: + return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[pc] + + +def _rule_to_template(rule): + template_name = _priority_class_to_template_name(rule['priority_class']) + if template_name in ['default']: + return {k: rule[k] for k in ["conditions", "actions"]} + elif template_name in ['override', 'underride']: + return {k: rule[k] for k in ["rule_id", "conditions", "actions"]} + elif template_name in ["sender", "room"]: + return {k: rule[k] for k in ["rule_id", "actions"]} + elif template_name == 'content': + if len(rule["conditions"]) != 1: + return None + thecond = rule["conditions"][0] + if "pattern" not in thecond: + return None + ret = {k: rule[k] for k in ["rule_id", "actions"]} + ret["pattern"] = thecond["pattern"] + return ret + + +def _strip_device_condition(rule): + for i, c in enumerate(rule['conditions']): + if c['kind'] == 'device': + del rule['conditions'][i] + return rule + + +class InvalidRuleException(Exception): + pass + + +# XXX: C+ped from rest/room.py - surely this should be common? +def _parse_json(request): + try: + content = json.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.", + errcode=Codes.NOT_JSON) + return content + except ValueError: + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + +def register_servlets(hs, http_server): + PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py new file mode 100644 index 0000000000..353a4a6589 --- /dev/null +++ b/synapse/rest/client/v1/pusher.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError, Codes +from synapse.push import PusherConfigException +from .base import ClientV1RestServlet, client_path_pattern + +import json + + +class PusherRestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/pushers/set$") + + @defer.inlineCallbacks + def on_POST(self, request): + user, _ = yield self.auth.get_user_by_req(request) + + content = _parse_json(request) + + pusher_pool = self.hs.get_pusherpool() + + if ('pushkey' in content and 'app_id' in content + and 'kind' in content and + content['kind'] is None): + yield pusher_pool.remove_pusher( + content['app_id'], content['pushkey'] + ) + defer.returnValue((200, {})) + + reqd = ['instance_handle', 'kind', 'app_id', 'app_display_name', + 'device_display_name', 'pushkey', 'lang', 'data'] + missing = [] + for i in reqd: + if i not in content: + missing.append(i) + if len(missing): + raise SynapseError(400, "Missing parameters: "+','.join(missing), + errcode=Codes.MISSING_PARAM) + + try: + yield pusher_pool.add_pusher( + user_name=user.to_string(), + instance_handle=content['instance_handle'], + kind=content['kind'], + app_id=content['app_id'], + app_display_name=content['app_display_name'], + device_display_name=content['device_display_name'], + pushkey=content['pushkey'], + lang=content['lang'], + data=content['data'] + ) + except PusherConfigException as pce: + raise SynapseError(400, "Config Error: "+pce.message, + errcode=Codes.MISSING_PARAM) + + defer.returnValue((200, {})) + + def on_OPTIONS(self, _): + return 200, {} + + +# XXX: C+ped from rest/room.py - surely this should be common? +def _parse_json(request): + try: + content = json.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.", + errcode=Codes.NOT_JSON) + return content + except ValueError: + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + +def register_servlets(hs, http_server): + PusherRestServlet(hs).register(http_server) diff --git a/synapse/rest/register.py b/synapse/rest/client/v1/register.py index e3b26902d9..c0423c2d45 100644 --- a/synapse/rest/register.py +++ b/synapse/rest/client/v1/register.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.api.constants import LoginType -from base import RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_pattern import synapse.util.stringutils as stringutils from synapse.util.async import run_on_reactor @@ -42,7 +42,7 @@ else: compare_digest = lambda a, b: a == b -class RegisterRestServlet(RestServlet): +class RegisterRestServlet(ClientV1RestServlet): """Handles registration with the home server. This servlet is in control of the registration flow; the registration diff --git a/synapse/rest/room.py b/synapse/rest/client/v1/room.py index e40773758a..410f19ccf6 100644 --- a/synapse/rest/room.py +++ b/synapse/rest/client/v1/room.py @@ -16,10 +16,12 @@ """ This module contains REST servlets to do with rooms: /rooms/<paths> """ from twisted.internet import defer -from base import RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_pattern from synapse.api.errors import SynapseError, Codes from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership +from synapse.types import UserID, RoomID, RoomAlias +from synapse.events.utils import serialize_event import json import logging @@ -29,7 +31,7 @@ import urllib logger = logging.getLogger(__name__) -class RoomCreateRestServlet(RestServlet): +class RoomCreateRestServlet(ClientV1RestServlet): # No PATTERN; we have custom dispatch rules here def register(self, http_server): @@ -60,7 +62,7 @@ class RoomCreateRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) room_config = self.get_room_config(request) info = yield self.make_room(room_config, auth_user, None) @@ -93,7 +95,7 @@ class RoomCreateRestServlet(RestServlet): # TODO: Needs unit testing for generic events -class RoomStateEventRestServlet(RestServlet): +class RoomStateEventRestServlet(ClientV1RestServlet): def register(self, http_server): # /room/$roomid/state/$eventtype no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" @@ -123,7 +125,7 @@ class RoomStateEventRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -140,8 +142,8 @@ class RoomStateEventRestServlet(RestServlet): defer.returnValue((200, data.get_dict()["content"])) @defer.inlineCallbacks - def on_PUT(self, request, room_id, event_type, state_key): - user = yield self.auth.get_user_by_req(request) + def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -156,13 +158,15 @@ class RoomStateEventRestServlet(RestServlet): event_dict["state_key"] = state_key msg_handler = self.handlers.message_handler - yield msg_handler.create_and_send_event(event_dict) + yield msg_handler.create_and_send_event( + event_dict, client=client, txn_id=txn_id, + ) defer.returnValue((200, {})) # TODO: Needs unit testing for generic events + feedback -class RoomSendEventRestServlet(RestServlet): +class RoomSendEventRestServlet(ClientV1RestServlet): def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] @@ -170,8 +174,8 @@ class RoomSendEventRestServlet(RestServlet): register_txn_path(self, PATTERN, http_server, with_get=True) @defer.inlineCallbacks - def on_POST(self, request, room_id, event_type): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_id, event_type, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -181,7 +185,9 @@ class RoomSendEventRestServlet(RestServlet): "content": content, "room_id": room_id, "sender": user.to_string(), - } + }, + client=client, + txn_id=txn_id, ) defer.returnValue((200, {"event_id": event.event_id})) @@ -198,14 +204,14 @@ class RoomSendEventRestServlet(RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_id, event_type) + response = yield self.on_POST(request, room_id, event_type, txn_id) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) # TODO: Needs unit testing for room ID + alias joins -class JoinRoomAliasServlet(RestServlet): +class JoinRoomAliasServlet(ClientV1RestServlet): def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -213,8 +219,8 @@ class JoinRoomAliasServlet(RestServlet): register_txn_path(self, PATTERN, http_server) @defer.inlineCallbacks - def on_POST(self, request, room_identifier): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_identifier, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) # the identifier could be a room alias or a room id. Try one then the # other if it fails to parse, without swallowing other valid @@ -223,10 +229,10 @@ class JoinRoomAliasServlet(RestServlet): identifier = None is_room_alias = False try: - identifier = self.hs.parse_roomalias(room_identifier) + identifier = RoomAlias.from_string(room_identifier) is_room_alias = True except SynapseError: - identifier = self.hs.parse_roomid(room_identifier) + identifier = RoomID.from_string(room_identifier) # TODO: Support for specifying the home server to join with? @@ -243,10 +249,12 @@ class JoinRoomAliasServlet(RestServlet): "room_id": identifier.to_string(), "sender": user.to_string(), "state_key": user.to_string(), - } + }, + client=client, + txn_id=txn_id, ) - defer.returnValue((200, {})) + defer.returnValue((200, {"room_id": identifier.to_string()})) @defer.inlineCallbacks def on_PUT(self, request, room_identifier, txn_id): @@ -257,14 +265,14 @@ class JoinRoomAliasServlet(RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_identifier) + response = yield self.on_POST(request, room_identifier, txn_id) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) # TODO: Needs unit testing -class PublicRoomListRestServlet(RestServlet): +class PublicRoomListRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/publicRooms$") @defer.inlineCallbacks @@ -275,13 +283,13 @@ class PublicRoomListRestServlet(RestServlet): # TODO: Needs unit testing -class RoomMemberListRestServlet(RestServlet): +class RoomMemberListRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members$") @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) handler = self.handlers.room_member_handler members = yield handler.get_room_members_as_pagination_chunk( room_id=room_id, @@ -289,7 +297,7 @@ class RoomMemberListRestServlet(RestServlet): for event in members["chunk"]: # FIXME: should probably be state_key here, not user_id - target_user = self.hs.parse_userid(event["user_id"]) + target_user = UserID.from_string(event["user_id"]) # Presence is an optional cache; don't fail if we can't fetch it try: presence_handler = self.handlers.presence_handler @@ -304,33 +312,36 @@ class RoomMemberListRestServlet(RestServlet): # TODO: Needs unit testing -class RoomMessageListRestServlet(RestServlet): +class RoomMessageListRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$") @defer.inlineCallbacks def on_GET(self, request, room_id): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request( request, default_limit=10, ) with_feedback = "feedback" in request.args + as_client_event = "raw" not in request.args handler = self.handlers.message_handler msgs = yield handler.get_messages( room_id=room_id, user_id=user.to_string(), pagin_config=pagination_config, - feedback=with_feedback) + feedback=with_feedback, + as_client_event=as_client_event + ) defer.returnValue((200, msgs)) # TODO: Needs unit testing -class RoomStateRestServlet(RestServlet): +class RoomStateRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/state$") @defer.inlineCallbacks def on_GET(self, request, room_id): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( @@ -341,12 +352,12 @@ class RoomStateRestServlet(RestServlet): # TODO: Needs unit testing -class RoomInitialSyncRestServlet(RestServlet): +class RoomInitialSyncRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/initialSync$") @defer.inlineCallbacks def on_GET(self, request, room_id): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request(request) content = yield self.handlers.message_handler.room_initial_sync( room_id=room_id, @@ -356,9 +367,13 @@ class RoomInitialSyncRestServlet(RestServlet): defer.returnValue((200, content)) -class RoomTriggerBackfill(RestServlet): +class RoomTriggerBackfill(ClientV1RestServlet): PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/backfill$") + def __init__(self, hs): + super(RoomTriggerBackfill, self).__init__(hs) + self.clock = hs.get_clock() + @defer.inlineCallbacks def on_GET(self, request, room_id): remote_server = urllib.unquote( @@ -370,12 +385,14 @@ class RoomTriggerBackfill(RestServlet): handler = self.handlers.federation_handler events = yield handler.backfill(remote_server, room_id, limit) - res = [self.hs.serialize_event(event) for event in events] + time_now = self.clock.time_msec() + + res = [serialize_event(event, time_now) for event in events] defer.returnValue((200, res)) # TODO: Needs unit testing -class RoomMembershipRestServlet(RestServlet): +class RoomMembershipRestServlet(ClientV1RestServlet): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] @@ -384,8 +401,8 @@ class RoomMembershipRestServlet(RestServlet): register_txn_path(self, PATTERN, http_server) @defer.inlineCallbacks - def on_POST(self, request, room_id, membership_action): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_id, membership_action, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -407,7 +424,9 @@ class RoomMembershipRestServlet(RestServlet): "room_id": room_id, "sender": user.to_string(), "state_key": state_key, - } + }, + client=client, + txn_id=txn_id, ) defer.returnValue((200, {})) @@ -421,20 +440,22 @@ class RoomMembershipRestServlet(RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_id, membership_action) + response = yield self.on_POST( + request, room_id, membership_action, txn_id + ) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) -class RoomRedactEventRestServlet(RestServlet): +class RoomRedactEventRestServlet(ClientV1RestServlet): def register(self, http_server): PATTERN = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") register_txn_path(self, PATTERN, http_server) @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_id, event_id, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -445,7 +466,9 @@ class RoomRedactEventRestServlet(RestServlet): "room_id": room_id, "sender": user.to_string(), "redacts": event_id, - } + }, + client=client, + txn_id=txn_id, ) defer.returnValue((200, {"event_id": event.event_id})) @@ -459,23 +482,23 @@ class RoomRedactEventRestServlet(RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_id, event_id) + response = yield self.on_POST(request, room_id, event_id, txn_id) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) -class RoomTypingRestServlet(RestServlet): +class RoomTypingRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern( "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" ) @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) room_id = urllib.unquote(room_id) - target_user = self.hs.parse_userid(urllib.unquote(user_id)) + target_user = UserID.from_string(urllib.unquote(user_id)) content = _parse_json(request) diff --git a/synapse/rest/transactions.py b/synapse/rest/client/v1/transactions.py index d933fea18a..d933fea18a 100644 --- a/synapse/rest/transactions.py +++ b/synapse/rest/client/v1/transactions.py diff --git a/synapse/rest/voip.py b/synapse/rest/client/v1/voip.py index 011c35e69b..11d08fbced 100644 --- a/synapse/rest/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from base import RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_pattern import hmac @@ -23,12 +23,12 @@ import hashlib import base64 -class VoipRestServlet(RestServlet): +class VoipRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/voip/turnServer$") @defer.inlineCallbacks def on_GET(self, request): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py new file mode 100644 index 0000000000..8f611de3a8 --- /dev/null +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import ( + sync, + filter +) + +from synapse.http.server import JsonResource + + +class ClientV2AlphaRestResource(JsonResource): + """A resource for version 2 alpha of the matrix client API.""" + + def __init__(self, hs): + JsonResource.__init__(self) + self.register_servlets(self, hs) + + @staticmethod + def register_servlets(client_resource, hs): + sync.register_servlets(hs, client_resource) + filter.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py new file mode 100644 index 0000000000..22dc5cb862 --- /dev/null +++ b/synapse/rest/client/v2_alpha/_base.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains base REST classes for constructing client v1 servlets. +""" + +from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +import re + +import logging + + +logger = logging.getLogger(__name__) + + +def client_v2_pattern(path_regex): + """Creates a regex compiled client path with the correct client path + prefix. + + Args: + path_regex (str): The regex string to match. This should NOT have a ^ + as this will be prefixed. + Returns: + SRE_Pattern + """ + return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py new file mode 100644 index 0000000000..6ddc495d23 --- /dev/null +++ b/synapse/rest/client/v2_alpha/filter.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import AuthError, SynapseError +from synapse.http.servlet import RestServlet +from synapse.types import UserID + +from ._base import client_v2_pattern + +import json +import logging + + +logger = logging.getLogger(__name__) + + +class GetFilterRestServlet(RestServlet): + PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") + + def __init__(self, hs): + super(GetFilterRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.filtering = hs.get_filtering() + + @defer.inlineCallbacks + def on_GET(self, request, user_id, filter_id): + target_user = UserID.from_string(user_id) + auth_user, client = yield self.auth.get_user_by_req(request) + + if target_user != auth_user: + raise AuthError(403, "Cannot get filters for other users") + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only get filters for local users") + + try: + filter_id = int(filter_id) + except: + raise SynapseError(400, "Invalid filter_id") + + try: + filter = yield self.filtering.get_user_filter( + user_localpart=target_user.localpart, + filter_id=filter_id, + ) + + defer.returnValue((200, filter.filter_json)) + except KeyError: + raise SynapseError(400, "No such filter") + + +class CreateFilterRestServlet(RestServlet): + PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter") + + def __init__(self, hs): + super(CreateFilterRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.filtering = hs.get_filtering() + + @defer.inlineCallbacks + def on_POST(self, request, user_id): + target_user = UserID.from_string(user_id) + auth_user, client = yield self.auth.get_user_by_req(request) + + if target_user != auth_user: + raise AuthError(403, "Cannot create filters for other users") + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only create filters for local users") + + try: + content = json.loads(request.content.read()) + + # TODO(paul): check for required keys and invalid keys + except: + raise SynapseError(400, "Invalid filter definition") + + filter_id = yield self.filtering.add_user_filter( + user_localpart=target_user.localpart, + user_filter=content, + ) + + defer.returnValue((200, {"filter_id": str(filter_id)})) + + +def register_servlets(hs, http_server): + GetFilterRestServlet(hs).register(http_server) + CreateFilterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py new file mode 100644 index 0000000000..81d5cf8ead --- /dev/null +++ b/synapse/rest/client/v2_alpha/sync.py @@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet +from synapse.handlers.sync import SyncConfig +from synapse.types import StreamToken +from synapse.events.utils import ( + serialize_event, format_event_for_client_v2_without_event_id, +) +from synapse.api.filtering import Filter +from ._base import client_v2_pattern + +import logging + +logger = logging.getLogger(__name__) + + +class SyncRestServlet(RestServlet): + """ + + GET parameters:: + timeout(int): How long to wait for new events in milliseconds. + limit(int): Maxiumum number of events per room to return. + gap(bool): Create gaps the message history if limit is exceeded to + ensure that the client has the most recent messages. Defaults to + "true". + sort(str,str): tuple of sort key (e.g. "timeline") and direction + (e.g. "asc", "desc"). Defaults to "timeline,asc". + since(batch_token): Batch token when asking for incremental deltas. + set_presence(str): What state the device presence should be set to. + default is "online". + backfill(bool): Should the HS request message history from other + servers. This may take a long time making it unsuitable for clients + expecting a prompt response. Defaults to "true". + filter(filter_id): A filter to apply to the events returned. + filter_*: Filter override parameters. + + Response JSON:: + { + "next_batch": // batch token for the next /sync + "private_user_data": // private events for this user. + "public_user_data": // public events for all users including the + // public events for this user. + "rooms": [{ // List of rooms with updates. + "room_id": // Id of the room being updated + "limited": // Was the per-room event limit exceeded? + "published": // Is the room published by our HS? + "event_map": // Map of EventID -> event JSON. + "events": { // The recent events in the room if gap is "true" + // otherwise the next events in the room. + "batch": [] // list of EventIDs in the "event_map". + "prev_batch": // back token for getting previous events. + } + "state": [] // list of EventIDs updating the current state to + // be what it should be at the end of the batch. + "ephemeral": [] + }] + } + """ + + PATTERN = client_v2_pattern("/sync$") + ALLOWED_SORT = set(["timeline,asc", "timeline,desc"]) + ALLOWED_PRESENCE = set(["online", "offline", "idle"]) + + def __init__(self, hs): + super(SyncRestServlet, self).__init__() + self.auth = hs.get_auth() + self.sync_handler = hs.get_handlers().sync_handler + self.clock = hs.get_clock() + self.filtering = hs.get_filtering() + + @defer.inlineCallbacks + def on_GET(self, request): + user, client = yield self.auth.get_user_by_req(request) + + timeout = self.parse_integer(request, "timeout", default=0) + limit = self.parse_integer(request, "limit", required=True) + gap = self.parse_boolean(request, "gap", default=True) + sort = self.parse_string( + request, "sort", default="timeline,asc", + allowed_values=self.ALLOWED_SORT + ) + since = self.parse_string(request, "since") + set_presence = self.parse_string( + request, "set_presence", default="online", + allowed_values=self.ALLOWED_PRESENCE + ) + backfill = self.parse_boolean(request, "backfill", default=False) + filter_id = self.parse_string(request, "filter", default=None) + + logger.info( + "/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r," + " set_presence=%r, backfill=%r, filter_id=%r" % ( + user, timeout, limit, gap, sort, since, set_presence, + backfill, filter_id + ) + ) + + # TODO(mjark): Load filter and apply overrides. + try: + filter = yield self.filtering.get_user_filter( + user.localpart, filter_id + ) + except: + filter = Filter({}) + # filter = filter.apply_overrides(http_request) + #if filter.matches(event): + # # stuff + + sync_config = SyncConfig( + user=user, + client_info=client, + gap=gap, + limit=limit, + sort=sort, + backfill=backfill, + filter=filter, + ) + + if since is not None: + since_token = StreamToken.from_string(since) + else: + since_token = None + + sync_result = yield self.sync_handler.wait_for_sync_for_user( + sync_config, since_token=since_token, timeout=timeout + ) + + time_now = self.clock.time_msec() + + response_content = { + "public_user_data": self.encode_user_data( + sync_result.public_user_data, filter, time_now + ), + "private_user_data": self.encode_user_data( + sync_result.private_user_data, filter, time_now + ), + "rooms": self.encode_rooms( + sync_result.rooms, filter, time_now, client.token_id + ), + "next_batch": sync_result.next_batch.to_string(), + } + + defer.returnValue((200, response_content)) + + def encode_user_data(self, events, filter, time_now): + return events + + def encode_rooms(self, rooms, filter, time_now, token_id): + return [ + self.encode_room(room, filter, time_now, token_id) + for room in rooms + ] + + @staticmethod + def encode_room(room, filter, time_now, token_id): + event_map = {} + state_events = filter.filter_room_state(room.state) + recent_events = filter.filter_room_events(room.events) + state_event_ids = [] + recent_event_ids = [] + for event in state_events: + # TODO(mjark): Respect formatting requirements in the filter. + event_map[event.event_id] = serialize_event( + event, time_now, token_id=token_id, + event_format=format_event_for_client_v2_without_event_id, + ) + state_event_ids.append(event.event_id) + + for event in recent_events: + # TODO(mjark): Respect formatting requirements in the filter. + event_map[event.event_id] = serialize_event( + event, time_now, token_id=token_id, + event_format=format_event_for_client_v2_without_event_id, + ) + recent_event_ids.append(event.event_id) + result = { + "room_id": room.room_id, + "event_map": event_map, + "events": { + "batch": recent_event_ids, + "prev_batch": room.prev_batch.to_string(), + }, + "state": state_event_ids, + "limited": room.limited, + "published": room.published, + "ephemeral": room.ephemeral, + } + return result + + +def register_servlets(hs, http_server): + SyncRestServlet(hs).register(http_server) diff --git a/synapse/media/__init__.py b/synapse/rest/media/__init__.py index e69de29bb2..e69de29bb2 100644 --- a/synapse/media/__init__.py +++ b/synapse/rest/media/__init__.py diff --git a/synapse/media/v0/__init__.py b/synapse/rest/media/v0/__init__.py index e69de29bb2..e69de29bb2 100644 --- a/synapse/media/v0/__init__.py +++ b/synapse/rest/media/v0/__init__.py diff --git a/synapse/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index 79ae0e3d74..22e26e3cd5 100644 --- a/synapse/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource): @defer.inlineCallbacks def map_request_to_name(self, request): # auth the user - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) # namespace all file uploads on the user prefix = base64.urlsafe_b64encode( diff --git a/synapse/media/v1/__init__.py b/synapse/rest/media/v1/__init__.py index 619999d268..d6c6690577 100644 --- a/synapse/media/v1/__init__.py +++ b/synapse/rest/media/v1/__init__.py @@ -22,7 +22,8 @@ except IOError as e: if str(e).startswith("decoder jpeg not available"): raise Exception( "FATAL: jpeg codec not supported. Install pillow correctly! " - " 'sudo apt-get install libjpeg-dev' then 'pip install -I pillow'" + " 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&" + " pip install pillow --user'" ) except Exception: # any other exception is fine @@ -36,7 +37,8 @@ except IOError as e: if str(e).startswith("decoder zip not available"): raise Exception( "FATAL: zip codec not supported. Install pillow correctly! " - " 'sudo apt-get install libjpeg-dev' then 'pip install -I pillow'" + " 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&" + " pip install pillow --user'" ) except Exception: # any other exception is fine diff --git a/synapse/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py index 688e7376ad..688e7376ad 100644 --- a/synapse/media/v1/base_resource.py +++ b/synapse/rest/media/v1/base_resource.py diff --git a/synapse/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index c585bb11f7..c585bb11f7 100644 --- a/synapse/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py diff --git a/synapse/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index ed9a58e9d9..ed9a58e9d9 100644 --- a/synapse/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py diff --git a/synapse/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 461cc001f1..461cc001f1 100644 --- a/synapse/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py diff --git a/synapse/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 84f5e3463c..84f5e3463c 100644 --- a/synapse/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py diff --git a/synapse/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index bc86efea8f..28404f2b7b 100644 --- a/synapse/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -82,7 +82,7 @@ class Thumbnailer(object): def save_image(self, output_image, output_type, output_path): output_bytes_io = BytesIO() - output_image.save(output_bytes_io, self.FORMATS[output_type]) + output_image.save(output_bytes_io, self.FORMATS[output_type], quality=70) output_bytes = output_bytes_io.getvalue() with open(output_path, "wb") as output_file: output_file.write(output_bytes) diff --git a/synapse/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index b1718a630b..b939a30e19 100644 --- a/synapse/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -42,7 +42,7 @@ class UploadResource(BaseMediaResource): @defer.inlineCallbacks def _async_render_POST(self, request): try: - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") diff --git a/synapse/server.py b/synapse/server.py index c3bf46abbf..23bdad0c7c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -20,21 +20,20 @@ # Imports required for the default HomeServer() implementation from synapse.federation import initialize_http_replication -from synapse.events.utils import serialize_event from synapse.notifier import Notifier from synapse.api.auth import Auth from synapse.handlers import Handlers -from synapse.rest import RestServletFactory from synapse.state import StateHandler from synapse.storage import DataStore -from synapse.types import UserID, RoomAlias, RoomID, EventID from synapse.util import Clock from synapse.util.distributor import Distributor from synapse.util.lockutils import LockManager from synapse.streams.events import EventSources from synapse.api.ratelimiting import Ratelimiter from synapse.crypto.keyring import Keyring +from synapse.push.pusherpool import PusherPool from synapse.events.builder import EventBuilderFactory +from synapse.api.filtering import Filtering class BaseHomeServer(object): @@ -72,6 +71,7 @@ class BaseHomeServer(object): 'notifier', 'distributor', 'resource_for_client', + 'resource_for_client_v2_alpha', 'resource_for_federation', 'resource_for_web_client', 'resource_for_content_repo', @@ -80,7 +80,9 @@ class BaseHomeServer(object): 'event_sources', 'ratelimiter', 'keyring', + 'pusherpool', 'event_builder_factory', + 'filtering', ] def __init__(self, hostname, **kwargs): @@ -125,33 +127,6 @@ class BaseHomeServer(object): setattr(BaseHomeServer, "get_%s" % (depname), _get) - # TODO: Why are these parse_ methods so high up along with other globals? - # Surely these should be in a util package or in the api package? - - # Other utility methods - def parse_userid(self, s): - """Parse the string given by 's' as a User ID and return a UserID - object.""" - return UserID.from_string(s) - - def parse_roomalias(self, s): - """Parse the string given by 's' as a Room Alias and return a RoomAlias - object.""" - return RoomAlias.from_string(s) - - def parse_roomid(self, s): - """Parse the string given by 's' as a Room ID and return a RoomID - object.""" - return RoomID.from_string(s) - - def parse_eventid(self, s): - """Parse the string given by 's' as a Event ID and return a EventID - object.""" - return EventID.from_string(s) - - def serialize_event(self, e): - return serialize_event(self, e) - def get_ip_from_request(self, request): # May be an X-Forwarding-For header depending on config ip_addr = request.getClientIP() @@ -203,9 +178,6 @@ class HomeServer(BaseHomeServer): def build_auth(self): return Auth(self) - def build_rest_servlet_factory(self): - return RestServletFactory(self) - def build_state_handler(self): return StateHandler(self) @@ -230,8 +202,8 @@ class HomeServer(BaseHomeServer): hostname=self.hostname, ) - def register_servlets(self): - """ Register all servlets associated with this HomeServer. - """ - # Simply building the ServletFactory is sufficient to have it register - self.get_rest_servlet_factory() + def build_filtering(self): + return Filtering(self) + + def build_pusherpool(self): + return PusherPool(self) diff --git a/synapse/state.py b/synapse/state.py index 8144fa02b4..8a056ee955 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext from collections import namedtuple @@ -36,12 +37,16 @@ def _get_state_key_from_event(event): KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) +AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,) + + class StateHandler(object): """ Responsible for doing state conflict resolution. """ def __init__(self, hs): self.store = hs.get_datastore() + self.hs = hs @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): @@ -163,10 +168,17 @@ class StateHandler(object): first is the name of a state group if one and only one is involved, otherwise `None`. """ + logger.debug("resolve_state_groups event_ids %s", event_ids) + state_groups = yield self.store.get_state_groups( event_ids ) + logger.debug( + "resolve_state_groups state_groups %s", + state_groups.keys() + ) + group_names = set(state_groups.keys()) if len(group_names) == 1: name, state_list = state_groups.items().pop() @@ -210,64 +222,93 @@ class StateHandler(object): else: prev_states = [] + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } + try: - new_state = {} - new_state.update(unconflicted_state) - for key, events in conflicted_state.items(): - new_state[key] = self._resolve_state_events(events) + resolved_state = self._resolve_state_events( + conflicted_state, auth_events + ) except: logger.exception("Failed to resolve state") raise - defer.returnValue((None, new_state, prev_states)) - - def _get_power_level_from_event_state(self, event, user_id): - if hasattr(event, "old_state_events") and event.old_state_events: - key = (EventTypes.PowerLevels, "", ) - power_level_event = event.old_state_events.get(key) - level = None - if power_level_event: - level = power_level_event.content.get("users", {}).get( - user_id - ) - if not level: - level = power_level_event.content.get("users_default", 0) + new_state = unconflicted_state + new_state.update(resolved_state) - return level - else: - return 0 + defer.returnValue((None, new_state, prev_states)) @log_function - def _resolve_state_events(self, events): - curr_events = events - - new_powers = [ - self._get_power_level_from_event_state(e, e.user_id) - for e in curr_events - ] - - new_powers = [ - int(p) if p else 0 for p in new_powers - ] + def _resolve_state_events(self, conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. memberships + 3. other events. + """ + resolved_state = {} + power_key = (EventTypes.PowerLevels, "") + if power_key in conflicted_state.items(): + power_levels = conflicted_state[power_key] + resolved_state[power_key] = self._resolve_auth_events(power_levels) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + resolved_state[key] = self._resolve_auth_events( + events, + auth_events + ) - max_power = max(new_powers) + auth_events.update(resolved_state) - curr_events = [ - z[0] for z in zip(curr_events, new_powers) - if z[1] == max_power - ] + for key, events in conflicted_state.items(): + if key not in resolved_state: + resolved_state[key] = self._resolve_normal_events( + events, auth_events + ) - if not curr_events: - raise RuntimeError("Max didn't get a max?") - elif len(curr_events) == 1: - return curr_events[0] - - # TODO: For now, just choose the one with the largest event_id. - return ( - sorted( - curr_events, - key=lambda e: hashlib.sha1( - e.event_id + e.user_id + e.room_id + e.type - ).hexdigest() - )[0] - ) + return resolved_state + + def _resolve_auth_events(self, events, auth_events): + reverse = [i for i in reversed(self._ordered_events(events))] + + auth_events = dict(auth_events) + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # FIXME: hs.get_auth() is bad style, but we need to do it to + # get around circular deps. + self.hs.get_auth().check(event, auth_events) + prev_event = event + except AuthError: + return prev_event + + return event + + def _resolve_normal_events(self, events, auth_events): + for event in self._ordered_events(events): + try: + # FIXME: hs.get_auth() is bad style, but we need to do it to + # get around circular deps. + self.hs.get_auth().check(event, auth_events) + return event + except AuthError: + pass + + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event + + def _ordered_events(self, events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() + + return sorted(events, key=key_func) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4beb951b9f..7c54b1b9d3 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -29,10 +29,14 @@ from .stream import StreamStore from .transactions import TransactionStore from .keys import KeyStore from .event_federation import EventFederationStore +from .pusher import PusherStore +from .push_rule import PushRuleStore from .media_repository import MediaRepositoryStore +from .rejections import RejectionsStore from .state import StateStore from .signatures import SignatureStore +from .filtering import FilteringStore from syutil.base64util import decode_base64 from syutil.jsonutil import encode_canonical_json @@ -60,13 +64,16 @@ SCHEMAS = [ "state", "event_edges", "event_signatures", + "pusher", "media_repository", + "filtering", + "rejections", ] # Remember to update this number every time an incompatible change is made to # database schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 11 +SCHEMA_VERSION = 12 class _RollbackButIsFineException(Exception): @@ -82,6 +89,10 @@ class DataStore(RoomMemberStore, RoomStore, DirectoryStore, KeyStore, StateStore, SignatureStore, EventFederationStore, MediaRepositoryStore, + RejectionsStore, + FilteringStore, + PusherStore, + PushRuleStore ): def __init__(self, hs): @@ -224,6 +235,9 @@ class DataStore(RoomMemberStore, RoomStore, if not outlier: self._store_state_groups_txn(txn, event, context) + if context.rejected: + self._store_rejections_txn(txn, event.event_id, context.rejected) + if current_state: txn.execute( "DELETE FROM current_state_events WHERE room_id = ?", @@ -262,7 +276,7 @@ class DataStore(RoomMemberStore, RoomStore, or_replace=True, ) - if is_new_state: + if is_new_state and not context.rejected: self._simple_insert_txn( txn, "current_state_events", @@ -288,7 +302,7 @@ class DataStore(RoomMemberStore, RoomStore, or_ignore=True, ) - if not backfilled: + if not backfilled and not context.rejected: self._simple_insert_txn( txn, table="state_forward_extremities", @@ -370,9 +384,12 @@ class DataStore(RoomMemberStore, RoomStore, "redacted": del_sql, } - if event_type: + if event_type and state_key is not None: sql += " AND s.type = ? AND s.state_key = ? " args = (room_id, event_type, state_key) + elif event_type: + sql += " AND s.type = ?" + args = (room_id, event_type) else: args = (room_id, ) @@ -382,6 +399,41 @@ class DataStore(RoomMemberStore, RoomStore, defer.returnValue(events) @defer.inlineCallbacks + def get_room_name_and_aliases(self, room_id): + del_sql = ( + "SELECT event_id FROM redactions WHERE redacts = e.event_id " + "LIMIT 1" + ) + + sql = ( + "SELECT e.*, (%(redacted)s) AS redacted FROM events as e " + "INNER JOIN current_state_events as c ON e.event_id = c.event_id " + "INNER JOIN state_events as s ON e.event_id = s.event_id " + "WHERE c.room_id = ? " + ) % { + "redacted": del_sql, + } + + sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')" + sql += " OR s.type = 'm.room.aliases')" + args = (room_id,) + + results = yield self._execute_and_decode(sql, *args) + + events = yield self._parse_events(results) + + name = None + aliases = [] + + for e in events: + if e.type == 'm.room.name': + name = e.content['name'] + elif e.type == 'm.room.aliases': + aliases.extend(e.content['aliases']) + + defer.returnValue((name, aliases)) + + @defer.inlineCallbacks def _get_min_token(self): row = yield self._execute( None, @@ -417,6 +469,35 @@ class DataStore(RoomMemberStore, RoomStore, ], ) + def have_events(self, event_ids): + """Given a list of event ids, check if we have already processed them. + + Returns: + dict: Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps to + None. + """ + def f(txn): + sql = ( + "SELECT e.event_id, reason FROM events as e " + "LEFT JOIN rejections as r ON e.event_id = r.event_id " + "WHERE e.event_id = ?" + ) + + res = {} + for event_id in event_ids: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if row: + _, rejected = row + res[event_id] = rejected + + return res + + return self.runInteraction( + "have_events", f, + ) + def schema_path(schema): """ Get a filesystem path for the named database schema diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f660fc6eaf..b350fd61f1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -193,6 +193,50 @@ class SQLBaseStore(object): txn.execute(sql, values.values()) return txn.lastrowid + def _simple_upsert(self, table, keyvalues, values): + """ + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + Returns: A deferred + """ + return self.runInteraction( + "_simple_upsert", + self._simple_upsert_txn, table, keyvalues, values + ) + + def _simple_upsert_txn(self, txn, table, keyvalues, values): + # Try to update + sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in values), + " AND ".join("%s = ?" % (k,) for k in keyvalues) + ) + sqlargs = values.values() + keyvalues.values() + logger.debug( + "[SQL] %s Args=%s", + sql, sqlargs, + ) + + txn.execute(sql, sqlargs) + if txn.rowcount == 0: + # We didn't update and rows so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues) + ) + logger.debug( + "[SQL] %s Args=%s", + sql, keyvalues.values(), + ) + txn.execute(sql, allvalues.values()) + def _simple_select_one(self, table, keyvalues, retcols, allow_none=False): """Executes a SELECT query on the named table, which is expected to @@ -344,8 +388,8 @@ class SQLBaseStore(object): if updatevalues: update_sql = "UPDATE %s SET %s WHERE %s" % ( table, - ", ".join("%s = ?" % (k) for k in updatevalues), - " AND ".join("%s = ?" % (k) for k in keyvalues) + ", ".join("%s = ?" % (k,) for k in updatevalues), + " AND ".join("%s = ?" % (k,) for k in keyvalues) ) def func(txn): @@ -458,10 +502,12 @@ class SQLBaseStore(object): return [e for e in events if e] def _get_event_txn(self, txn, event_id, check_redacted=True, - get_prev_content=False): + get_prev_content=False, allow_rejected=False): sql = ( - "SELECT internal_metadata, json, r.event_id FROM event_json as e " + "SELECT e.internal_metadata, e.json, r.event_id, rej.reason " + "FROM event_json as e " "LEFT JOIN redactions as r ON e.event_id = r.redacts " + "LEFT JOIN rejections as rej on rej.event_id = e.event_id " "WHERE e.event_id = ? " "LIMIT 1 " ) @@ -473,13 +519,16 @@ class SQLBaseStore(object): if not res: return None - internal_metadata, js, redacted = res + internal_metadata, js, redacted, rejected_reason = res - return self._get_event_from_row_txn( - txn, internal_metadata, js, redacted, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - ) + if allow_rejected or not rejected_reason: + return self._get_event_from_row_txn( + txn, internal_metadata, js, redacted, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + ) + else: + return None def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, check_redacted=True, get_prev_content=False): diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py new file mode 100644 index 0000000000..e86eeced45 --- /dev/null +++ b/synapse/storage/filtering.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import SQLBaseStore + +import json + + +class FilteringStore(SQLBaseStore): + @defer.inlineCallbacks + def get_user_filter(self, user_localpart, filter_id): + def_json = yield self._simple_select_one_onecol( + table="user_filters", + keyvalues={ + "user_id": user_localpart, + "filter_id": filter_id, + }, + retcol="filter_json", + allow_none=False, + ) + + defer.returnValue(json.loads(def_json)) + + def add_user_filter(self, user_localpart, user_filter): + def_json = json.dumps(user_filter) + + # Need an atomic transaction to SELECT the maximal ID so far then + # INSERT a new one + def _do_txn(txn): + sql = ( + "SELECT MAX(filter_id) FROM user_filters " + "WHERE user_id = ?" + ) + txn.execute(sql, (user_localpart,)) + max_id = txn.fetchone()[0] + if max_id is None: + filter_id = 0 + else: + filter_id = max_id + 1 + + sql = ( + "INSERT INTO user_filters (user_id, filter_id, filter_json)" + "VALUES(?, ?, ?)" + ) + txn.execute(sql, (user_localpart, filter_id, def_json)) + + return filter_id + + return self.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py new file mode 100644 index 0000000000..27502d2399 --- /dev/null +++ b/synapse/storage/push_rule.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +from ._base import SQLBaseStore, Table +from twisted.internet import defer + +import logging +import copy +import json + +logger = logging.getLogger(__name__) + + +class PushRuleStore(SQLBaseStore): + @defer.inlineCallbacks + def get_push_rules_for_user_name(self, user_name): + sql = ( + "SELECT "+",".join(PushRuleTable.fields)+" " + "FROM "+PushRuleTable.table_name+" " + "WHERE user_name = ? " + "ORDER BY priority_class DESC, priority DESC" + ) + rows = yield self._execute(None, sql, user_name) + + dicts = [] + for r in rows: + d = {} + for i, f in enumerate(PushRuleTable.fields): + d[f] = r[i] + dicts.append(d) + + defer.returnValue(dicts) + + @defer.inlineCallbacks + def add_push_rule(self, before, after, **kwargs): + vals = copy.copy(kwargs) + if 'conditions' in vals: + vals['conditions'] = json.dumps(vals['conditions']) + if 'actions' in vals: + vals['actions'] = json.dumps(vals['actions']) + # we could check the rest of the keys are valid column names + # but sqlite will do that anyway so I think it's just pointless. + if 'id' in vals: + del vals['id'] + + if before or after: + ret = yield self.runInteraction( + "_add_push_rule_relative_txn", + self._add_push_rule_relative_txn, + before=before, + after=after, + **vals + ) + defer.returnValue(ret) + else: + ret = yield self.runInteraction( + "_add_push_rule_highest_priority_txn", + self._add_push_rule_highest_priority_txn, + **vals + ) + defer.returnValue(ret) + + def _add_push_rule_relative_txn(self, txn, user_name, **kwargs): + after = None + relative_to_rule = None + if 'after' in kwargs and kwargs['after']: + after = kwargs['after'] + relative_to_rule = after + if 'before' in kwargs and kwargs['before']: + relative_to_rule = kwargs['before'] + + # get the priority of the rule we're inserting after/before + sql = ( + "SELECT priority_class, priority FROM ? " + "WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,) + ) + txn.execute(sql, (user_name, relative_to_rule)) + res = txn.fetchall() + if not res: + raise RuleNotFoundException("before/after rule not found: %s" % (relative_to_rule)) + priority_class, base_rule_priority = res[0] + + if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class: + raise InconsistentRuleException( + "Given priority class does not match class of relative rule" + ) + + new_rule = copy.copy(kwargs) + if 'before' in new_rule: + del new_rule['before'] + if 'after' in new_rule: + del new_rule['after'] + new_rule['priority_class'] = priority_class + new_rule['user_name'] = user_name + + # check if the priority before/after is free + new_rule_priority = base_rule_priority + if after: + new_rule_priority -= 1 + else: + new_rule_priority += 1 + + new_rule['priority'] = new_rule_priority + + sql = ( + "SELECT COUNT(*) FROM " + PushRuleTable.table_name + + " WHERE user_name = ? AND priority_class = ? AND priority = ?" + ) + txn.execute(sql, (user_name, priority_class, new_rule_priority)) + res = txn.fetchall() + num_conflicting = res[0][0] + + # if there are conflicting rules, bump everything + if num_conflicting: + sql = "UPDATE "+PushRuleTable.table_name+" SET priority = priority " + if after: + sql += "-1" + else: + sql += "+1" + sql += " WHERE user_name = ? AND priority_class = ? AND priority " + if after: + sql += "<= ?" + else: + sql += ">= ?" + + txn.execute(sql, (user_name, priority_class, new_rule_priority)) + + # now insert the new rule + sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" + sql += ",".join(new_rule.keys())+") VALUES (" + sql += ", ".join(["?" for _ in new_rule.keys()])+")" + + txn.execute(sql, new_rule.values()) + + def _add_push_rule_highest_priority_txn(self, txn, user_name, + priority_class, **kwargs): + # find the highest priority rule in that class + sql = ( + "SELECT COUNT(*), MAX(priority) FROM " + PushRuleTable.table_name + + " WHERE user_name = ? and priority_class = ?" + ) + txn.execute(sql, (user_name, priority_class)) + res = txn.fetchall() + (how_many, highest_prio) = res[0] + + new_prio = 0 + if how_many > 0: + new_prio = highest_prio + 1 + + # and insert the new rule + new_rule = copy.copy(kwargs) + if 'id' in new_rule: + del new_rule['id'] + new_rule['user_name'] = user_name + new_rule['priority_class'] = priority_class + new_rule['priority'] = new_prio + + sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" + sql += ",".join(new_rule.keys())+") VALUES (" + sql += ", ".join(["?" for _ in new_rule.keys()])+")" + + txn.execute(sql, new_rule.values()) + + @defer.inlineCallbacks + def delete_push_rule(self, user_name, rule_id, **kwargs): + """ + Delete a push rule. Args specify the row to be deleted and can be + any of the columns in the push_rule table, but below are the + standard ones + + Args: + user_name (str): The matrix ID of the push rule owner + rule_id (str): The rule_id of the rule to be deleted + """ + yield self._simple_delete_one(PushRuleTable.table_name, kwargs) + + +class RuleNotFoundException(Exception): + pass + + +class InconsistentRuleException(Exception): + pass + + +class PushRuleTable(Table): + table_name = "push_rules" + + fields = [ + "id", + "user_name", + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ] + + EntryType = collections.namedtuple("PushRuleEntry", fields) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py new file mode 100644 index 0000000000..f253c9e2c3 --- /dev/null +++ b/synapse/storage/pusher.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +from ._base import SQLBaseStore, Table +from twisted.internet import defer + +from synapse.api.errors import StoreError + +import logging + +logger = logging.getLogger(__name__) + + +class PusherStore(SQLBaseStore): + @defer.inlineCallbacks + def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey): + sql = ( + "SELECT id, user_name, kind, instance_handle, app_id," + "app_display_name, device_display_name, pushkey, ts, data, " + "last_token, last_success, failing_since " + "FROM pushers " + "WHERE app_id = ? AND pushkey = ?" + ) + + rows = yield self._execute( + None, sql, app_id_and_pushkey[0], app_id_and_pushkey[1] + ) + + ret = [ + { + "id": r[0], + "user_name": r[1], + "kind": r[2], + "instance_handle": r[3], + "app_id": r[4], + "app_display_name": r[5], + "device_display_name": r[6], + "pushkey": r[7], + "pushkey_ts": r[8], + "data": r[9], + "last_token": r[10], + "last_success": r[11], + "failing_since": r[12] + } + for r in rows + ] + + defer.returnValue(ret[0]) + + @defer.inlineCallbacks + def get_all_pushers(self): + sql = ( + "SELECT id, user_name, kind, instance_handle, app_id," + "app_display_name, device_display_name, pushkey, ts, data, " + "last_token, last_success, failing_since " + "FROM pushers" + ) + + rows = yield self._execute(None, sql) + + ret = [ + { + "id": r[0], + "user_name": r[1], + "kind": r[2], + "instance_handle": r[3], + "app_id": r[4], + "app_display_name": r[5], + "device_display_name": r[6], + "pushkey": r[7], + "pushkey_ts": r[8], + "data": r[9], + "last_token": r[10], + "last_success": r[11], + "failing_since": r[12] + } + for r in rows + ] + + defer.returnValue(ret) + + @defer.inlineCallbacks + def add_pusher(self, user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, + pushkey, pushkey_ts, lang, data): + try: + yield self._simple_upsert( + PushersTable.table_name, + dict( + app_id=app_id, + pushkey=pushkey, + ), + dict( + user_name=user_name, + kind=kind, + instance_handle=instance_handle, + app_display_name=app_display_name, + device_display_name=device_display_name, + ts=pushkey_ts, + lang=lang, + data=data + )) + except Exception as e: + logger.error("create_pusher with failed: %s", e) + raise StoreError(500, "Problem creating pusher.") + + @defer.inlineCallbacks + def delete_pusher_by_app_id_pushkey(self, app_id, pushkey): + yield self._simple_delete_one( + PushersTable.table_name, + dict(app_id=app_id, pushkey=pushkey) + ) + + @defer.inlineCallbacks + def update_pusher_last_token(self, user_name, pushkey, last_token): + yield self._simple_update_one( + PushersTable.table_name, + {'user_name': user_name, 'pushkey': pushkey}, + {'last_token': last_token} + ) + + @defer.inlineCallbacks + def update_pusher_last_token_and_success(self, user_name, pushkey, + last_token, last_success): + yield self._simple_update_one( + PushersTable.table_name, + {'user_name': user_name, 'pushkey': pushkey}, + {'last_token': last_token, 'last_success': last_success} + ) + + @defer.inlineCallbacks + def update_pusher_failing_since(self, user_name, pushkey, failing_since): + yield self._simple_update_one( + PushersTable.table_name, + {'user_name': user_name, 'pushkey': pushkey}, + {'failing_since': failing_since} + ) + + +class PushersTable(Table): + table_name = "pushers" + + fields = [ + "id", + "user_name", + "kind", + "instance_handle", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "pushkey_ts", + "data", + "last_token", + "last_success", + "failing_since" + ] + + EntryType = collections.namedtuple("PusherEntry", fields) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 75dffa4db2..029b07cc66 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -122,7 +122,8 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.admin, access_tokens.device_id" + "SELECT users.name, users.admin," + " access_tokens.device_id, access_tokens.id as token_id" " FROM users" " INNER JOIN access_tokens on users.id = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py new file mode 100644 index 0000000000..4e1a9a2783 --- /dev/null +++ b/synapse/storage/rejections.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore + +import logging + +logger = logging.getLogger(__name__) + + +class RejectionsStore(SQLBaseStore): + def _store_rejections_txn(self, txn, event_id, reason): + self._simple_insert_txn( + txn, + table="rejections", + values={ + "event_id": event_id, + "reason": reason, + "last_check": self._clock.time_msec(), + } + ) + + def get_rejection_reason(self, event_id): + return self._simple_select_one_onecol( + table="rejections", + retcol="reason", + keyvalues={ + "event_id": event_id, + }, + allow_none=True, + ) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 978b2c4a48..6542f8e4f8 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -58,13 +58,6 @@ class RoomStore(SQLBaseStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - def store_room_config(self, room_id, visibility): - return self._simple_update_one( - table=RoomsTable.table_name, - keyvalues={"room_id": room_id}, - updatevalues={"is_public": visibility} - ) - def get_room(self, room_id): """Retrieve a room. diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index e59e65529b..c69dd995ce 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -20,6 +20,7 @@ from collections import namedtuple from ._base import SQLBaseStore from synapse.api.constants import Membership +from synapse.types import UserID import logging @@ -39,7 +40,7 @@ class RoomMemberStore(SQLBaseStore): """ try: target_user_id = event.state_key - domain = self.hs.parse_userid(target_user_id).domain + domain = UserID.from_string(target_user_id).domain except: logger.exception( "Failed to parse target_user_id=%s", target_user_id @@ -84,7 +85,7 @@ class RoomMemberStore(SQLBaseStore): for e in member_events: try: joined_domains.add( - self.hs.parse_userid(e.state_key).domain + UserID.from_string(e.state_key).domain ) except: # FIXME: How do we deal with invalid user ids in the db? diff --git a/synapse/storage/schema/delta/v12.sql b/synapse/storage/schema/delta/v12.sql new file mode 100644 index 0000000000..a6867cba62 --- /dev/null +++ b/synapse/storage/schema/delta/v12.sql @@ -0,0 +1,54 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE +); + +-- Push notification endpoints that users have configured +CREATE TABLE IF NOT EXISTS pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + instance_handle varchar(32) NOT NULL, + kind varchar(8) NOT NULL, + app_id varchar(64) NOT NULL, + app_display_name varchar(64) NOT NULL, + device_display_name varchar(128) NOT NULL, + pushkey blob NOT NULL, + ts BIGINT NOT NULL, + lang varchar(8), + data blob, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + FOREIGN KEY(user_name) REFERENCES users(name), + UNIQUE (app_id, pushkey) +); + +CREATE TABLE IF NOT EXISTS push_rules ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + priority_class TINYINT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + conditions TEXT NOT NULL, + actions TEXT NOT NULL, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); diff --git a/synapse/storage/schema/delta/v13.sql b/synapse/storage/schema/delta/v13.sql new file mode 100644 index 0000000000..beb39ca201 --- /dev/null +++ b/synapse/storage/schema/delta/v13.sql @@ -0,0 +1,24 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS user_filters( + user_id TEXT, + filter_id INTEGER, + filter_json TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) +); + +CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( + user_id, filter_id +); diff --git a/synapse/storage/schema/filtering.sql b/synapse/storage/schema/filtering.sql new file mode 100644 index 0000000000..beb39ca201 --- /dev/null +++ b/synapse/storage/schema/filtering.sql @@ -0,0 +1,24 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS user_filters( + user_id TEXT, + filter_id INTEGER, + filter_json TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) +); + +CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( + user_id, filter_id +); diff --git a/synapse/storage/schema/pusher.sql b/synapse/storage/schema/pusher.sql new file mode 100644 index 0000000000..8c4dfd5c1b --- /dev/null +++ b/synapse/storage/schema/pusher.sql @@ -0,0 +1,46 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- Push notification endpoints that users have configured +CREATE TABLE IF NOT EXISTS pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + instance_handle varchar(32) NOT NULL, + kind varchar(8) NOT NULL, + app_id varchar(64) NOT NULL, + app_display_name varchar(64) NOT NULL, + device_display_name varchar(128) NOT NULL, + pushkey blob NOT NULL, + ts BIGINT NOT NULL, + lang varchar(8), + data blob, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + FOREIGN KEY(user_name) REFERENCES users(name), + UNIQUE (app_id, pushkey) +); + +CREATE TABLE IF NOT EXISTS push_rules ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + priority_class TINYINT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + conditions TEXT NOT NULL, + actions TEXT NOT NULL, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); diff --git a/synapse/storage/schema/rejections.sql b/synapse/storage/schema/rejections.sql new file mode 100644 index 0000000000..bd2a8b1bb5 --- /dev/null +++ b/synapse/storage/schema/rejections.sql @@ -0,0 +1,21 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE +); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 5327517704..71db16d0e5 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -78,12 +78,6 @@ class StateStore(SQLBaseStore): f, ) - def store_state_groups(self, event): - return self.runInteraction( - "store_state_groups", - self._store_state_groups_txn, event - ) - def _store_state_groups_txn(self, txn, event, context): if context.current_state is None: return diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index bedc3c6c52..3ccb6f8a61 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -39,6 +39,8 @@ from ._base import SQLBaseStore from synapse.api.errors import SynapseError from synapse.util.logutils import log_function +from collections import namedtuple + import logging @@ -52,91 +54,79 @@ _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" -def _parse_stream_token(string): - try: - if string[0] != 's': - raise - return int(string[1:]) - except: - raise SynapseError(400, "Invalid token") - - -def _parse_topological_token(string): - try: - if string[0] != 't': - raise - parts = string[1:].split('-', 1) - return (int(parts[0]), int(parts[1])) - except: - raise SynapseError(400, "Invalid token") - - -def is_stream_token(string): - try: - _parse_stream_token(string) - return True - except: - return False - - -def is_topological_token(string): - try: - _parse_topological_token(string) - return True - except: - return False - - -def _get_token_bound(token, comparison): - try: - s = _parse_stream_token(token) - return "%s %s %d" % ("stream_ordering", comparison, s) - except: - pass - - try: - top, stream = _parse_topological_token(token) - return "%s %s %d AND %s %s %d" % ( - "topological_ordering", comparison, top, - "stream_ordering", comparison, stream, - ) - except: - pass - - raise SynapseError(400, "Invalid token") - - -class StreamStore(SQLBaseStore): - @log_function - def get_room_events(self, user_id, from_key, to_key, room_id, limit=0, - direction='f', with_feedback=False): - # We deal with events request in two different ways depending on if - # this looks like an /events request or a pagination request. - is_events = ( - direction == 'f' - and user_id - and is_stream_token(from_key) - and to_key and is_stream_token(to_key) - ) +class _StreamToken(namedtuple("_StreamToken", "topological stream")): + """Tokens are positions between events. The token "s1" comes after event 1. + + s0 s1 + | | + [0] V [1] V [2] + + Tokens can either be a point in the live event stream or a cursor going + through historic events. + + When traversing the live event stream events are ordered by when they + arrived at the homeserver. + + When traversing historic events the events are ordered by their depth in + the event graph "topological_ordering" and then by when they arrived at the + homeserver "stream_ordering". + + Live tokens start with an "s" followed by the "stream_ordering" id of the + event it comes after. Historic tokens start with a "t" followed by the + "topological_ordering" id of the event it comes after, follewed by "-", + followed by the "stream_ordering" id of the event it comes after. + """ + __slots__ = [] + + @classmethod + def parse(cls, string): + try: + if string[0] == 's': + return cls(topological=None, stream=int(string[1:])) + if string[0] == 't': + parts = string[1:].split('-', 1) + return cls(topological=int(parts[0]), stream=int(parts[1])) + except: + pass + raise SynapseError(400, "Invalid token %r" % (string,)) + + @classmethod + def parse_stream_token(cls, string): + try: + if string[0] == 's': + return cls(topological=None, stream=int(string[1:])) + except: + pass + raise SynapseError(400, "Invalid token %r" % (string,)) + + def __str__(self): + if self.topological is not None: + return "t%d-%d" % (self.topological, self.stream) + else: + return "s%d" % (self.stream,) - if is_events: - return self.get_room_events_stream( - user_id=user_id, - from_key=from_key, - to_key=to_key, - room_id=room_id, - limit=limit, - with_feedback=with_feedback, + def lower_bound(self): + if self.topological is None: + return "(%d < %s)" % (self.stream, "stream_ordering") + else: + return "(%d < %s OR (%d == %s AND %d < %s))" % ( + self.topological, "topological_ordering", + self.topological, "topological_ordering", + self.stream, "stream_ordering", ) + + def upper_bound(self): + if self.topological is None: + return "(%d >= %s)" % (self.stream, "stream_ordering") else: - return self.paginate_room_events( - from_key=from_key, - to_key=to_key, - room_id=room_id, - limit=limit, - with_feedback=with_feedback, + return "(%d > %s OR (%d == %s AND %d >= %s))" % ( + self.topological, "topological_ordering", + self.topological, "topological_ordering", + self.stream, "stream_ordering", ) + +class StreamStore(SQLBaseStore): @log_function def get_room_events_stream(self, user_id, from_key, to_key, room_id, limit=0, with_feedback=False): @@ -162,8 +152,8 @@ class StreamStore(SQLBaseStore): limit = MAX_STREAM_SIZE # From and to keys should be integers from ordering. - from_id = _parse_stream_token(from_key) - to_id = _parse_stream_token(to_key) + from_id = _StreamToken.parse_stream_token(from_key) + to_id = _StreamToken.parse_stream_token(to_key) if from_key == to_key: return defer.succeed(([], to_key)) @@ -181,7 +171,7 @@ class StreamStore(SQLBaseStore): } def f(txn): - txn.execute(sql, (user_id, user_id, from_id, to_id,)) + txn.execute(sql, (user_id, user_id, from_id.stream, to_id.stream,)) rows = self.cursor_to_dict(txn) @@ -191,8 +181,11 @@ class StreamStore(SQLBaseStore): get_prev_content=True ) + self._set_before_and_after(ret, rows) + if rows: key = "s%d" % max([r["stream_ordering"] for r in rows]) + else: # Assume we didn't get anything because there was nothing to # get. @@ -211,17 +204,21 @@ class StreamStore(SQLBaseStore): # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. - from_comp = '<=' if direction == 'b' else '>' - to_comp = '>' if direction == 'b' else '<=' - order = "DESC" if direction == 'b' else "ASC" - args = [room_id] - - bounds = _get_token_bound(from_key, from_comp) - if to_key: - bounds = "%s AND %s" % ( - bounds, _get_token_bound(to_key, to_comp) - ) + if direction == 'b': + order = "DESC" + bounds = _StreamToken.parse(from_key).upper_bound() + if to_key: + bounds = "%s AND %s" % ( + bounds, _StreamToken.parse(to_key).lower_bound() + ) + else: + order = "ASC" + bounds = _StreamToken.parse(from_key).lower_bound() + if to_key: + bounds = "%s AND %s" % ( + bounds, _StreamToken.parse(to_key).upper_bound() + ) if int(limit) > 0: args.append(int(limit)) @@ -249,9 +246,13 @@ class StreamStore(SQLBaseStore): topo = rows[-1]["topological_ordering"] toke = rows[-1]["stream_ordering"] if direction == 'b': - topo -= 1 + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. toke -= 1 - next_token = "t%s-%s" % (topo, toke) + next_token = str(_StreamToken(topo, toke)) else: # TODO (erikj): We should work out what to do here instead. next_token = to_key if to_key else from_key @@ -262,35 +263,62 @@ class StreamStore(SQLBaseStore): get_prev_content=True ) + self._set_before_and_after(events, rows) + return events, next_token, return self.runInteraction("paginate_room_events", f) def get_recent_events_for_room(self, room_id, limit, end_token, - with_feedback=False): + with_feedback=False, from_token=None): # TODO (erikj): Handle compressed feedback - sql = ( - "SELECT stream_ordering, topological_ordering, event_id FROM events " - "WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0 " - "ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ? " - ) + end_token = _StreamToken.parse_stream_token(end_token) - def f(txn): - txn.execute(sql, (room_id, end_token, limit,)) + if from_token is None: + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0" + " ORDER BY topological_ordering DESC, stream_ordering DESC" + " LIMIT ?" + ) + else: + from_token = _StreamToken.parse_stream_token(from_token) + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering > ?" + " AND stream_ordering <= ? AND outlier = 0" + " ORDER BY topological_ordering DESC, stream_ordering DESC" + " LIMIT ?" + ) + + def get_recent_events_for_room_txn(txn): + if from_token is None: + txn.execute(sql, (room_id, end_token.stream, limit,)) + else: + txn.execute(sql, ( + room_id, from_token.stream, end_token.stream, limit + )) rows = self.cursor_to_dict(txn) rows.reverse() # As we selected with reverse ordering if rows: + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # since we are going backwards so we subtract one from the + # stream part. topo = rows[0]["topological_ordering"] - toke = rows[0]["stream_ordering"] - start_token = "t%s-%s" % (topo, toke) + toke = rows[0]["stream_ordering"] - 1 + start_token = str(_StreamToken(topo, toke)) - token = (start_token, end_token) + token = (start_token, str(end_token)) else: - token = (end_token, end_token) + token = (str(end_token), str(end_token)) events = self._get_events_txn( txn, @@ -298,9 +326,13 @@ class StreamStore(SQLBaseStore): get_prev_content=True ) + self._set_before_and_after(events, rows) + return events, token - return self.runInteraction("get_recent_events_for_room", f) + return self.runInteraction( + "get_recent_events_for_room", get_recent_events_for_room_txn + ) def get_room_events_max_id(self): return self.runInteraction( @@ -322,3 +354,12 @@ class StreamStore(SQLBaseStore): key = res[0]["m"] return "s%d" % (key,) + + @staticmethod + def _set_before_and_after(events, rows): + for event, row in zip(events, rows): + stream = row["stream_ordering"] + topo = event.depth + internal = event.internal_metadata + internal.before = str(_StreamToken(topo, stream - 1)) + internal.after = str(_StreamToken(topo, stream)) diff --git a/synapse/types.py b/synapse/types.py index faac729ff2..f6a1b0bbcf 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -119,3 +119,6 @@ class StreamToken( d = self._asdict() d[key] = new_value return StreamToken(**d) + + +ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) |