diff options
45 files changed, 1444 insertions, 623 deletions
diff --git a/AUTHORS.rst b/AUTHORS.rst index 8711a6ae5c..3dcb1c2a89 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -57,3 +57,6 @@ Florent Violleau <floviolleau at gmail dot com> Niklas Riekenbrauck <nikriek at gmail dot.com> * Add JWT support for registration and login + +Christoph Witzany <christoph at web.crofting.com> + * Add LDAP support for authentication diff --git a/README.rst b/README.rst index 6136e0c1fe..02e7c61d1e 100644 --- a/README.rst +++ b/README.rst @@ -104,7 +104,7 @@ Installing prerequisites on Ubuntu or Debian:: sudo apt-get install build-essential python2.7-dev libffi-dev \ python-pip python-setuptools sqlite3 \ - libssl-dev python-virtualenv libjpeg-dev + libssl-dev python-virtualenv libjpeg-dev libxslt1-dev Installing prerequisites on ArchLinux:: @@ -557,6 +557,23 @@ as the primary means of identity and E2E encryption is not complete. As such, we are running a single identity server (https://matrix.org) at the current time. + +URL Previews +============ + +Synapse 0.15.0 introduces an experimental new API for previewing URLs at +/_matrix/media/r0/preview_url. This is disabled by default. To turn it on +you must enable the `url_preview_enabled: True` config parameter and explicitly +specify the IP ranges that Synapse is not allowed to spider for previewing in +the `url_preview_ip_range_blacklist` configuration parameter. This is critical +from a security perspective to stop arbitrary Matrix users spidering 'internal' +URLs on your network. At the very least we recommend that your loopback and +RFC1918 IP addresses are blacklisted. + +This also requires the optional lxml and netaddr python dependencies to be +installed. + + Password reset ============== diff --git a/UPGRADE.rst b/UPGRADE.rst index 4f08cbb96a..699f04c2c2 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -30,6 +30,14 @@ running: python synapse/python_dependencies.py | xargs -n1 pip install +Upgrading to v0.15.0 +==================== + +If you want to use the new URL previewing API (/_matrix/media/r0/preview_url) +then you have to explicitly enable it in the config and update your dependencies +dependencies. See README.rst for details. + + Upgrading to v0.11.0 ==================== diff --git a/docs/url_previews.rst b/docs/url_previews.rst new file mode 100644 index 0000000000..634d9d907f --- /dev/null +++ b/docs/url_previews.rst @@ -0,0 +1,74 @@ +URL Previews +============ + +Design notes on a URL previewing service for Matrix: + +Options are: + + 1. Have an AS which listens for URLs, downloads them, and inserts an event that describes their metadata. + * Pros: + * Decouples the implementation entirely from Synapse. + * Uses existing Matrix events & content repo to store the metadata. + * Cons: + * Which AS should provide this service for a room, and why should you trust it? + * Doesn't work well with E2E; you'd have to cut the AS into every room + * the AS would end up subscribing to every room anyway. + + 2. Have a generic preview API (nothing to do with Matrix) that provides a previewing service: + * Pros: + * Simple and flexible; can be used by any clients at any point + * Cons: + * If each HS provides one of these independently, all the HSes in a room may needlessly DoS the target URI + * We need somewhere to store the URL metadata rather than just using Matrix itself + * We can't piggyback on matrix to distribute the metadata between HSes. + + 3. Make the synapse of the sending user responsible for spidering the URL and inserting an event asynchronously which describes the metadata. + * Pros: + * Works transparently for all clients + * Piggy-backs nicely on using Matrix for distributing the metadata. + * No confusion as to which AS + * Cons: + * Doesn't work with E2E + * We might want to decouple the implementation of the spider from the HS, given spider behaviour can be quite complicated and evolve much more rapidly than the HS. It's more like a bot than a core part of the server. + + 4. Make the sending client use the preview API and insert the event itself when successful. + * Pros: + * Works well with E2E + * No custom server functionality + * Lets the client customise the preview that they send (like on FB) + * Cons: + * Entirely specific to the sending client, whereas it'd be nice if /any/ URL was correctly previewed if clients support it. + + 5. Have the option of specifying a shared (centralised) previewing service used by a room, to avoid all the different HSes in the room DoSing the target. + +Best solution is probably a combination of both 2 and 4. + * Sending clients do their best to create and send a preview at the point of sending the message, perhaps delaying the message until the preview is computed? (This also lets the user validate the preview before sending) + * Receiving clients have the option of going and creating their own preview if one doesn't arrive soon enough (or if the original sender didn't create one) + +This is a bit magical though in that the preview could come from two entirely different sources - the sending HS or your local one. However, this can always be exposed to users: "Generate your own URL previews if none are available?" + +This is tantamount also to senders calculating their own thumbnails for sending in advance of the main content - we are trusting the sender not to lie about the content in the thumbnail. Whereas currently thumbnails are calculated by the receiving homeserver to avoid this attack. + +However, this kind of phishing attack does exist whether we let senders pick their thumbnails or not, in that a malicious sender can send normal text messages around the attachment claiming it to be legitimate. We could rely on (future) reputation/abuse management to punish users who phish (be it with bogus metadata or bogus descriptions). Bogus metadata is particularly bad though, especially if it's avoidable. + +As a first cut, let's do #2 and have the receiver hit the API to calculate its own previews (as it does currently for image thumbnails). We can then extend/optimise this to option 4 as a special extra if needed. + +API +--- + +GET /_matrix/media/r0/preview_url?url=http://wherever.com +200 OK +{ + "og:type" : "article" + "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672" + "og:title" : "Matrix on Twitter" + "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png" + "og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”" + "og:site_name" : "Twitter" +} + +* Downloads the URL + * If HTML, just stores it in RAM and parses it for OG meta tags + * Download any media OG meta tags to the media repo, and refer to them in the OG via mxc:// URIs. + * If a media filetype we know we can thumbnail: store it on disk, and hand it to the thumbnailer. Generate OG meta tags from the thumbnailer contents. + * Otherwise, don't bother downloading further. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index bc90605324..6da6a1b62e 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -100,11 +100,6 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("push_bulk to %s threw exception %s", uri, ex) defer.returnValue(False) - @defer.inlineCallbacks - def push(self, service, event, txn_id=None): - response = yield self.push_bulk(service, [event], txn_id) - defer.returnValue(response) - def _serialize(self, events): time_now = self.clock.time_msec() return [ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index acf74c8761..9a80ac39ec 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -30,13 +30,14 @@ from .saml2 import SAML2Config from .cas import CasConfig from .password import PasswordConfig from .jwt import JWTConfig +from .ldap import LDAPConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, - JWTConfig, PasswordConfig,): + JWTConfig, LDAPConfig, PasswordConfig,): pass diff --git a/synapse/config/ldap.py b/synapse/config/ldap.py new file mode 100644 index 0000000000..9c14593a99 --- /dev/null +++ b/synapse/config/ldap.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 Niklas Riekenbrauck +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class LDAPConfig(Config): + def read_config(self, config): + ldap_config = config.get("ldap_config", None) + if ldap_config: + self.ldap_enabled = ldap_config.get("enabled", False) + self.ldap_server = ldap_config["server"] + self.ldap_port = ldap_config["port"] + self.ldap_tls = ldap_config.get("tls", False) + self.ldap_search_base = ldap_config["search_base"] + self.ldap_search_property = ldap_config["search_property"] + self.ldap_email_property = ldap_config["email_property"] + self.ldap_full_name_property = ldap_config["full_name_property"] + else: + self.ldap_enabled = False + self.ldap_server = None + self.ldap_port = None + self.ldap_tls = False + self.ldap_search_base = None + self.ldap_search_property = None + self.ldap_email_property = None + self.ldap_full_name_property = None + + def default_config(self, **kwargs): + return """\ + # ldap_config: + # enabled: true + # server: "ldap://localhost" + # port: 389 + # tls: false + # search_base: "ou=Users,dc=example,dc=com" + # search_property: "cn" + # email_property: "email" + # full_name_property: "givenName" + """ diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2e96c09013..49922c6d03 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -16,6 +16,8 @@ from ._base import Config from collections import namedtuple +import sys + ThumbnailRequirement = namedtuple( "ThumbnailRequirement", ["width", "height", "method", "media_type"] ) @@ -23,7 +25,7 @@ ThumbnailRequirement = namedtuple( def parse_thumbnail_requirements(thumbnail_sizes): """ Takes a list of dictionaries with "width", "height", and "method" keys - and creates a map from image media types to the thumbnail size, thumnailing + and creates a map from image media types to the thumbnail size, thumbnailing method, and thumbnail media type to precalculate Args: @@ -53,12 +55,25 @@ class ContentRepositoryConfig(Config): def read_config(self, config): self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"]) + self.max_spider_size = self.parse_size(config["max_spider_size"]) self.media_store_path = self.ensure_directory(config["media_store_path"]) self.uploads_path = self.ensure_directory(config["uploads_path"]) self.dynamic_thumbnails = config["dynamic_thumbnails"] self.thumbnail_requirements = parse_thumbnail_requirements( config["thumbnail_sizes"] ) + self.url_preview_enabled = config["url_preview_enabled"] + if self.url_preview_enabled: + try: + from netaddr import IPSet + if "url_preview_ip_range_blacklist" in config: + self.url_preview_ip_range_blacklist = IPSet( + config["url_preview_ip_range_blacklist"] + ) + if "url_preview_url_blacklist" in config: + self.url_preview_url_blacklist = config["url_preview_url_blacklist"] + except ImportError: + sys.stderr.write("\nmissing netaddr dep - disabling preview_url API\n") def default_config(self, **kwargs): media_store = self.default_path("media_store") @@ -80,7 +95,7 @@ class ContentRepositoryConfig(Config): # the resolution requested by the client. If true then whenever # a new resolution is requested by the client the server will # generate a new thumbnail. If false the server will pick a thumbnail - # from a precalcualted list. + # from a precalculated list. dynamic_thumbnails: false # List of thumbnail to precalculate when an image is uploaded. @@ -100,4 +115,62 @@ class ContentRepositoryConfig(Config): - width: 800 height: 600 method: scale + + # Is the preview URL API enabled? If enabled, you *must* specify + # an explicit url_preview_ip_range_blacklist of IPs that the spider is + # denied from accessing. + url_preview_enabled: False + + # List of IP address CIDR ranges that the URL preview spider is denied + # from accessing. There are no defaults: you must explicitly + # specify a list for URL previewing to work. You should specify any + # internal services in your network that you do not want synapse to try + # to connect to, otherwise anyone in any Matrix room could cause your + # synapse to issue arbitrary GET requests to your internal services, + # causing serious security issues. + # + # url_preview_ip_range_blacklist: + # - '127.0.0.0/8' + # - '10.0.0.0/8' + # - '172.16.0.0/12' + # - '192.168.0.0/16' + + # Optional list of URL matches that the URL preview spider is + # denied from accessing. You should use url_preview_ip_range_blacklist + # in preference to this, otherwise someone could define a public DNS + # entry that points to a private IP address and circumvent the blacklist. + # This is more useful if you know there is an entire shape of URL that + # you know that will never want synapse to try to spider. + # + # Each list entry is a dictionary of url component attributes as returned + # by urlparse.urlsplit as applied to the absolute form of the URL. See + # https://docs.python.org/2/library/urlparse.html#urlparse.urlsplit + # The values of the dictionary are treated as an filename match pattern + # applied to that component of URLs, unless they start with a ^ in which + # case they are treated as a regular expression match. If all the + # specified component matches for a given list item succeed, the URL is + # blacklisted. + # + # url_preview_url_blacklist: + # # blacklist any URL with a username in its URI + # - username: '*' + # + # # blacklist all *.google.com URLs + # - netloc: 'google.com' + # - netloc: '*.google.com' + # + # # blacklist all plain HTTP URLs + # - scheme: 'http' + # + # # blacklist http(s)://www.acme.com/foo + # - netloc: 'www.acme.com' + # path: '/foo' + # + # # blacklist any URL with a literal IPv4 address + # - netloc: '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$' + + # The largest allowed URL preview spidering size in bytes + max_spider_size: "10M" + + """ % locals() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d5d6faa85f..7a13a8b11c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -49,6 +49,21 @@ class AuthHandler(BaseHandler): self.sessions = {} self.INVALID_TOKEN_HTTP_STATUS = 401 + self.ldap_enabled = hs.config.ldap_enabled + self.ldap_server = hs.config.ldap_server + self.ldap_port = hs.config.ldap_port + self.ldap_tls = hs.config.ldap_tls + self.ldap_search_base = hs.config.ldap_search_base + self.ldap_search_property = hs.config.ldap_search_property + self.ldap_email_property = hs.config.ldap_email_property + self.ldap_full_name_property = hs.config.ldap_full_name_property + + if self.ldap_enabled is True: + import ldap + logger.info("Import ldap version: %s", ldap.__version__) + + self.hs = hs # FIXME better possibility to access registrationHandler later? + @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ @@ -215,8 +230,10 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) + if not (yield self._check_password(user_id, password)): + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + defer.returnValue(user_id) @defer.inlineCallbacks @@ -340,8 +357,10 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) + + if not (yield self._check_password(user_id, password)): + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) logger.info("Logging in user %s", user_id) access_token = yield self.issue_access_token(user_id) @@ -407,11 +426,60 @@ class AuthHandler(BaseHandler): else: defer.returnValue(user_infos.popitem()) - def _check_password(self, user_id, password, stored_hash): - """Checks that user_id has passed password, raises LoginError if not.""" - if not self.validate_hash(password, stored_hash): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) + @defer.inlineCallbacks + def _check_password(self, user_id, password): + defer.returnValue( + not ( + (yield self._check_ldap_password(user_id, password)) + or + (yield self._check_local_password(user_id, password)) + )) + + @defer.inlineCallbacks + def _check_local_password(self, user_id, password): + try: + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + defer.returnValue(not self.validate_hash(password, password_hash)) + except LoginError: + defer.returnValue(False) + + @defer.inlineCallbacks + def _check_ldap_password(self, user_id, password): + if self.ldap_enabled is not True: + logger.debug("LDAP not configured") + defer.returnValue(False) + + import ldap + + logger.info("Authenticating %s with LDAP" % user_id) + try: + ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) + logger.debug("Connecting LDAP server at %s" % ldap_url) + l = ldap.initialize(ldap_url) + if self.ldap_tls: + logger.debug("Initiating TLS") + self._connection.start_tls_s() + + local_name = UserID.from_string(user_id).localpart + + dn = "%s=%s, %s" % ( + self.ldap_search_property, + local_name, + self.ldap_search_base) + logger.debug("DN for LDAP authentication: %s" % dn) + + l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) + + if not (yield self.does_user_exist(user_id)): + handler = self.hs.get_handlers().registration_handler + user_id, access_token = ( + yield handler.register(localpart=local_name) + ) + + defer.returnValue(True) + except ldap.LDAPError, e: + logger.warn("LDAP error: %s", e) + defer.returnValue(False) @defer.inlineCallbacks def issue_access_token(self, user_id): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3686738e59..83dab32bcb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -40,6 +40,7 @@ from synapse.events.utils import prune_event from synapse.util.retryutils import NotRetryingDestination from synapse.push.action_generator import ActionGenerator +from synapse.util.distributor import user_joined_room from twisted.internet import defer @@ -49,10 +50,6 @@ import logging logger = logging.getLogger(__name__) -def user_joined_room(distributor, user, room_id): - return distributor.fire("user_joined_room", user, room_id) - - class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 10608c0dd9..f51feda2f4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -34,10 +34,6 @@ import logging logger = logging.getLogger(__name__) -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - class MessageHandler(BaseHandler): def __init__(self, hs): @@ -49,35 +45,6 @@ class MessageHandler(BaseHandler): self.snapshot_cache = SnapshotCache() @defer.inlineCallbacks - def get_message(self, msg_id=None, room_id=None, sender_id=None, - user_id=None): - """ Retrieve a message. - - Args: - msg_id (str): The message ID to obtain. - room_id (str): The room where the message resides. - sender_id (str): The user ID of the user who sent the message. - user_id (str): The user ID of the user making this request. - Returns: - The message, or None if no message exists. - Raises: - SynapseError if something went wrong. - """ - yield self.auth.check_joined_room(room_id, user_id) - - # Pull out the message from the db -# msg = yield self.store.get_message( -# room_id=room_id, -# msg_id=msg_id, -# user_id=sender_id -# ) - - # TODO (erikj): Once we work out the correct c-s api we need to think - # on how to do this. - - defer.returnValue(None) - - @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, as_client_event=True): """Get messages in a room. @@ -202,12 +169,8 @@ class MessageHandler(BaseHandler): membership = builder.content.get("membership", None) target = UserID.from_string(builder.state_key) - if membership == Membership.JOIN: + if membership in {Membership.JOIN, Membership.INVITE}: # If event doesn't include a display name, add one. - yield collect_presencelike_data( - self.distributor, target, builder.content - ) - elif membership == Membership.INVITE: profile = self.hs.get_handlers().profile_handler content = builder.content diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b45eafbb49..e37409170d 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,7 +17,6 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.types import UserID, Requester -from synapse.util import unwrapFirstError from ._base import BaseHandler @@ -27,14 +26,6 @@ import logging logger = logging.getLogger(__name__) -def changed_presencelike_data(distributor, user, state): - return distributor.fire("changed_presencelike_data", user, state) - - -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - class ProfileHandler(BaseHandler): def __init__(self, hs): @@ -46,17 +37,9 @@ class ProfileHandler(BaseHandler): ) distributor = hs.get_distributor() - self.distributor = distributor - - distributor.declare("collect_presencelike_data") - distributor.declare("changed_presencelike_data") distributor.observe("registered_user", self.registered_user) - distributor.observe( - "collect_presencelike_data", self.collect_presencelike_data - ) - def registered_user(self, user): return self.store.create_profile(user.localpart) @@ -105,10 +88,6 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_displayname ) - yield changed_presencelike_data(self.distributor, target_user, { - "displayname": new_displayname, - }) - yield self._update_join_states(requester) @defer.inlineCallbacks @@ -152,31 +131,9 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_avatar_url ) - yield changed_presencelike_data(self.distributor, target_user, { - "avatar_url": new_avatar_url, - }) - yield self._update_join_states(requester) @defer.inlineCallbacks - def collect_presencelike_data(self, user, state): - if not self.hs.is_mine(user): - defer.returnValue(None) - - (displayname, avatar_url) = yield defer.gatherResults( - [ - self.store.get_profile_displayname(user.localpart), - self.store.get_profile_avatar_url(user.localpart), - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - - state["displayname"] = displayname - state["avatar_url"] = avatar_url - - defer.returnValue(None) - - @defer.inlineCallbacks def on_profile_query(self, args): user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f287ee247b..b0862067e1 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -23,6 +23,7 @@ from synapse.api.errors import ( from ._base import BaseHandler from synapse.util.async import run_on_reactor from synapse.http.client import CaptchaServerHttpClient +from synapse.util.distributor import registered_user import logging import urllib @@ -30,10 +31,6 @@ import urllib logger = logging.getLogger(__name__) -def registered_user(distributor, user): - return distributor.fire("registered_user", user) - - class RegistrationHandler(BaseHandler): def __init__(self, hs): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3e1d9282d7..ea306cd42a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,7 +25,6 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.util import stringutils from synapse.util.async import concurrently_execute -from synapse.util.logcontext import preserve_context_over_fn from synapse.util.caches.response_cache import ResponseCache from collections import OrderedDict @@ -39,20 +38,6 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) - - -def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) - - class RoomCreationHandler(BaseHandler): PRESETS_DICT = { diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8c41cb6f3c..b69f36aefe 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -23,7 +23,8 @@ from synapse.api.constants import ( EventTypes, Membership, ) from synapse.api.errors import AuthError, SynapseError, Codes -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.async import Linearizer +from synapse.util.distributor import user_left_room, user_joined_room from signedjson.sign import verify_signed_json from signedjson.key import decode_verify_key_bytes @@ -37,20 +38,6 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) - - -def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) - - class RoomMemberHandler(BaseHandler): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level @@ -60,6 +47,8 @@ class RoomMemberHandler(BaseHandler): def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) + self.member_linearizer = Linearizer() + self.clock = hs.get_clock() self.distributor = hs.get_distributor() @@ -183,6 +172,34 @@ class RoomMemberHandler(BaseHandler): third_party_signed=None, ratelimit=True, ): + key = (target, room_id,) + + with (yield self.member_linearizer.queue(key)): + result = yield self._update_membership( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + ) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _update_membership( + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + ): effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" @@ -375,19 +392,6 @@ class RoomMemberHandler(BaseHandler): and guest_access.content["guest_access"] == "can_join" ) - def _should_do_dance(self, current_state, inviter, room_hosts=None): - # TODO: Shouldn't this be remote_room_host? - room_hosts = room_hosts or [] - - is_host_in_room = self.is_host_in_room(current_state) - if is_host_in_room: - return False, room_hosts - - if inviter and not self.hs.is_mine(inviter): - room_hosts.append(inviter.domain) - - return True, room_hosts - @defer.inlineCallbacks def lookup_room_alias(self, room_alias): """ diff --git a/synapse/http/client.py b/synapse/http/client.py index cbd45b2bbe..6c89b20984 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -15,17 +15,24 @@ from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE -from synapse.api.errors import CodeMessageException +from synapse.api.errors import ( + CodeMessageException, SynapseError, Codes, +) from synapse.util.logcontext import preserve_context_over_fn import synapse.metrics +from synapse.http.endpoint import SpiderEndpoint from canonicaljson import encode_canonical_json -from twisted.internet import defer, reactor, ssl +from twisted.internet import defer, reactor, ssl, protocol +from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint from twisted.web.client import ( - Agent, readBody, FileBodyProducer, PartialDownloadError, + BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, + readBody, FileBodyProducer, PartialDownloadError, ) +from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers +from twisted.web._newclient import ResponseDone from StringIO import StringIO @@ -238,6 +245,107 @@ class SimpleHttpClient(object): else: raise CodeMessageException(response.code, body) + # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. + # The two should be factored out. + + @defer.inlineCallbacks + def get_file(self, url, output_stream, max_size=None): + """GETs a file from a given URL + Args: + url (str): The URL to GET + output_stream (file): File to write the response body to. + Returns: + A (int,dict,string,int) tuple of the file length, dict of the response + headers, absolute URI of the response and HTTP response code. + """ + + response = yield self.request( + "GET", + url.encode("ascii"), + headers=Headers({ + b"User-Agent": [self.user_agent], + }) + ) + + headers = dict(response.headers.getAllRawHeaders()) + + if 'Content-Length' in headers and headers['Content-Length'] > max_size: + logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) + raise SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + ) + + if response.code > 299: + logger.warn("Got %d when downloading %s" % (response.code, url)) + raise SynapseError( + 502, + "Got error %d" % (response.code,), + Codes.UNKNOWN, + ) + + # TODO: if our Content-Type is HTML or something, just read the first + # N bytes into RAM rather than saving it all to disk only to read it + # straight back in again + + try: + length = yield preserve_context_over_fn( + _readBodyToFile, + response, output_stream, max_size + ) + except Exception as e: + logger.exception("Failed to download body") + raise SynapseError( + 502, + ("Failed to download remote body: %s" % e), + Codes.UNKNOWN, + ) + + defer.returnValue((length, headers, response.request.absoluteURI, response.code)) + + +# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. +# The two should be factored out. + +class _ReadBodyToFileProtocol(protocol.Protocol): + def __init__(self, stream, deferred, max_size): + self.stream = stream + self.deferred = deferred + self.length = 0 + self.max_size = max_size + + def dataReceived(self, data): + self.stream.write(data) + self.length += len(data) + if self.max_size is not None and self.length >= self.max_size: + self.deferred.errback(SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + )) + self.deferred = defer.Deferred() + self.transport.loseConnection() + + def connectionLost(self, reason): + if reason.check(ResponseDone): + self.deferred.callback(self.length) + elif reason.check(PotentialDataLoss): + # stolen from https://github.com/twisted/treq/pull/49/files + # http://twistedmatrix.com/trac/ticket/4840 + self.deferred.callback(self.length) + else: + self.deferred.errback(reason) + + +# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. +# The two should be factored out. + +def _readBodyToFile(response, stream, max_size): + d = defer.Deferred() + response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) + return d + class CaptchaServerHttpClient(SimpleHttpClient): """ @@ -269,6 +377,59 @@ class CaptchaServerHttpClient(SimpleHttpClient): defer.returnValue(e.response) +class SpiderEndpointFactory(object): + def __init__(self, hs): + self.blacklist = hs.config.url_preview_ip_range_blacklist + self.policyForHTTPS = hs.get_http_client_context_factory() + + def endpointForURI(self, uri): + logger.info("Getting endpoint for %s", uri.toBytes()) + if uri.scheme == "http": + return SpiderEndpoint( + reactor, uri.host, uri.port, self.blacklist, + endpoint=TCP4ClientEndpoint, + endpoint_kw_args={ + 'timeout': 15 + }, + ) + elif uri.scheme == "https": + tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) + return SpiderEndpoint( + reactor, uri.host, uri.port, self.blacklist, + endpoint=SSL4ClientEndpoint, + endpoint_kw_args={ + 'sslContextFactory': tlsPolicy, + 'timeout': 15 + }, + ) + else: + logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme) + + +class SpiderHttpClient(SimpleHttpClient): + """ + Separate HTTP client for spidering arbitrary URLs. + Special in that it follows retries and has a UA that looks + like a browser. + + used by the preview_url endpoint in the content repo. + """ + def __init__(self, hs): + SimpleHttpClient.__init__(self, hs) + # clobber the base class's agent and UA: + self.agent = ContentDecoderAgent( + BrowserLikeRedirectAgent( + Agent.usingEndpointFactory( + reactor, + SpiderEndpointFactory(hs) + ) + ), [('gzip', GzipDecoder)] + ) + # We could look like Chrome: + # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) + # Chrome Safari" % hs.version_string) + + def encode_urlencode_args(args): return {k: encode_urlencode_arg(v) for k, v in args.items()} diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 4775f6707d..a456dc19da 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError import collections import logging import random +import time logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ SERVER_CACHE = {} _Server = collections.namedtuple( - "_Server", "priority weight host port" + "_Server", "priority weight host port expires" ) @@ -74,6 +75,37 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, return transport_endpoint(reactor, domain, port, **endpoint_kw_args) +class SpiderEndpoint(object): + """An endpoint which refuses to connect to blacklisted IP addresses + Implements twisted.internet.interfaces.IStreamClientEndpoint. + """ + def __init__(self, reactor, host, port, blacklist, + endpoint=TCP4ClientEndpoint, endpoint_kw_args={}): + self.reactor = reactor + self.host = host + self.port = port + self.blacklist = blacklist + self.endpoint = endpoint + self.endpoint_kw_args = endpoint_kw_args + + @defer.inlineCallbacks + def connect(self, protocolFactory): + address = yield self.reactor.resolve(self.host) + + from netaddr import IPAddress + if IPAddress(address) in self.blacklist: + raise ConnectError( + "Refusing to spider blacklisted IP address %s" % address + ) + + logger.info("Connecting to %s:%s", address, self.port) + endpoint = self.endpoint( + self.reactor, address, self.port, **self.endpoint_kw_args + ) + connection = yield endpoint.connect(protocolFactory) + defer.returnValue(connection) + + class SRVClientEndpoint(object): """An endpoint which looks up SRV records for a service. Cycles through the list of servers starting with each call to connect @@ -92,7 +124,8 @@ class SRVClientEndpoint(object): host=domain, port=default_port, priority=0, - weight=0 + weight=0, + expires=0, ) else: self.default_server = None @@ -118,7 +151,7 @@ class SRVClientEndpoint(object): return self.default_server else: raise ConnectError( - "Not server available for %s", self.service_name + "Not server available for %s" % self.service_name ) min_priority = self.servers[0].priority @@ -153,7 +186,13 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks -def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): +def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): + cache_entry = cache.get(service_name, None) + if cache_entry: + if all(s.expires > int(clock.time()) for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + servers = [] try: @@ -166,34 +205,33 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): and answers[0].type == dns.SRV and answers[0].payload and answers[0].payload.target == dns.Name('.')): - raise ConnectError("Service %s unavailable", service_name) + raise ConnectError("Service %s unavailable" % service_name) for answer in answers: if answer.type != dns.SRV or not answer.payload: continue payload = answer.payload - host = str(payload.target) + srv_ttl = answer.ttl try: answers, _, _ = yield dns_client.lookupAddress(host) except DNSNameError: continue - ips = [ - answer.payload.dottedQuad() - for answer in answers - if answer.type == dns.A and answer.payload - ] - - for ip in ips: - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight) - )) + for answer in answers: + if answer.type == dns.A and answer.payload: + ip = answer.payload.dottedQuad() + host_ttl = min(srv_ttl, answer.ttl) + + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + host_ttl, + )) servers.sort() cache[service_name] = list(servers) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index cf1414b4db..1adbdd9421 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -41,7 +41,11 @@ REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = { "web_client": { "matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"], - } + }, + "preview_url": { + "lxml>=3.6.0": ["lxml"], + "netaddr>=0.7.18": ["netaddr"], + }, } diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 680dc89536..cfc728a038 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore): "_get_current_state_for_key" ] + get_event = DataStore.get_event.__func__ get_current_state = DataStore.get_current_state.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_rooms_for_user_where_membership_is = ( @@ -89,8 +90,11 @@ class SlavedEventStore(BaseSlavedStore): _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ _parse_events_txn = DataStore._parse_events_txn.__func__ _get_events_txn = DataStore._get_events_txn.__func__ + _enqueue_events = DataStore._enqueue_events.__func__ + _do_fetch = DataStore._do_fetch.__func__ _fetch_events_txn = DataStore._fetch_events_txn.__func__ _fetch_event_rows = DataStore._fetch_event_rows.__func__ + _get_event_from_row = DataStore._get_event_from_row.__func__ _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ _get_rooms_for_user_where_membership_is_txn = ( DataStore._get_rooms_for_user_where_membership_is_txn.__func__ @@ -100,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore): def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() result["events"] = self._stream_id_gen.get_current_token() - result["backfilled"] = self._backfill_id_gen.get_current_token() + result["backfill"] = self._backfill_id_gen.get_current_token() return result def process_replication(self, result): @@ -142,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore): position = row[0] internal = json.loads(row[1]) event_json = json.loads(row[2]) - event = FrozenEvent(event_json, internal_metadata_dict=internal) self._invalidate_caches_for_event( event, backfilled, reset_state=position in state_resets @@ -158,6 +161,8 @@ class SlavedEventStore(BaseSlavedStore): self._invalidate_get_event_cache(event.event_id) + self.get_latest_event_ids_in_room.invalidate((event.room_id,)) + if not backfilled: self._events_stream_cache.entity_has_changed( event.room_id, event.internal_metadata.stream_ordering diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py index 58ef91c0b8..2b1938dc8e 100644 --- a/synapse/rest/media/v1/base_resource.py +++ b/synapse/rest/media/v1/base_resource.py @@ -72,6 +72,7 @@ class BaseMediaResource(Resource): self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels + self.max_spider_size = hs.config.max_spider_size self.filepaths = filepaths self.version_string = hs.version_string self.downloads = {} diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 7dfb027dd1..97b7e84af9 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -17,6 +17,7 @@ from .upload_resource import UploadResource from .download_resource import DownloadResource from .thumbnail_resource import ThumbnailResource from .identicon_resource import IdenticonResource +from .preview_url_resource import PreviewUrlResource from .filepath import MediaFilePaths from twisted.web.resource import Resource @@ -78,3 +79,9 @@ class MediaRepositoryResource(Resource): self.putChild("download", DownloadResource(hs, filepaths)) self.putChild("thumbnail", ThumbnailResource(hs, filepaths)) self.putChild("identicon", IdenticonResource()) + if hs.config.url_preview_enabled: + try: + self.putChild("preview_url", PreviewUrlResource(hs, filepaths)) + except Exception as e: + logger.warn("Failed to mount preview_url") + logger.exception(e) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py new file mode 100644 index 0000000000..4dd97ac0e3 --- /dev/null +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -0,0 +1,461 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base_resource import BaseMediaResource + +from twisted.web.server import NOT_DONE_YET +from twisted.internet import defer +from urlparse import urlparse, urlsplit, urlunparse + +from synapse.api.errors import ( + SynapseError, Codes, +) +from synapse.util.stringutils import random_string +from synapse.util.caches.expiringcache import ExpiringCache +from synapse.http.client import SpiderHttpClient +from synapse.http.server import ( + request_handler, respond_with_json_bytes +) +from synapse.util.async import ObservableDeferred +from synapse.util.stringutils import is_ascii + +import os +import re +import fnmatch +import cgi +import ujson as json + +import logging +logger = logging.getLogger(__name__) + +try: + from lxml import html +except ImportError: + pass + + +class PreviewUrlResource(BaseMediaResource): + isLeaf = True + + def __init__(self, hs, filepaths): + try: + if html: + pass + except: + raise RuntimeError("Disabling PreviewUrlResource as lxml not available") + + if not hasattr(hs.config, "url_preview_ip_range_blacklist"): + logger.warn( + "For security, you must specify an explicit target IP address " + "blacklist in url_preview_ip_range_blacklist for url previewing " + "to work" + ) + raise RuntimeError( + "Disabling PreviewUrlResource as " + "url_preview_ip_range_blacklist not specified" + ) + + BaseMediaResource.__init__(self, hs, filepaths) + self.client = SpiderHttpClient(hs) + if hasattr(hs.config, "url_preview_url_blacklist"): + self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist + + # simple memory cache mapping urls to OG metadata + self.cache = ExpiringCache( + cache_name="url_previews", + clock=self.clock, + # don't spider URLs more often than once an hour + expiry_ms=60 * 60 * 1000, + ) + self.cache.start() + + self.downloads = {} + + def render_GET(self, request): + self._async_render_GET(request) + return NOT_DONE_YET + + @request_handler + @defer.inlineCallbacks + def _async_render_GET(self, request): + + # XXX: if get_user_by_req fails, what should we do in an async render? + requester = yield self.auth.get_user_by_req(request) + url = request.args.get("url")[0] + if "ts" in request.args: + ts = int(request.args.get("ts")[0]) + else: + ts = self.clock.time_msec() + + # impose the URL pattern blacklist + if hasattr(self, "url_preview_url_blacklist"): + url_tuple = urlsplit(url) + for entry in self.url_preview_url_blacklist: + match = True + for attrib in entry: + pattern = entry[attrib] + value = getattr(url_tuple, attrib) + logger.debug(( + "Matching attrib '%s' with value '%s' against" + " pattern '%s'" + ) % (attrib, value, pattern)) + + if value is None: + match = False + continue + + if pattern.startswith('^'): + if not re.match(pattern, getattr(url_tuple, attrib)): + match = False + continue + else: + if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern): + match = False + continue + if match: + logger.warn( + "URL %s blocked by url_blacklist entry %s", url, entry + ) + raise SynapseError( + 403, "URL blocked by url pattern blacklist entry", + Codes.UNKNOWN + ) + + # first check the memory cache - good to handle all the clients on this + # HS thundering away to preview the same URL at the same time. + og = self.cache.get(url) + if og: + respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) + return + + # then check the URL cache in the DB (which will also provide us with + # historical previews, if we have any) + cache_result = yield self.store.get_url_cache(url, ts) + if ( + cache_result and + cache_result["download_ts"] + cache_result["expires"] > ts and + cache_result["response_code"] / 100 == 2 + ): + respond_with_json_bytes( + request, 200, cache_result["og"].encode('utf-8'), + send_cors=True + ) + return + + # Ensure only one download for a given URL is active at a time + download = self.downloads.get(url) + if download is None: + download = self._download_url(url, requester.user) + download = ObservableDeferred( + download, + consumeErrors=True + ) + self.downloads[url] = download + + @download.addBoth + def callback(media_info): + del self.downloads[url] + return media_info + media_info = yield download.observe() + + # FIXME: we should probably update our cache now anyway, so that + # even if the OG calculation raises, we don't keep hammering on the + # remote server. For now, leave it uncached to aid debugging OG + # calculation problems + + logger.debug("got media_info of '%s'" % media_info) + + if self._is_media(media_info['media_type']): + dims = yield self._generate_local_thumbnails( + media_info['filesystem_id'], media_info + ) + + og = { + "og:description": media_info['download_name'], + "og:image": "mxc://%s/%s" % ( + self.server_name, media_info['filesystem_id'] + ), + "og:image:type": media_info['media_type'], + "matrix:image:size": media_info['media_length'], + } + + if dims: + og["og:image:width"] = dims['width'] + og["og:image:height"] = dims['height'] + else: + logger.warn("Couldn't get dims for %s" % url) + + # define our OG response for this media + elif self._is_html(media_info['media_type']): + # TODO: somehow stop a big HTML tree from exploding synapse's RAM + + try: + tree = html.parse(media_info['filename']) + og = yield self._calc_og(tree, media_info, requester) + except UnicodeDecodeError: + # XXX: evil evil bodge + # Empirically, sites like google.com mix Latin-1 and utf-8 + # encodings in the same page. The rogue Latin-1 characters + # cause lxml to choke with a UnicodeDecodeError, so if we + # see this we go and do a manual decode of the HTML before + # handing it to lxml as utf-8 encoding, counter-intuitively, + # which seems to make it happier... + file = open(media_info['filename']) + body = file.read() + file.close() + tree = html.fromstring(body.decode('utf-8', 'ignore')) + og = yield self._calc_og(tree, media_info, requester) + + else: + logger.warn("Failed to find any OG data in %s", url) + og = {} + + logger.debug("Calculated OG for %s as %s" % (url, og)) + + # store OG in ephemeral in-memory cache + self.cache[url] = og + + # store OG in history-aware DB cache + yield self.store.store_url_cache( + url, + media_info["response_code"], + media_info["etag"], + media_info["expires"], + json.dumps(og), + media_info["filesystem_id"], + media_info["created_ts"], + ) + + respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) + + @defer.inlineCallbacks + def _calc_og(self, tree, media_info, requester): + # suck our tree into lxml and define our OG response. + + # if we see any image URLs in the OG response, then spider them + # (although the client could choose to do this by asking for previews of those + # URLs to avoid DoSing the server) + + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : "Fun stuff happening here", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", + + og = {} + for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): + og[tag.attrib['property']] = tag.attrib['content'] + + # TODO: grab article: meta tags too, e.g.: + + # "article:publisher" : "https://www.facebook.com/thethudonline" /> + # "article:author" content="https://www.facebook.com/thethudonline" /> + # "article:tag" content="baby" /> + # "article:section" content="Breaking News" /> + # "article:published_time" content="2016-03-31T19:58:24+00:00" /> + # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> + + if 'og:title' not in og: + # do some basic spidering of the HTML + title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") + og['og:title'] = title[0].text.strip() if title else None + + if 'og:image' not in og: + # TODO: extract a favicon failing all else + meta_image = tree.xpath( + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" + ) + if meta_image: + og['og:image'] = self._rebase_url(meta_image[0], media_info['uri']) + else: + # TODO: consider inlined CSS styles as well as width & height attribs + images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") + images = sorted(images, key=lambda i: ( + -1 * int(i.attrib['width']) * int(i.attrib['height']) + )) + if not images: + images = tree.xpath("//img[@src]") + if images: + og['og:image'] = images[0].attrib['src'] + + # pre-cache the image for posterity + # FIXME: it might be cleaner to use the same flow as the main /preview_url request + # itself and benefit from the same caching etc. But for now we just rely on the + # caching on the master request to speed things up. + if 'og:image' in og and og['og:image']: + image_info = yield self._download_url( + self._rebase_url(og['og:image'], media_info['uri']), requester.user + ) + + if self._is_media(image_info['media_type']): + # TODO: make sure we don't choke on white-on-transparent images + dims = yield self._generate_local_thumbnails( + image_info['filesystem_id'], image_info + ) + if dims: + og["og:image:width"] = dims['width'] + og["og:image:height"] = dims['height'] + else: + logger.warn("Couldn't get dims for %s" % og["og:image"]) + + og["og:image"] = "mxc://%s/%s" % ( + self.server_name, image_info['filesystem_id'] + ) + og["og:image:type"] = image_info['media_type'] + og["matrix:image:size"] = image_info['media_length'] + else: + del og["og:image"] + + if 'og:description' not in og: + meta_description = tree.xpath( + "//*/meta" + "[translate(@name, 'DESCRIPTION', 'description')='description']" + "/@content") + if meta_description: + og['og:description'] = meta_description[0] + else: + # grab any text nodes which are inside the <body/> tag... + # unless they are within an HTML5 semantic markup tag... + # <header/>, <nav/>, <aside/>, <footer/> + # ...or if they are within a <script/> or <style/> tag. + # This is a very very very coarse approximation to a plain text + # render of the page. + text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | " + "ancestor::aside | ancestor::footer | " + "ancestor::script | ancestor::style)]" + + "[ancestor::body]") + text = '' + for text_node in text_nodes: + if len(text) < 500: + text += text_node + ' ' + else: + break + text = re.sub(r'[\t ]+', ' ', text) + text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text) + text = text.strip()[:500] + og['og:description'] = text if text else None + + # TODO: delete the url downloads to stop diskfilling, + # as we only ever cared about its OG + defer.returnValue(og) + + def _rebase_url(self, url, base): + base = list(urlparse(base)) + url = list(urlparse(url)) + if not url[0]: # fix up schema + url[0] = base[0] or "http" + if not url[1]: # fix up hostname + url[1] = base[1] + if not url[2].startswith('/'): + url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] + return urlunparse(url) + + @defer.inlineCallbacks + def _download_url(self, url, user): + # TODO: we should probably honour robots.txt... except in practice + # we're most likely being explicitly triggered by a human rather than a + # bot, so are we really a robot? + + # XXX: horrible duplication with base_resource's _download_remote_file() + file_id = random_string(24) + + fname = self.filepaths.local_media_filepath(file_id) + self._makedirs(fname) + + try: + with open(fname, "wb") as f: + logger.debug("Trying to get url '%s'" % url) + length, headers, uri, code = yield self.client.get_file( + url, output_stream=f, max_size=self.max_spider_size, + ) + # FIXME: pass through 404s and other error messages nicely + + media_type = headers["Content-Type"][0] + time_now_ms = self.clock.time_msec() + + content_disposition = headers.get("Content-Disposition", None) + if content_disposition: + _, params = cgi.parse_header(content_disposition[0],) + download_name = None + + # First check if there is a valid UTF-8 filename + download_name_utf8 = params.get("filename*", None) + if download_name_utf8: + if download_name_utf8.lower().startswith("utf-8''"): + download_name = download_name_utf8[7:] + + # If there isn't check for an ascii name. + if not download_name: + download_name_ascii = params.get("filename", None) + if download_name_ascii and is_ascii(download_name_ascii): + download_name = download_name_ascii + + if download_name: + download_name = urlparse.unquote(download_name) + try: + download_name = download_name.decode("utf-8") + except UnicodeDecodeError: + download_name = None + else: + download_name = None + + yield self.store.store_local_media( + media_id=file_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=download_name, + media_length=length, + user_id=user, + ) + + except Exception as e: + os.remove(fname) + raise SynapseError( + 500, ("Failed to download content: %s" % e), + Codes.UNKNOWN + ) + + defer.returnValue({ + "media_type": media_type, + "media_length": length, + "download_name": download_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + "filename": fname, + "uri": uri, + "response_code": code, + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + "expires": 60 * 60 * 1000, + "etag": headers["ETag"][0] if "ETag" in headers else None, + }) + + def _is_media(self, content_type): + if content_type.lower().startswith("image/"): + return True + + def _is_html(self, content_type): + content_type = content_type.lower() + if ( + content_type.startswith("text/html") or + content_type.startswith("application/xhtml") + ): + return True diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index ab52499785..513b445688 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -72,6 +72,11 @@ class ThumbnailResource(BaseMediaResource): self._respond_404(request) return + if media_info["media_type"] == "image/svg+xml": + file_path = self.filepaths.local_media_filepath(media_id) + yield self._respond_with_file(request, media_info["media_type"], file_path) + return + thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) if thumbnail_infos: @@ -103,6 +108,11 @@ class ThumbnailResource(BaseMediaResource): self._respond_404(request) return + if media_info["media_type"] == "image/svg+xml": + file_path = self.filepaths.local_media_filepath(media_id) + yield self._respond_with_file(request, media_info["media_type"], file_path) + return + thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: t_w = info["thumbnail_width"] == desired_width @@ -138,6 +148,11 @@ class ThumbnailResource(BaseMediaResource): desired_method, desired_type): media_info = yield self._get_remote_media(server_name, media_id) + if media_info["media_type"] == "image/svg+xml": + file_path = self.filepaths.remote_media_filepath(server_name, media_id) + yield self._respond_with_file(request, media_info["media_type"], file_path) + return + thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, ) @@ -181,6 +196,11 @@ class ThumbnailResource(BaseMediaResource): # We should proxy the thumbnail from the remote server instead. media_info = yield self._get_remote_media(server_name, media_id) + if media_info["media_type"] == "image/svg+xml": + file_path = self.filepaths.remote_media_filepath(server_name, media_id) + yield self._respond_with_file(request, media_info["media_type"], file_path) + return + thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, ) @@ -208,6 +228,8 @@ class ThumbnailResource(BaseMediaResource): @defer.inlineCallbacks def _respond_default_thumbnail(self, request, media_info, width, height, method, m_type): + # XXX: how is this meant to work? store.get_default_thumbnails + # appears to always return [] so won't this always 404? media_type = media_info["media_type"] top_level_type = media_type.split("/")[0] sub_type = media_type.split("/")[-1].split(";")[0] diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 04d7fcf6d6..1e27c2c0ce 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -810,12 +810,6 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) - def get_next_stream_id(self): - with self._next_stream_id_lock: - i = self._next_stream_id - self._next_stream_id += 1 - return i - def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): # Fetch a mapping of room_id -> max stream position for "recent" rooms. diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 9d3ba32478..a820fcf07f 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -25,7 +25,7 @@ class MediaRepositoryStore(SQLBaseStore): def get_local_media(self, media_id): """Get the metadata for a local piece of media Returns: - None if the meia_id doesn't exist. + None if the media_id doesn't exist. """ return self._simple_select_one( "local_media_repository", @@ -50,6 +50,61 @@ class MediaRepositoryStore(SQLBaseStore): desc="store_local_media", ) + def get_url_cache(self, url, ts): + """Get the media_id and ts for a cached URL as of the given timestamp + Returns: + None if the URL isn't cached. + """ + def get_url_cache_txn(txn): + # get the most recently cached result (relative to the given ts) + sql = ( + "SELECT response_code, etag, expires, og, media_id, download_ts" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts <= ?" + " ORDER BY download_ts DESC LIMIT 1" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row: + # ...or if we've requested a timestamp older than the oldest + # copy in the cache, return the oldest copy (if any) + sql = ( + "SELECT response_code, etag, expires, og, media_id, download_ts" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts > ?" + " ORDER BY download_ts ASC LIMIT 1" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row: + return None + + return dict(zip(( + 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts' + ), row)) + + return self.runInteraction( + "get_url_cache", get_url_cache_txn + ) + + def store_url_cache(self, url, response_code, etag, expires, og, media_id, + download_ts): + return self._simple_insert( + "local_media_repository_url_cache", + { + "url": url, + "response_code": response_code, + "etag": etag, + "expires": expires, + "og": og, + "media_id": media_id, + "download_ts": download_ts, + }, + desc="store_url_cache", + ) + def get_local_media_thumbnails(self, media_id): return self._simple_select_list( "local_media_repository_thumbnails", diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 00833422af..57f14fd12b 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -30,18 +30,6 @@ SCHEMA_VERSION = 31 dir_path = os.path.abspath(os.path.dirname(__file__)) -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - class PrepareDatabaseException(Exception): pass diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 59b4ef5ce6..07f5fae8dd 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -176,16 +176,6 @@ class PresenceStore(SQLBaseStore): desc="disallow_presence_visible", ) - def is_presence_visible(self, observed_localpart, observer_userid): - return self._simple_select_one( - table="presence_allow_inbound", - keyvalues={"observed_user_id": observed_localpart, - "observer_user_id": observer_userid}, - retcols=["observed_user_id"], - allow_none=True, - desc="is_presence_visible", - ) - def add_presence_list_pending(self, observer_localpart, observed_userid): return self._simple_insert( table="presence_list", diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 088ad0f914..08a54cbdd1 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -124,26 +124,6 @@ class RoomMemberStore(SQLBaseStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) - def get_room_member(self, user_id, room_id): - """Retrieve the current state of a room member. - - Args: - user_id (str): The member's user ID. - room_id (str): The room the member is in. - Returns: - Deferred: Results in a MembershipEvent or None. - """ - return self.runInteraction( - "get_room_member", - self._get_members_events_txn, - room_id, - user_id=user_id, - ).addCallback( - self._get_events - ).addCallback( - lambda events: events[0] if events else None - ) - @cached(max_entries=5000) def get_users_in_room(self, room_id): def f(txn): @@ -206,19 +186,6 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(invite) defer.returnValue(None) - def get_leave_and_ban_events_for_user(self, user_id): - """ Get all the leave events for a user - Args: - user_id (str): The user ID. - Returns: - A deferred list of event objects. - """ - return self.get_rooms_for_user_where_membership_is( - user_id, (Membership.LEAVE, Membership.BAN) - ).addCallback(lambda leaves: self._get_events([ - leave.event_id for leave in leaves - ])) - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user matches one in the membership list. diff --git a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql new file mode 100644 index 0000000000..9efb4280eb --- /dev/null +++ b/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql @@ -0,0 +1,27 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE local_media_repository_url_cache( + url TEXT, -- the URL being cached + response_code INTEGER, -- the HTTP response code of this download attempt + etag TEXT, -- the etag header of this response + expires INTEGER, -- the number of ms this response was valid for + og TEXT, -- cache of the OG metadata of this URL as JSON + media_id TEXT, -- the media_id, if any, of the URL's content in the repo + download_ts BIGINT -- the timestamp of this download attempt +); + +CREATE INDEX local_media_repository_url_cache_by_url_download_ts + ON local_media_repository_url_cache(url, download_ts); diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 76bcd9cd00..95b12559a6 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -303,96 +303,6 @@ class StreamStore(SQLBaseStore): defer.returnValue(ret) - def get_room_events_stream( - self, - user_id, - from_key, - to_key, - limit=0, - is_guest=False, - room_ids=None - ): - room_ids = room_ids or [] - room_ids = [r for r in room_ids] - if is_guest: - current_room_membership_sql = ( - "SELECT c.room_id FROM history_visibility AS h" - " INNER JOIN current_state_events AS c" - " ON h.event_id = c.event_id" - " WHERE c.room_id IN (%s)" - " AND h.history_visibility = 'world_readable'" % ( - ",".join(map(lambda _: "?", room_ids)) - ) - ) - current_room_membership_args = room_ids - else: - current_room_membership_sql = ( - "SELECT m.room_id FROM room_memberships as m " - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id AND c.state_key = m.user_id" - " WHERE m.user_id = ? AND m.membership = 'join'" - ) - current_room_membership_args = [user_id] - - # We also want to get any membership events about that user, e.g. - # invites or leave notifications. - membership_sql = ( - "SELECT m.event_id FROM room_memberships as m " - "INNER JOIN current_state_events as c ON m.event_id = c.event_id " - "WHERE m.user_id = ? " - ) - membership_args = [user_id] - - if limit: - limit = max(limit, MAX_STREAM_SIZE) - else: - limit = MAX_STREAM_SIZE - - # From and to keys should be integers from ordering. - from_id = RoomStreamToken.parse_stream_token(from_key) - to_id = RoomStreamToken.parse_stream_token(to_key) - - if from_key == to_key: - return defer.succeed(([], to_key)) - - sql = ( - "SELECT e.event_id, e.stream_ordering FROM events AS e WHERE " - "(e.outlier = ? AND (room_id IN (%(current)s)) OR " - "(event_id IN (%(invites)s))) " - "AND e.stream_ordering > ? AND e.stream_ordering <= ? " - "ORDER BY stream_ordering ASC LIMIT %(limit)d " - ) % { - "current": current_room_membership_sql, - "invites": membership_sql, - "limit": limit - } - - def f(txn): - args = ([False] + current_room_membership_args + membership_args + - [from_id.stream, to_id.stream]) - txn.execute(sql, args) - - rows = self.cursor_to_dict(txn) - - ret = self._get_events_txn( - txn, - [r["event_id"] for r in rows], - get_prev_content=True - ) - - self._set_before_and_after(ret, rows) - - if rows: - key = "s%d" % max(r["stream_ordering"] for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = to_key - - return ret, key - - return self.runInteraction("get_room_events_stream", f) - @defer.inlineCallbacks def paginate_room_events(self, room_id, from_key, to_key=None, direction='b', limit=-1): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f69f1cdad4..46cf93ff87 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -112,7 +112,7 @@ class StreamIdGenerator(object): self._current + self._step * (n + 1), self._step ) - self._current += n + self._current += n * self._step for next_id in next_ids: self._unfinished_ids.append(next_id) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 3b9da5b34a..b462495eb8 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -49,9 +49,6 @@ class Clock(object): l.start(msec / 1000.0, now=False) return l - def stop_looping_call(self, loop): - loop.stop() - def call_later(self, delay, callback, *args, **kwargs): """Call something later diff --git a/synapse/util/async.py b/synapse/util/async.py index cd4d90f3cf..0d6f48e2d8 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,9 +16,13 @@ from twisted.internet import defer, reactor -from .logcontext import PreserveLoggingContext, preserve_fn +from .logcontext import ( + PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, +) from synapse.util import unwrapFirstError +from contextlib import contextmanager + @defer.inlineCallbacks def sleep(seconds): @@ -137,3 +141,47 @@ def concurrently_execute(func, args, limit): preserve_fn(_concurrently_execute_inner)() for _ in xrange(limit) ], consumeErrors=True).addErrback(unwrapFirstError) + + +class Linearizer(object): + """Linearizes access to resources based on a key. Useful to ensure only one + thing is happening at a time on a given resource. + + Example: + + with (yield linearizer.queue("test_key")): + # do some work. + + """ + def __init__(self): + self.key_to_defer = {} + + @defer.inlineCallbacks + def queue(self, key): + # If there is already a deferred in the queue, we pull it out so that + # we can wait on it later. + # Then we replace it with a deferred that we resolve *after* the + # context manager has exited. + # We only return the context manager after the previous deferred has + # resolved. + # This all has the net effect of creating a chain of deferreds that + # wait for the previous deferred before starting their work. + current_defer = self.key_to_defer.get(key) + + new_defer = defer.Deferred() + self.key_to_defer[key] = new_defer + + if current_defer: + yield preserve_context_over_deferred(current_defer) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + current_d = self.key_to_defer.get(key) + if current_d is new_defer: + self.key_to_defer.pop(key, None) + + defer.returnValue(_ctx_manager()) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index be310ba320..36686b479e 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -35,7 +35,7 @@ class ResponseCache(object): return None def set(self, key, deferred): - result = ObservableDeferred(deferred) + result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result def remove(r): diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 8875813de4..d7cccc06b1 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -15,7 +15,9 @@ from twisted.internet import defer -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_context_over_fn +) from synapse.util import unwrapFirstError @@ -25,6 +27,24 @@ import logging logger = logging.getLogger(__name__) +def registered_user(distributor, user): + return distributor.fire("registered_user", user) + + +def user_left_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_left_room", user=user, room_id=room_id + ) + + +def user_joined_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_joined_room", user=user, room_id=room_id + ) + + class Distributor(object): """A central dispatch point for loosely-connected pieces of code to register, observe, and fire signals. diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 4076eed269..1101881a2d 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -100,20 +100,6 @@ class _PerHostRatelimiter(object): self.current_processing = set() self.request_times = [] - def is_empty(self): - time_now = self.clock.time_msec() - self.request_times[:] = [ - r for r in self.request_times - if time_now - r < self.window_size - ] - - return not ( - self.ready_request_queue - or self.sleeping_requests - or self.current_processing - or self.request_times - ) - @contextlib.contextmanager def ratelimit(self): # `contextlib.contextmanager` takes a generator and turns it into a diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index b490bb8725..a100f151d4 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -21,10 +21,6 @@ _string_with_symbols = ( ) -def origin_from_ucid(ucid): - return ucid.split("@", 1)[1] - - def random_string(length): return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 0f525a8943..983caafe8a 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): def check(self, method, args, expected_result=None): master_result = yield getattr(self.master_store, method)(*args) slaved_result = yield getattr(self.slaved_store, method)(*args) - self.assertEqual(master_result, slaved_result) if expected_result is not None: self.assertEqual(master_result, expected_result) self.assertEqual(slaved_result, expected_result) + self.assertEqual(master_result, slaved_result) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 351d777fb2..baa4a26eb5 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -14,20 +14,47 @@ from ._base import BaseSlavedStoreTestCase -from synapse.events import FrozenEvent +from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext from synapse.storage.roommember import RoomsForUser from twisted.internet import defer + USER_ID = "@feeling:blue" USER_ID_2 = "@bright:blue" OUTLIER = {"outlier": True} ROOM_ID = "!room:blue" +def dict_equals(self, other): + return self.__dict__ == other.__dict__ + + +def patch__eq__(cls): + eq = getattr(cls, "__eq__", None) + cls.__eq__ = dict_equals + + def unpatch(): + if eq is not None: + cls.__eq__ = eq + return unpatch + + class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): + def setUp(self): + # Patch up the equality operator for events so that we can check + # whether lists of events match using assertEquals + self.unpatches = [ + patch__eq__(_EventInternalMetadata), + patch__eq__(FrozenEvent), + ] + return super(SlavedEventStoreTestCase, self).setUp() + + def tearDown(self): + [unpatch() for unpatch in self.unpatches] + @defer.inlineCallbacks def test_room_name_and_aliases(self): create = yield self.persist(type="m.room.create", key="", creator=USER_ID) @@ -116,13 +143,121 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) yield self.check("get_rooms_for_user", (USER_ID_2,), []) + @defer.inlineCallbacks + def test_get_latest_event_ids_in_room(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.replicate() + yield self.check( + "get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id] + ) + + join = yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + prev_events=[(create.event_id, {})], + ) + yield self.replicate() + yield self.check( + "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id] + ) + + @defer.inlineCallbacks + def test_get_current_state(self): + # Create the room. + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), [] + ) + + # Join the room. + join1 = yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + ) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), + [join1] + ) + + # Add some other user to the room. + join2 = yield self.persist( + type="m.room.member", key=USER_ID_2, membership="join", + ) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2), + [join2] + ) + + # Leave the room, then rejoin the room clobbering state. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + join3 = yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2), + [] + ) + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), + [join3] + ) + + @defer.inlineCallbacks + def test_redactions(self): + yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = yield self.persist( + type="m.room.message", msgtype="m.text", body="Hello" + ) + yield self.replicate() + yield self.check("get_event", [msg.event_id], msg) + + redaction = yield self.persist( + type="m.room.redaction", redacts=msg.event_id + ) + yield self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + yield self.check("get_event", [msg.event_id], redacted) + + @defer.inlineCallbacks + def test_backfilled_redactions(self): + yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = yield self.persist( + type="m.room.message", msgtype="m.text", body="Hello" + ) + yield self.replicate() + yield self.check("get_event", [msg.event_id], msg) + + redaction = yield self.persist( + type="m.room.redaction", redacts=msg.event_id, backfill=True + ) + yield self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + yield self.check("get_event", [msg.event_id], redacted) + event_id = 0 @defer.inlineCallbacks def persist( self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, state=None, reset_state=False, backfill=False, - depth=None, prev_events=[], auth_events=[], prev_state=[], + depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None, **content ): """ @@ -147,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event_dict["state_key"] = key event_dict["prev_state"] = prev_state + if redacts is not None: + event_dict["redacts"] = redacts + event = FrozenEvent(event_dict, internal_metadata_dict=internal) self.event_id += 1 diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index ec78f007ca..63203cea35 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -35,33 +35,6 @@ class PresenceStoreTestCase(unittest.TestCase): self.u_banana = UserID.from_string("@banana:test") @defer.inlineCallbacks - def test_visibility(self): - self.assertFalse((yield self.store.is_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ))) - - yield self.store.allow_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ) - - self.assertTrue((yield self.store.is_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ))) - - yield self.store.disallow_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ) - - self.assertFalse((yield self.store.is_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ))) - - @defer.inlineCallbacks def test_presence_list(self): self.assertEquals( [], diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 5880409867..6afaca3a61 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -110,22 +110,10 @@ class RedactionTestCase(unittest.TestCase): self.room1, self.u_alice, Membership.JOIN ) - start = yield self.store.get_room_events_max_id() - msg_event = yield self.inject_message(self.room1, self.u_alice, u"t") - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - # Check event has not been redacted: - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertObjectHasAttributes( { @@ -144,17 +132,7 @@ class RedactionTestCase(unittest.TestCase): self.room1, msg_event.event_id, self.u_alice, reason ) - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - # Check redaction - - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertEqual(msg_event.event_id, event.event_id) @@ -184,25 +162,12 @@ class RedactionTestCase(unittest.TestCase): self.room1, self.u_alice, Membership.JOIN ) - start = yield self.store.get_room_events_max_id() - msg_event = yield self.inject_room_member( self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}, ) - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - # Check event has not been redacted: - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertObjectHasAttributes( { @@ -221,17 +186,9 @@ class RedactionTestCase(unittest.TestCase): self.room1, msg_event.event_id, self.u_alice, reason ) - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - # Check redaction - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertTrue("redacted_because" in event.unsigned) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index b029ff0584..997090fe35 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -71,13 +71,6 @@ class RoomMemberStoreTestCase(unittest.TestCase): yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) self.assertEquals( - Membership.JOIN, - (yield self.store.get_room_member( - user_id=self.u_alice.to_string(), - room_id=self.room.to_string(), - )).membership - ) - self.assertEquals( [self.u_alice.to_string()], [m.user_id for m in ( yield self.store.get_room_members(self.room.to_string()) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py deleted file mode 100644 index da322152c7..0000000000 --- a/tests/storage/test_stream.py +++ /dev/null @@ -1,185 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from tests import unittest -from twisted.internet import defer - -from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID, RoomID -from tests.storage.event_injector import EventInjector - -from tests.utils import setup_test_homeserver - -from mock import Mock - - -class StreamStoreTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver( - resource_for_federation=Mock(), - http_client=None, - ) - - self.store = hs.get_datastore() - self.event_builder_factory = hs.get_event_builder_factory() - self.event_injector = EventInjector(hs) - self.handlers = hs.get_handlers() - self.message_handler = self.handlers.message_handler - - self.u_alice = UserID.from_string("@alice:test") - self.u_bob = UserID.from_string("@bob:test") - - self.room1 = RoomID.from_string("!abc123:test") - self.room2 = RoomID.from_string("!xyx987:test") - - @defer.inlineCallbacks - def test_event_stream_get_other(self): - # Both bob and alice joins the room - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - # Initial stream key: - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_bob.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - event = results[0] - - self.assertObjectHasAttributes( - { - "type": EventTypes.Message, - "user_id": self.u_alice.to_string(), - "content": {"body": "test", "msgtype": "message"}, - }, - event, - ) - - @defer.inlineCallbacks - def test_event_stream_get_own(self): - # Both bob and alice joins the room - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - # Initial stream key: - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - event = results[0] - - self.assertObjectHasAttributes( - { - "type": EventTypes.Message, - "user_id": self.u_alice.to_string(), - "content": {"body": "test", "msgtype": "message"}, - }, - event, - ) - - @defer.inlineCallbacks - def test_event_stream_join_leave(self): - # Both bob and alice joins the room - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - # Then bob leaves again. - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.LEAVE - ) - - # Initial stream key: - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_bob.to_string(), - start, - end, - ) - - # We should not get the message, as it happened *after* bob left. - self.assertEqual(0, len(results)) - - @defer.inlineCallbacks - def test_event_stream_prev_content(self): - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN, - ) - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_bob.to_string(), - start, - end, - ) - - # We should not get the message, as it happened *after* bob left. - self.assertEqual(1, len(results)) - - event = results[0] - - self.assertTrue( - "prev_content" in event.unsigned, - msg="No prev_content key" - ) diff --git a/tests/test_dns.py b/tests/test_dns.py index 637b1606f8..c394c57ee7 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -21,6 +21,8 @@ from mock import Mock from synapse.http.endpoint import resolve_service +from tests.utils import MockClock + class DnsTestCase(unittest.TestCase): @@ -63,14 +65,17 @@ class DnsTestCase(unittest.TestCase): self.assertEquals(servers[0].host, ip_address) @defer.inlineCallbacks - def test_from_cache(self): + def test_from_cache_expired_and_dns_fail(self): dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) service_name = "test_service.examle.com" + entry = Mock(spec_set=["expires"]) + entry.expires = 0 + cache = { - service_name: [object()] + service_name: [entry] } servers = yield resolve_service( @@ -83,6 +88,31 @@ class DnsTestCase(unittest.TestCase): self.assertEquals(servers, cache[service_name]) @defer.inlineCallbacks + def test_from_cache(self): + clock = MockClock() + + dns_client_mock = Mock(spec_set=['lookupService']) + dns_client_mock.lookupService = Mock(spec_set=[]) + + service_name = "test_service.examle.com" + + entry = Mock(spec_set=["expires"]) + entry.expires = 999999999 + + cache = { + service_name: [entry] + } + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache, clock=clock, + ) + + self.assertFalse(dns_client_mock.lookupService.called) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + + @defer.inlineCallbacks def test_empty_cache(self): dns_client_mock = Mock() diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py new file mode 100644 index 0000000000..afcba482f9 --- /dev/null +++ b/tests/util/test_linearizer.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest + +from twisted.internet import defer + +from synapse.util.async import Linearizer + + +class LinearizerTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def test_linearizer(self): + linearizer = Linearizer() + + key = object() + + d1 = linearizer.queue(key) + cm1 = yield d1 + + d2 = linearizer.queue(key) + self.assertFalse(d2.called) + + with cm1: + self.assertFalse(d2.called) + + self.assertTrue(d2.called) + + with (yield d2): + pass |