From 7dd0c1730a1ea5962a77b9bbb883c1690b25b686 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 24 Jan 2016 18:47:27 -0500 Subject: initial WIP of a tentative preview_url endpoint - incomplete, untested, experimental, etc. just putting it here for safekeeping for now --- synapse/config/repository.py | 6 +- synapse/http/client.py | 81 +++++++++++++ synapse/rest/media/v1/media_repository.py | 3 + synapse/rest/media/v1/preview_url_resource.py | 164 ++++++++++++++++++++++++++ 4 files changed, 253 insertions(+), 1 deletion(-) create mode 100644 synapse/rest/media/v1/preview_url_resource.py (limited to 'synapse') diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2fcf872449..33fff5616d 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -53,6 +53,7 @@ class ContentRepositoryConfig(Config): def read_config(self, config): self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"]) + self.max_spider_size = self.parse_size(config["max_spider_size"]) self.media_store_path = self.ensure_directory(config["media_store_path"]) self.uploads_path = self.ensure_directory(config["uploads_path"]) self.dynamic_thumbnails = config["dynamic_thumbnails"] @@ -73,6 +74,9 @@ class ContentRepositoryConfig(Config): # The largest allowed upload size in bytes max_upload_size: "10M" + # The largest allowed URL preview spidering size in bytes + max_spider_size: "10M" + # Maximum number of pixels that will be thumbnailed max_image_pixels: "32M" @@ -80,7 +84,7 @@ class ContentRepositoryConfig(Config): # the resolution requested by the client. If true then whenever # a new resolution is requested by the client the server will # generate a new thumbnail. If false the server will pick a thumbnail - # from a precalcualted list. + # from a precalculated list. dynamic_thumbnails: false # List of thumbnail to precalculate when an image is uploaded. diff --git a/synapse/http/client.py b/synapse/http/client.py index fdd90b1c3c..25d319f126 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -238,6 +238,87 @@ class SimpleHttpClient(object): else: raise CodeMessageException(response.code, body) + # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. + # The two should be factored out. + + @defer.inlineCallbacks + def get_file(self, url, output_stream, args={}, max_size=None): + """GETs a file from a given URL + Args: + url (str): The URL to GET + output_stream (file): File to write the response body to. + Returns: + A (int,dict) tuple of the file length and a dict of the response + headers. + """ + + def body_callback(method, url_bytes, headers_dict): + self.sign_request(destination, method, url_bytes, headers_dict) + return None + + response = yield self.request( + "GET", + url.encode("ascii"), + headers=Headers({ + b"User-Agent": [self.user_agent], + }) + ) + + headers = dict(response.headers.getAllRawHeaders()) + + if headers['Content-Length'] > max_size: + logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) + # XXX: do we want to explicitly drop the connection here somehow? if so, how? + raise # what should we be raising here? + + # TODO: if our Content-Type is HTML or something, just read the first + # N bytes into RAM rather than saving it all to disk only to read it + # straight back in again + + try: + length = yield preserve_context_over_fn( + _readBodyToFile, + response, output_stream, max_size + ) + except: + logger.exception("Failed to download body") + raise + + defer.returnValue((length, headers)) + + +# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. +# The two should be factored out. + +class _ReadBodyToFileProtocol(protocol.Protocol): + def __init__(self, stream, deferred, max_size): + self.stream = stream + self.deferred = deferred + self.length = 0 + self.max_size = max_size + + def dataReceived(self, data): + self.stream.write(data) + self.length += len(data) + if self.max_size is not None and self.length >= self.max_size: + logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) + self.deferred = defer.Deferred() + self.transport.loseConnection() + + def connectionLost(self, reason): + if reason.check(ResponseDone): + self.deferred.callback(self.length) + else: + self.deferred.errback(reason) + + +# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. +# The two should be factored out. + +def _readBodyToFile(response, stream, max_size): + d = defer.Deferred() + response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) + return d class CaptchaServerHttpClient(SimpleHttpClient): """ diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 7dfb027dd1..8f3491b91c 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -17,6 +17,7 @@ from .upload_resource import UploadResource from .download_resource import DownloadResource from .thumbnail_resource import ThumbnailResource from .identicon_resource import IdenticonResource +from .preview_url_resource import PreviewUrlResource from .filepath import MediaFilePaths from twisted.web.resource import Resource @@ -78,3 +79,5 @@ class MediaRepositoryResource(Resource): self.putChild("download", DownloadResource(hs, filepaths)) self.putChild("thumbnail", ThumbnailResource(hs, filepaths)) self.putChild("identicon", IdenticonResource()) + self.putChild("preview_url", PreviewUrlResource(hs, filepaths)) + diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py new file mode 100644 index 0000000000..fb8ab3096f --- /dev/null +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -0,0 +1,164 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.resource import Resource +from lxml import html +from synapse.http.client import SimpleHttpClient +from synapse.http.server import respond_with_json_bytes +from simplejson import json + +import logging +logger = logging.getLogger(__name__) + +class PreviewUrlResource(Resource): + isLeaf = True + + def __init__(self, hs, filepaths): + Resource.__init__(self) + self.client = SimpleHttpClient(hs) + self.filepaths = filepaths + self.max_spider_size = hs.config.max_spider_size + self.server_name = hs.hostname + self.clock = hs.get_clock() + + def render_GET(self, request): + self._async_render_GET(request) + return NOT_DONE_YET + + @request_handler + @defer.inlineCallbacks + def _async_render_GET(self, request): + url = request.args.get("url") + + try: + # TODO: keep track of whether there's an ongoing request for this preview + # and block and return their details if there is one. + + media_info = self._download_url(url) + except: + os.remove(fname) + raise + + if self._is_media(media_type): + dims = yield self._generate_local_thumbnails( + media_info.filesystem_id, media_info + ) + + og = { + "og:description" : media_info.download_name, + "og:image" : "mxc://%s/%s" % (self.server_name, media_info.filesystem_id), + "og:image:type" : media_info.media_type, + "og:image:width" : dims.width, + "og:image:height" : dims.height, + } + + # define our OG response for this media + elif self._is_html(media_type): + tree = html.parse(media_info.filename) + + # suck it up into lxml and define our OG response. + # if we see any URLs in the OG response, then spider them + # (although the client could choose to do this by asking for previews of those URLs to avoid DoSing the server) + + # "og:type" : "article" + # "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672" + # "og:title" : "Matrix on Twitter" + # "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png" + # "og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”" + # "og:site_name" : "Twitter" + + og = {} + for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): + og[tag.attrib['property']] = tag.attrib['content'] + + # TODO: store our OG details in a cache (and expire them when stale) + # TODO: delete the content to stop diskfilling, as we only ever cared about its OG + + respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) + + def _download_url(url): + requester = yield self.auth.get_user_by_req(request) + + # XXX: horrible duplication with base_resource's _download_remote_file() + file_id = random_string(24) + + fname = self.filepaths.local_media_filepath(file_id) + self._makedirs(fname) + + try: + with open(fname, "wb") as f: + length, headers = yield self.client.get_file( + url, output_stream=f, max_size=self.max_spider_size, + ) + media_type = headers["Content-Type"][0] + time_now_ms = self.clock.time_msec() + + content_disposition = headers.get("Content-Disposition", None) + if content_disposition: + _, params = cgi.parse_header(content_disposition[0],) + download_name = None + + # First check if there is a valid UTF-8 filename + download_name_utf8 = params.get("filename*", None) + if download_name_utf8: + if download_name_utf8.lower().startswith("utf-8''"): + download_name = download_name_utf8[7:] + + # If there isn't check for an ascii name. + if not download_name: + download_name_ascii = params.get("filename", None) + if download_name_ascii and is_ascii(download_name_ascii): + download_name = download_name_ascii + + if download_name: + download_name = urlparse.unquote(download_name) + try: + download_name = download_name.decode("utf-8") + except UnicodeDecodeError: + download_name = None + else: + download_name = None + + yield self.store.store_local_media( + media_id=fname, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=download_name, + media_length=length, + user_id=requester.user, + ) + + except: + os.remove(fname) + raise + + return { + "media_type": media_type, + "media_length": length, + "download_name": download_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + "filename": fname, + } + + + def _is_media(content_type): + if content_type.lower().startswith("image/"): + return True + + def _is_html(content_type): + content_type = content_type.lower() + if content_type == "text/html" or + content_type.startswith("application/xhtml"): + return True -- cgit 1.4.1 From 3b554bda267402fd43a9e462eccf4060077f37dc Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 24 Mar 2016 13:19:39 +0000 Subject: Never notify for member events. This fixes https://github.com/vector-im/vector-web/issues/828 --- synapse/push/baserules.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) (limited to 'synapse') diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 86a2998bcc..792af70eb7 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -160,7 +160,27 @@ BASE_APPEND_OVRRIDE_RULES = [ 'actions': [ 'dont_notify', ] - } + }, + # Will we sometimes want to know about people joining and leaving? + # Perhaps: if so, this could be expanded upon. Seems the most usual case + # is that we don't though. We add this override rule so that even if + # the room rule is set to notify, we don't get notifications about + # join/leave/avatar/displayname events. + # See also: https://matrix.org/jira/browse/SYN-607 + { + 'rule_id': 'global/override/.m.rule.member_event', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.member', + '_id': '_member', + } + ], + 'actions': [ + 'dont_notify' + ] + }, ] @@ -261,25 +281,6 @@ BASE_APPEND_UNDERRIDE_RULES = [ } ] }, - # This is too simple: https://matrix.org/jira/browse/SYN-607 - # Removing for now - # { - # 'rule_id': 'global/underride/.m.rule.member_event', - # 'conditions': [ - # { - # 'kind': 'event_match', - # 'key': 'type', - # 'pattern': 'm.room.member', - # '_id': '_member', - # } - # ], - # 'actions': [ - # 'notify', { - # 'set_tweak': 'highlight', - # 'value': False - # } - # ] - # }, { 'rule_id': 'global/underride/.m.rule.message', 'conditions': [ -- cgit 1.4.1 From 191c7bef6bbb80f66f66e95387940c3bb6b5a0cf Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 24 Mar 2016 17:47:31 +0000 Subject: Deduplicate identical /sync requests --- synapse/handlers/sync.py | 16 +++++++++++- synapse/rest/client/v2_alpha/sync.py | 3 +++ synapse/util/caches/response_cache.py | 46 +++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 synapse/util/caches/response_cache.py (limited to 'synapse') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1f6fde8e8a..48ab5707e1 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -20,6 +20,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util import unwrapFirstError from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.metrics import Measure +from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user from twisted.internet import defer @@ -35,6 +36,7 @@ SyncConfig = collections.namedtuple("SyncConfig", [ "user", "filter_collection", "is_guest", + "request_key", ]) @@ -136,8 +138,8 @@ class SyncHandler(BaseHandler): super(SyncHandler, self).__init__(hs) self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() + self.response_cache = ResponseCache() - @defer.inlineCallbacks def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): """Get the sync for a client if we have new data for it now. Otherwise @@ -146,7 +148,19 @@ class SyncHandler(BaseHandler): Returns: A Deferred SyncResult. """ + result = self.response_cache.get(sync_config.request_key) + if not result: + result = self.response_cache.set( + sync_config.request_key, + self._wait_for_sync_for_user( + sync_config, since_token, timeout, full_state + ) + ) + return result + @defer.inlineCallbacks + def _wait_for_sync_for_user(self, sync_config, since_token, timeout, + full_state): context = LoggingContext.current_context() if context: if since_token is None: diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index de4a020ad4..c5785d7074 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -115,6 +115,8 @@ class SyncRestServlet(RestServlet): ) ) + request_key = (user, timeout, since, filter_id, full_state) + if filter_id: if filter_id.startswith('{'): try: @@ -134,6 +136,7 @@ class SyncRestServlet(RestServlet): user=user, filter_collection=filter, is_guest=requester.is_guest, + request_key=request_key, ) if since is not None: diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py new file mode 100644 index 0000000000..1c2e344269 --- /dev/null +++ b/synapse/util/caches/response_cache.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.async import ObservableDeferred + + +class ResponseCache(object): + """ + This caches a deferred response. Until the deferred completes it will be + returned from the cache. This means that if the client retries the request + while the response is still being computed, that original response will be + used rather than trying to compute a new response. + """ + + def __init__(self): + self.pending_result_cache = {} # Request that haven't finished yet. + + def get(self, key): + result = self.pending_result_cache.get(key) + if result is not None: + return result.observe() + else: + return None + + def set(self, key, deferred): + result = ObservableDeferred(deferred) + self.pending_result_cache[key] = result + + def remove(r): + self.pending_result_cache.pop(key, None) + return r + + result.addBoth(remove) + return result.observe() -- cgit 1.4.1 From 54a546091abdf70c740d1e59b025e79c44df7455 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 24 Mar 2016 18:02:10 +0000 Subject: Add a response cache for getting the public room list --- synapse/handlers/room.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d5c56ce0d6..133183a257 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,6 +25,7 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError, StoreError, SynapseError, Codes from synapse.util import stringutils, unwrapFirstError from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.caches.response_cache import ResponseCache from signedjson.sign import verify_signed_json from signedjson.key import decode_verify_key_bytes @@ -939,9 +940,18 @@ class RoomMemberHandler(BaseHandler): class RoomListHandler(BaseHandler): + def __init__(self, hs): + super(RoomListHandler, self).__init__(hs) + self.response_cache = ResponseCache() - @defer.inlineCallbacks def get_public_room_list(self): + result = self.response_cache.get(()) + if not result: + result = self.response_cache.set((), self._get_public_room_list()) + return result + + @defer.inlineCallbacks + def _get_public_room_list(self): room_ids = yield self.store.get_public_room_ids() @defer.inlineCallbacks -- cgit 1.4.1 From 77cba688edb9216f5c578c931e96142722641b70 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 24 Mar 2016 18:02:37 +0000 Subject: Fix typo --- synapse/util/caches/response_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 1c2e344269..be310ba320 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -25,7 +25,7 @@ class ResponseCache(object): """ def __init__(self): - self.pending_result_cache = {} # Request that haven't finished yet. + self.pending_result_cache = {} # Requests that haven't finished yet. def get(self, key): result = self.pending_result_cache.get(key) -- cgit 1.4.1 From adafa24b0a8f539c114c7d45f36f7b62743557f6 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Fri, 25 Mar 2016 23:38:19 +0000 Subject: typo --- synapse/replication/resource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 8c1ae0fbc7..37a1d3960c 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -76,7 +76,7 @@ class ReplicationResource(Resource): The response is a JSON object with keys for each stream with updates. Under each key is a JSON object with: - * "postion": The current position of the stream. + * "position": The current position of the stream. * "field_names": The names of the fields in each row. * "rows": The updates as an array of arrays. -- cgit 1.4.1 From ec0cf996c94cb11f2a9b51369b886fb275b26ee5 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Fri, 25 Mar 2016 23:38:19 +0000 Subject: typo --- synapse/replication/resource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 8c1ae0fbc7..37a1d3960c 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -76,7 +76,7 @@ class ReplicationResource(Resource): The response is a JSON object with keys for each stream with updates. Under each key is a JSON object with: - * "postion": The current position of the stream. + * "position": The current position of the stream. * "field_names": The names of the fields in each row. * "rows": The updates as an array of arrays. -- cgit 1.4.1 From dd4287ca5d0c3e3df566748e0dd6ab36398f64b4 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Tue, 29 Mar 2016 02:07:57 +0100 Subject: make it build --- synapse/http/client.py | 2 +- synapse/python_dependencies.py | 1 + synapse/rest/media/v1/preview_url_resource.py | 17 +++++++++-------- 3 files changed, 11 insertions(+), 9 deletions(-) (limited to 'synapse') diff --git a/synapse/http/client.py b/synapse/http/client.py index 127690e534..a735300db0 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -21,7 +21,7 @@ import synapse.metrics from canonicaljson import encode_canonical_json -from twisted.internet import defer, reactor, ssl +from twisted.internet import defer, reactor, ssl, protocol from twisted.web.client import ( Agent, readBody, FileBodyProducer, PartialDownloadError, ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 0a6043ae8d..d12ef15043 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -36,6 +36,7 @@ REQUIREMENTS = { "blist": ["blist"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pymacaroons-pynacl": ["pymacaroons"], + "lxml>=3.6.0": ["lxml"], } CONDITIONAL_REQUIREMENTS = { "web_client": { diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index fb8ab3096f..5c8e20e23c 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -13,10 +13,11 @@ # limitations under the License. from twisted.web.resource import Resource +from twisted.internet import defer from lxml import html from synapse.http.client import SimpleHttpClient -from synapse.http.server import respond_with_json_bytes -from simplejson import json +from synapse.http.server import request_handler, respond_with_json_bytes +import ujson as json import logging logger = logging.getLogger(__name__) @@ -75,7 +76,7 @@ class PreviewUrlResource(Resource): # "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672" # "og:title" : "Matrix on Twitter" # "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png" - # "og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”" + # "og:description" : "Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP" # "og:site_name" : "Twitter" og = {} @@ -143,15 +144,15 @@ class PreviewUrlResource(Resource): os.remove(fname) raise - return { + yield ({ "media_type": media_type, "media_length": length, "download_name": download_name, "created_ts": time_now_ms, "filesystem_id": file_id, "filename": fname, - } - + }) + return def _is_media(content_type): if content_type.lower().startswith("image/"): @@ -159,6 +160,6 @@ class PreviewUrlResource(Resource): def _is_html(content_type): content_type = content_type.lower() - if content_type == "text/html" or - content_type.startswith("application/xhtml"): + if (content_type == "text/html" or + content_type.startswith("application/xhtml")): return True -- cgit 1.4.1 From 64b4aead15927be56d7433250462c03f2d1f4565 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Tue, 29 Mar 2016 03:13:25 +0100 Subject: make it work --- docs/url_previews.rst | 2 +- synapse/http/client.py | 3 +- synapse/rest/media/v1/base_resource.py | 1 + synapse/rest/media/v1/preview_url_resource.py | 131 +++++++++++++++----------- 4 files changed, 80 insertions(+), 57 deletions(-) (limited to 'synapse') diff --git a/docs/url_previews.rst b/docs/url_previews.rst index 1dc6ee0c45..634d9d907f 100644 --- a/docs/url_previews.rst +++ b/docs/url_previews.rst @@ -56,7 +56,7 @@ As a first cut, let's do #2 and have the receiver hit the API to calculate its o API --- -GET /_matrix/media/r0/previewUrl?url=http://wherever.com +GET /_matrix/media/r0/preview_url?url=http://wherever.com 200 OK { "og:type" : "article" diff --git a/synapse/http/client.py b/synapse/http/client.py index a735300db0..cfdea91b57 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -26,6 +26,7 @@ from twisted.web.client import ( Agent, readBody, FileBodyProducer, PartialDownloadError, ) from twisted.web.http_headers import Headers +from twisted.web._newclient import ResponseDone from StringIO import StringIO @@ -266,7 +267,7 @@ class SimpleHttpClient(object): headers = dict(response.headers.getAllRawHeaders()) - if headers['Content-Length'] > max_size: + if 'Content-Length' in headers and headers['Content-Length'] > max_size: logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) # XXX: do we want to explicitly drop the connection here somehow? if so, how? raise # what should we be raising here? diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py index 58ef91c0b8..2b1938dc8e 100644 --- a/synapse/rest/media/v1/base_resource.py +++ b/synapse/rest/media/v1/base_resource.py @@ -72,6 +72,7 @@ class BaseMediaResource(Resource): self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels + self.max_spider_size = hs.config.max_spider_size self.filepaths = filepaths self.version_string = hs.version_string self.downloads = {} diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 5c8e20e23c..408b103367 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -12,26 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .base_resource import BaseMediaResource +from synapse.api.errors import Codes from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET from twisted.internet import defer from lxml import html +from synapse.util.stringutils import random_string from synapse.http.client import SimpleHttpClient -from synapse.http.server import request_handler, respond_with_json_bytes +from synapse.http.server import request_handler, respond_with_json, respond_with_json_bytes + +import os import ujson as json import logging logger = logging.getLogger(__name__) -class PreviewUrlResource(Resource): +class PreviewUrlResource(BaseMediaResource): isLeaf = True def __init__(self, hs, filepaths): - Resource.__init__(self) + BaseMediaResource.__init__(self, hs, filepaths) self.client = SimpleHttpClient(hs) - self.filepaths = filepaths - self.max_spider_size = hs.config.max_spider_size - self.server_name = hs.hostname - self.clock = hs.get_clock() def render_GET(self, request): self._async_render_GET(request) @@ -40,57 +42,76 @@ class PreviewUrlResource(Resource): @request_handler @defer.inlineCallbacks def _async_render_GET(self, request): - url = request.args.get("url") try: + # XXX: if get_user_by_req fails, what should we do in an async render? + requester = yield self.auth.get_user_by_req(request) + url = request.args.get("url")[0] + # TODO: keep track of whether there's an ongoing request for this preview # and block and return their details if there is one. - media_info = self._download_url(url) + media_info = yield self._download_url(url, requester.user) + + logger.warn("got media_info of '%s'" % media_info) + + if self._is_media(media_info['media_type']): + dims = yield self._generate_local_thumbnails( + media_info.filesystem_id, media_info + ) + + og = { + "og:description" : media_info.download_name, + "og:image" : "mxc://%s/%s" % (self.server_name, media_info.filesystem_id), + "og:image:type" : media_info['media_type'], + "og:image:width" : dims.width, + "og:image:height" : dims.height, + } + + # define our OG response for this media + elif self._is_html(media_info['media_type']): + tree = html.parse(media_info['filename']) + logger.warn(html.tostring(tree)) + + # suck it up into lxml and define our OG response. + # if we see any URLs in the OG response, then spider them + # (although the client could choose to do this by asking for previews of those URLs to avoid DoSing the server) + + # "og:type" : "article" + # "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672" + # "og:title" : "Matrix on Twitter" + # "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png" + # "og:description" : "Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP" + # "og:site_name" : "Twitter" + + og = {} + for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): + og[tag.attrib['property']] = tag.attrib['content'] + + # TODO: store our OG details in a cache (and expire them when stale) + # TODO: delete the content to stop diskfilling, as we only ever cared about its OG + else: + logger.warn("Failed to find any OG data in %s", url) + og = {} + + respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) except: - os.remove(fname) + # XXX: if we don't explicitly respond here, the request never returns. + # isn't this what server.py's wrapper is meant to be doing for us? + respond_with_json( + request, + 500, + { + "error": "Internal server error", + "errcode": Codes.UNKNOWN, + }, + send_cors=True + ) raise - if self._is_media(media_type): - dims = yield self._generate_local_thumbnails( - media_info.filesystem_id, media_info - ) - - og = { - "og:description" : media_info.download_name, - "og:image" : "mxc://%s/%s" % (self.server_name, media_info.filesystem_id), - "og:image:type" : media_info.media_type, - "og:image:width" : dims.width, - "og:image:height" : dims.height, - } - - # define our OG response for this media - elif self._is_html(media_type): - tree = html.parse(media_info.filename) - - # suck it up into lxml and define our OG response. - # if we see any URLs in the OG response, then spider them - # (although the client could choose to do this by asking for previews of those URLs to avoid DoSing the server) - - # "og:type" : "article" - # "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672" - # "og:title" : "Matrix on Twitter" - # "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png" - # "og:description" : "Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP" - # "og:site_name" : "Twitter" - - og = {} - for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - og[tag.attrib['property']] = tag.attrib['content'] - - # TODO: store our OG details in a cache (and expire them when stale) - # TODO: delete the content to stop diskfilling, as we only ever cared about its OG - - respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) - - def _download_url(url): - requester = yield self.auth.get_user_by_req(request) + @defer.inlineCallbacks + def _download_url(self, url, user): # XXX: horrible duplication with base_resource's _download_remote_file() file_id = random_string(24) @@ -99,6 +120,7 @@ class PreviewUrlResource(Resource): try: with open(fname, "wb") as f: + logger.warn("Trying to get url '%s'" % url) length, headers = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) @@ -137,14 +159,14 @@ class PreviewUrlResource(Resource): time_now_ms=self.clock.time_msec(), upload_name=download_name, media_length=length, - user_id=requester.user, + user_id=user, ) except: os.remove(fname) raise - yield ({ + defer.returnValue({ "media_type": media_type, "media_length": length, "download_name": download_name, @@ -152,14 +174,13 @@ class PreviewUrlResource(Resource): "filesystem_id": file_id, "filename": fname, }) - return - def _is_media(content_type): + def _is_media(self, content_type): if content_type.lower().startswith("image/"): return True - def _is_html(content_type): + def _is_html(self, content_type): content_type = content_type.lower() - if (content_type == "text/html" or + if (content_type.startswith("text/html") or content_type.startswith("application/xhtml")): return True -- cgit 1.4.1 From 19038582d3957eef2b662d28035361ecf9d3a84e Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Tue, 29 Mar 2016 03:14:16 +0100 Subject: debug --- synapse/rest/media/v1/preview_url_resource.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 408b103367..4f7c9e3d1b 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -53,7 +53,7 @@ class PreviewUrlResource(BaseMediaResource): media_info = yield self._download_url(url, requester.user) - logger.warn("got media_info of '%s'" % media_info) + logger.debug("got media_info of '%s'" % media_info) if self._is_media(media_info['media_type']): dims = yield self._generate_local_thumbnails( @@ -71,7 +71,6 @@ class PreviewUrlResource(BaseMediaResource): # define our OG response for this media elif self._is_html(media_info['media_type']): tree = html.parse(media_info['filename']) - logger.warn(html.tostring(tree)) # suck it up into lxml and define our OG response. # if we see any URLs in the OG response, then spider them @@ -120,7 +119,7 @@ class PreviewUrlResource(BaseMediaResource): try: with open(fname, "wb") as f: - logger.warn("Trying to get url '%s'" % url) + logger.debug("Trying to get url '%s'" % url) length, headers = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) -- cgit 1.4.1 From 721b2bfa851bcf91948e166587dce4da666739b1 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Tue, 29 Mar 2016 03:32:52 +0100 Subject: implement redirects --- synapse/http/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/http/client.py b/synapse/http/client.py index cfdea91b57..71b2e3375e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -23,7 +23,7 @@ from canonicaljson import encode_canonical_json from twisted.internet import defer, reactor, ssl, protocol from twisted.web.client import ( - Agent, readBody, FileBodyProducer, PartialDownloadError, + RedirectAgent, Agent, readBody, FileBodyProducer, PartialDownloadError, ) from twisted.web.http_headers import Headers from twisted.web._newclient import ResponseDone @@ -59,11 +59,11 @@ class SimpleHttpClient(object): # The default context factory in Twisted 14.0.0 (which we require) is # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' - self.agent = Agent( + self.agent = RedirectAgent(Agent( reactor, connectTimeout=15, contextFactory=hs.get_http_client_context_factory() - ) + )) self.user_agent = hs.version_string if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) -- cgit 1.4.1 From ae5831d30354c713cd1693f3b74cf048de7428a7 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Tue, 29 Mar 2016 03:32:55 +0100 Subject: fix bugs --- synapse/rest/media/v1/preview_url_resource.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 4f7c9e3d1b..b999944e86 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -57,15 +57,15 @@ class PreviewUrlResource(BaseMediaResource): if self._is_media(media_info['media_type']): dims = yield self._generate_local_thumbnails( - media_info.filesystem_id, media_info + media_info['filesystem_id'], media_info ) og = { - "og:description" : media_info.download_name, - "og:image" : "mxc://%s/%s" % (self.server_name, media_info.filesystem_id), + "og:description" : media_info['download_name'], + "og:image" : "mxc://%s/%s" % (self.server_name, media_info['filesystem_id']), "og:image:type" : media_info['media_type'], - "og:image:width" : dims.width, - "og:image:height" : dims.height, + "og:image:width" : dims['width'], + "og:image:height" : dims['height'], } # define our OG response for this media @@ -123,6 +123,7 @@ class PreviewUrlResource(BaseMediaResource): length, headers = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) + # FIXME: handle 404s sanely - don't spider an error page media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() -- cgit 1.4.1 From 3f9948a069498e9966166a0fa581bdbf872d4ad3 Mon Sep 17 00:00:00 2001 From: Niklas Riekenbrauck Date: Mon, 28 Mar 2016 21:33:40 +0200 Subject: Add JWT support --- synapse/config/homeserver.py | 3 ++- synapse/config/jwt.py | 37 ++++++++++++++++++++++++++++ synapse/python_dependencies.py | 1 + synapse/rest/client/v1/login.py | 53 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 synapse/config/jwt.py (limited to 'synapse') diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index a08c170f1d..acf74c8761 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -29,13 +29,14 @@ from .key import KeyConfig from .saml2 import SAML2Config from .cas import CasConfig from .password import PasswordConfig +from .jwt import JWTConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, - PasswordConfig,): + JWTConfig, PasswordConfig,): pass diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py new file mode 100644 index 0000000000..4cb092bbec --- /dev/null +++ b/synapse/config/jwt.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 Niklas Riekenbrauck +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class JWTConfig(Config): + def read_config(self, config): + jwt_config = config.get("jwt_config", None) + if jwt_config: + self.jwt_enabled = jwt_config.get("enabled", False) + self.jwt_secret = jwt_config["secret"] + self.jwt_algorithm = jwt_config["algorithm"] + else: + self.jwt_enabled = False + self.jwt_secret = None + self.jwt_algorithm = None + + def default_config(self, **kwargs): + return """\ + # jwt_config: + # enabled: true + # secret: "a secret" + # algorithm: "HS256" + """ diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 0a6043ae8d..cf1414b4db 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -36,6 +36,7 @@ REQUIREMENTS = { "blist": ["blist"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pymacaroons-pynacl": ["pymacaroons"], + "pyjwt": ["jwt"], } CONDITIONAL_REQUIREMENTS = { "web_client": { diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index fe593d07ce..d14ce3efa2 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -33,6 +33,9 @@ from saml2.client import Saml2Client import xml.etree.ElementTree as ET +import jwt +from jwt.exceptions import InvalidTokenError + logger = logging.getLogger(__name__) @@ -43,12 +46,16 @@ class LoginRestServlet(ClientV1RestServlet): SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" TOKEN_TYPE = "m.login.token" + JWT_TYPE = "m.login.jwt" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.password_enabled = hs.config.password_enabled self.saml2_enabled = hs.config.saml2_enabled + self.jwt_enabled = hs.config.jwt_enabled + self.jwt_secret = hs.config.jwt_secret + self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled self.cas_server_url = hs.config.cas_server_url self.cas_required_attributes = hs.config.cas_required_attributes @@ -57,6 +64,8 @@ class LoginRestServlet(ClientV1RestServlet): def on_GET(self, request): flows = [] + if self.jwt_enabled: + flows.append({"type": LoginRestServlet.JWT_TYPE}) if self.saml2_enabled: flows.append({"type": LoginRestServlet.SAML2_TYPE}) if self.cas_enabled: @@ -98,6 +107,10 @@ class LoginRestServlet(ClientV1RestServlet): "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) + elif self.jwt_enabled and (login_submission["type"] == + LoginRestServlet.JWT_TYPE): + result = yield self.do_jwt_login(login_submission) + defer.returnValue(result) # TODO Delete this after all CAS clients switch to token login instead elif self.cas_enabled and (login_submission["type"] == LoginRestServlet.CAS_TYPE): @@ -209,6 +222,46 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) + @defer.inlineCallbacks + def do_jwt_login(self, login_submission): + token = login_submission['token'] + if token is None: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + try: + payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) + except InvalidTokenError: + raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) + + user = payload['user'] + if user is None: + raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) + + user_id = UserID.create(user, self.hs.hostname).to_string() + auth_handler = self.handlers.auth_handler + user_exists = yield auth_handler.does_user_exist(user_id) + if user_exists: + user_id, access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(user_id) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "refresh_token": refresh_token, + "home_server": self.hs.hostname, + } + else: + user_id, access_token = ( + yield self.handlers.registration_handler.register(localpart=user) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + # TODO Delete this after all CAS clients switch to token login instead def parse_cas_response(self, cas_response_body): root = ET.fromstring(cas_response_body) -- cgit 1.4.1 From fddb6fddc1f1e70ab79d8d4ed276f722ab2ea058 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Mar 2016 10:54:01 +0100 Subject: Require user to have left room to forget room This dramatically simplifies the forget API code - in particular it no longer generates a leave event. --- synapse/handlers/room.py | 22 ++++++++++++++++------ synapse/rest/client/v1/room.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 133183a257..1d4c2c39a1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -416,8 +416,6 @@ class RoomMemberHandler(BaseHandler): effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" - elif action == "forget": - effective_membership_state = "leave" if third_party_signed is not None: replication = self.hs.get_replication_layer() @@ -473,9 +471,6 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts=remote_room_hosts, ) - if action == "forget": - yield self.forget(requester.user, room_id) - @defer.inlineCallbacks def send_membership_event( self, @@ -935,8 +930,23 @@ class RoomMemberHandler(BaseHandler): display_name = data["display_name"] defer.returnValue((token, public_keys, fallback_public_key, display_name)) + @defer.inlineCallbacks def forget(self, user, room_id): - return self.store.forget(user.to_string(), room_id) + user_id = user.to_string() + + member = yield self.state_handler.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None + + if membership is not None and membership != Membership.LEAVE: + raise SynapseError(400, "User %s in room %s" % ( + user_id, room_id + )) + + yield self.store.forget(user_id, room_id) class RoomListHandler(BaseHandler): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index a1fa7daf79..ccb6e3c45e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -405,6 +405,43 @@ class RoomEventContext(ClientV1RestServlet): defer.returnValue((200, results)) +class RoomForgetRestServlet(ClientV1RestServlet): + def register(self, http_server): + # /rooms/$roomid/[invite|join|leave] + PATTERNS = ("/rooms/(?P[^/]*)/forget") + register_txn_path(self, PATTERNS, http_server) + + @defer.inlineCallbacks + def on_POST(self, request, room_id, txn_id=None): + requester = yield self.auth.get_user_by_req( + request, + allow_guest=False, + ) + + yield self.handlers.room_member_handler.forget( + user=requester.user, + room_id=room_id, + ) + + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_PUT(self, request, room_id, txn_id): + try: + defer.returnValue( + self.txns.get_client_transaction(request, txn_id) + ) + except KeyError: + pass + + response = yield self.on_POST( + request, room_id, txn_id + ) + + self.txns.store_client_transaction(request, txn_id, response) + defer.returnValue(response) + + # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): @@ -624,6 +661,7 @@ def register_servlets(hs, http_server): RoomMemberListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) + RoomForgetRestServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) -- cgit 1.4.1 From 1e25f62ee6a8aaa65c139e264ec2be1f8831eb16 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 12:55:02 +0100 Subject: Use a stream id generator to assign state group ids --- synapse/events/__init__.py | 2 +- synapse/storage/__init__.py | 2 +- synapse/storage/events.py | 90 +++++++++++++++++++++++++-------------------- synapse/storage/state.py | 16 ++++---- 4 files changed, 60 insertions(+), 50 deletions(-) (limited to 'synapse') diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 23f8b612ae..925a83c645 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -31,7 +31,7 @@ class _EventInternalMetadata(object): return dict(self.__dict__) def is_outlier(self): - return hasattr(self, "outlier") and self.outlier + return getattr(self, "outlier", False) def _event_dict_property(key): diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 250ba536ea..aaad38039e 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -116,7 +116,7 @@ class DataStore(RoomMemberStore, RoomStore, ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") - self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") + self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5233430028..5f675ab09b 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes from canonicaljson import encode_canonical_json from contextlib import contextmanager + import logging import math import ujson as json @@ -79,41 +80,57 @@ class EventsStore(SQLBaseStore): len(events_and_contexts) ) + state_group_id_manager = self._state_groups_id_gen.get_next_mult( + len(events_and_contexts) + ) with stream_ordering_manager as stream_orderings: - for (event, _), stream in zip(events_and_contexts, stream_orderings): - event.internal_metadata.stream_ordering = stream - - chunks = [ - events_and_contexts[x:x + 100] - for x in xrange(0, len(events_and_contexts), 100) - ] + with state_group_id_manager as state_group_ids: + for (event, context), stream, state_group_id in zip( + events_and_contexts, stream_orderings, state_group_ids + ): + event.internal_metadata.stream_ordering = stream + # Assign a state group_id in case a new id is needed for + # this context. In theory we only need to assign this + # for contexts that have current_state and aren't outliers + # but that make the code more complicated. Assigning an ID + # per event only causes the state_group_ids to grow as fast + # as the stream_ordering so in practise shouldn't be a problem. + context.new_state_group_id = state_group_id + + chunks = [ + events_and_contexts[x:x + 100] + for x in xrange(0, len(events_and_contexts), 100) + ] - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( - yield self.runInteraction( - "persist_events", - self._persist_events_txn, - events_and_contexts=chunk, - backfilled=backfilled, - is_new_state=is_new_state, - ) + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + yield self.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=chunk, + backfilled=backfilled, + is_new_state=is_new_state, + ) @defer.inlineCallbacks @log_function def persist_event(self, event, context, is_new_state=True, current_state=None): + try: with self._stream_id_gen.get_next() as stream_ordering: - event.internal_metadata.stream_ordering = stream_ordering - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - is_new_state=is_new_state, - current_state=current_state, - ) + with self._state_groups_id_gen.get_next() as state_group_id: + event.internal_metadata.stream_ordering = stream_ordering + context.new_state_group_id = state_group_id + yield self.runInteraction( + "persist_event", + self._persist_event_txn, + event=event, + context=context, + is_new_state=is_new_state, + current_state=current_state, + ) except _RollbackButIsFineException: pass @@ -178,7 +195,7 @@ class EventsStore(SQLBaseStore): @log_function def _persist_event_txn(self, txn, event, context, - is_new_state=True, current_state=None): + is_new_state, current_state): # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: @@ -215,7 +232,7 @@ class EventsStore(SQLBaseStore): @log_function def _persist_events_txn(self, txn, events_and_contexts, backfilled, - is_new_state=True): + is_new_state): depth_updates = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids @@ -282,9 +299,7 @@ class EventsStore(SQLBaseStore): outlier_persisted = have_persisted[event.event_id] if not event.internal_metadata.is_outlier() and outlier_persisted: - self._store_state_groups_txn( - txn, event, context, - ) + self._store_mult_state_groups_txn(txn, ((event, context),)) metadata_json = encode_json( event.internal_metadata.get_dict() @@ -310,19 +325,14 @@ class EventsStore(SQLBaseStore): self._update_extremeties(txn, [event]) - events_and_contexts = filter( - lambda ec: ec[0] not in to_remove, - events_and_contexts - ) + events_and_contexts = [ + ec for ec in events_and_contexts if ec[0] not in to_remove + ] if not events_and_contexts: return - self._store_mult_state_groups_txn(txn, [ - (event, context) - for event, context in events_and_contexts - if not event.internal_metadata.is_outlier() - ]) + self._store_mult_state_groups_txn(txn, events_and_contexts) self._handle_mult_prev_events( txn, diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 02cefdff26..30d1060ecd 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -64,12 +64,12 @@ class StateStore(SQLBaseStore): for group, state_map in group_to_state.items() }) - def _store_state_groups_txn(self, txn, event, context): - return self._store_mult_state_groups_txn(txn, [(event, context)]) - def _store_mult_state_groups_txn(self, txn, events_and_contexts): state_groups = {} for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + if context.current_state is None: continue @@ -82,7 +82,8 @@ class StateStore(SQLBaseStore): if event.is_state(): state_events[(event.type, event.state_key)] = event - state_group = self._state_groups_id_gen.get_next() + state_group = context.new_state_group_id + self._simple_insert_txn( txn, table="state_groups", @@ -114,11 +115,10 @@ class StateStore(SQLBaseStore): table="event_to_state_groups", values=[ { - "state_group": state_groups[event.event_id], - "event_id": event.event_id, + "state_group": state_group_id, + "event_id": event_id, } - for event, context in events_and_contexts - if context.current_state is not None + for event_id, state_group_id in state_groups.items() ], ) -- cgit 1.4.1 From 08a8514b7a05bf2b6d1f8a5d8a3b8985c78ade9e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Mar 2016 15:05:33 +0100 Subject: Remove spurious comment --- synapse/rest/client/v1/room.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse') diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index ccb6e3c45e..b223fb7e5f 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -407,7 +407,6 @@ class RoomEventContext(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet): def register(self, http_server): - # /rooms/$roomid/[invite|join|leave] PATTERNS = ("/rooms/(?P[^/]*)/forget") register_txn_path(self, PATTERNS, http_server) -- cgit 1.4.1 From 73b6bf4629dbfa547a3e5d65395de2ed9c4c08c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Mar 2016 15:09:18 +0100 Subject: Only forget room if you were in the room --- synapse/handlers/room.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 1d4c2c39a1..71f7ab3d22 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -946,7 +946,8 @@ class RoomMemberHandler(BaseHandler): user_id, room_id )) - yield self.store.forget(user_id, room_id) + if membership: + yield self.store.forget(user_id, room_id) class RoomListHandler(BaseHandler): -- cgit 1.4.1 From 31a9eceda5cf00b0482baf1c8bf1e138c823f621 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 15:58:20 +0100 Subject: Add a replication stream for state groups --- synapse/replication/resource.py | 36 +++++++++++++++++++++++++++++------- synapse/storage/events.py | 6 +++++- synapse/storage/state.py | 30 ++++++++++++++++++++++++++++++ tests/replication/test_resource.py | 30 +++++++++++++++++++++++++++--- 4 files changed, 91 insertions(+), 11 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 8c1ae0fbc7..096a79a7a4 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -38,6 +38,7 @@ STREAM_NAMES = ( ("backfill",), ("push_rules",), ("pushers",), + ("state",), ) @@ -123,6 +124,7 @@ class ReplicationResource(Resource): backfill_token = yield self.store.get_current_backfill_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() pushers_token = self.store.get_pushers_stream_token() + state_token = self.store.get_state_stream_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -133,6 +135,7 @@ class ReplicationResource(Resource): backfill_token, push_rules_token, pushers_token, + state_token, )) @request_handler @@ -156,6 +159,7 @@ class ReplicationResource(Resource): yield self.receipts(writer, current_token, limit) yield self.push_rules(writer, current_token, limit) yield self.pushers(writer, current_token, limit) + yield self.state(writer, current_token, limit) self.streams(writer, current_token) logger.info("Replicated %d rows", writer.total) @@ -205,12 +209,12 @@ class ReplicationResource(Resource): current_token.backfill, current_token.events, limit ) - writer.write_header_and_rows( - "events", events_rows, ("position", "internal", "json") - ) - writer.write_header_and_rows( - "backfill", backfill_rows, ("position", "internal", "json") - ) + writer.write_header_and_rows("events", events_rows, ( + "position", "internal", "json", "state_group" + )) + writer.write_header_and_rows("backfill", backfill_rows, ( + "position", "internal", "json", "state_group" + )) @defer.inlineCallbacks def presence(self, writer, current_token): @@ -320,6 +324,24 @@ class ReplicationResource(Resource): "position", "user_id", "app_id", "pushkey" )) + @defer.inlineCallbacks + def state(self, writer, current_token, limit): + current_position = current_token.state + + state = parse_integer(writer.request, "state") + if state is not None: + state_groups, state_group_state = ( + yield self.store.get_all_new_state_groups( + state, current_position, limit + ) + ) + writer.write_header_and_rows("state_groups", state_groups, ( + "position", "room_id", "event_id" + )) + writer.write_header_and_rows("state_group_state", state_group_state, ( + "position", "type", "state_key", "event_id" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -350,7 +372,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules", "pushers" + "push_rules", "pushers", "state" ))): __slots__ = [] diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5f675ab09b..a4b8995496 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1097,10 +1097,12 @@ class EventsStore(SQLBaseStore): new events or as backfilled events""" def get_all_new_events_txn(txn): sql = ( - "SELECT e.stream_ordering, ej.internal_metadata, ej.json" + "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group" " FROM events as e" " JOIN event_json as ej" " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " LEFT JOIN event_to_state_groups as eg" + " ON e.event_id = eg.event_id" " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" " LIMIT ?" @@ -1116,6 +1118,8 @@ class EventsStore(SQLBaseStore): " FROM events as e" " JOIN event_json as ej" " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " LEFT JOIN event_to_state_groups as eg" + " ON e.event_id = eg.event_id" " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" " ORDER BY e.stream_ordering DESC" " LIMIT ?" diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 30d1060ecd..7fc9a4f264 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -429,3 +429,33 @@ class StateStore(SQLBaseStore): } defer.returnValue(results) + + def get_all_new_state_groups(self, last_id, current_id, limit): + def get_all_new_state_groups_txn(txn): + sql = ( + "SELECT id, room_id, event_id FROM state_groups" + " WHERE ? < id AND id <= ? ORDER BY id LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + groups = txn.fetchall() + + if not groups: + return ([], []) + + lower_bound = groups[0][0] + upper_bound = groups[-1][0] + sql = ( + "SELECT state_group, type, state_key, event_id" + " FROM state_groups_state" + " WHERE ? <= state_group AND state_group <= ?" + ) + + txn.execute(sql, (lower_bound, upper_bound)) + state_group_state = txn.fetchall() + return (groups, state_group_state) + return self.runInteraction( + "get_all_new_state_groups", get_all_new_state_groups_txn + ) + + def get_state_stream_token(self): + return self._state_groups_id_gen.get_max_token() diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index f4b5fb3328..b1dd7b4a74 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -58,15 +58,21 @@ class ReplicationResourceCase(unittest.TestCase): self.assertEquals(body, {}) @defer.inlineCallbacks - def test_events(self): - get = self.get(events="-1", timeout="0") + def test_events_and_state(self): + get = self.get(events="-1", state="-1", timeout="0") yield self.hs.get_handlers().room_creation_handler.create_room( Requester(self.user, "", False), {} ) code, body = yield get self.assertEquals(code, 200) self.assertEquals(body["events"]["field_names"], [ - "position", "internal", "json" + "position", "internal", "json", "state_group" + ]) + self.assertEquals(body["state_groups"]["field_names"], [ + "position", "room_id", "event_id" + ]) + self.assertEquals(body["state_group_state"]["field_names"], [ + "position", "type", "state_key", "event_id" ]) @defer.inlineCallbacks @@ -132,6 +138,7 @@ class ReplicationResourceCase(unittest.TestCase): test_timeout_backfill = _test_timeout("backfill") test_timeout_push_rules = _test_timeout("push_rules") test_timeout_pushers = _test_timeout("pushers") + test_timeout_state = _test_timeout("state") @defer.inlineCallbacks def send_text_message(self, room_id, message): @@ -182,4 +189,21 @@ class ReplicationResourceCase(unittest.TestCase): ) response_body = json.loads(response_json) + if response_code == 200: + self.check_response(response_body) + defer.returnValue((response_code, response_body)) + + def check_response(self, response_body): + for name, stream in response_body.items(): + self.assertIn("field_names", stream) + field_names = stream["field_names"] + self.assertIn("rows", stream) + self.assertTrue(stream["rows"]) + for row in stream["rows"]: + self.assertEquals( + len(row), len(field_names), + "%s: len(row = %r) == len(field_names = %r)" % ( + name, row, field_names + ) + ) -- cgit 1.4.1 From 61407986b40e28b590961d364f5618bbe7d44e94 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 16:18:46 +0100 Subject: Add a entry to current_state_resets table when the current state is reset --- synapse/storage/events.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'synapse') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index a4b8995496..bd4d503b6d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -205,6 +205,15 @@ class EventsStore(SQLBaseStore): txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases, event.room_id) + # Add an entry to the current_state_resets table to record the point + # where we clobbered the current state + stream_order = event.internal_metadata.stream_ordering + self._simple_insert_txn( + txn, + table="current_state_resets", + values={"event_stream_ordering": stream_order} + ) + self._simple_delete_txn( txn, table="current_state_events", -- cgit 1.4.1 From 8b8052909f5352e55998b8c7fbb26a8274e75366 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 16:20:07 +0100 Subject: return the state_group for backfill --- synapse/storage/events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index a4b8995496..e0ef7f46b2 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1114,7 +1114,8 @@ class EventsStore(SQLBaseStore): new_forward_events = [] sql = ( - "SELECT -e.stream_ordering, ej.internal_metadata, ej.json" + "SELECT -e.stream_ordering, ej.internal_metadata, ej.json," + " eg.state_group" " FROM events as e" " JOIN event_json as ej" " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" -- cgit 1.4.1 From 1fbb094c6fbaab33ef8e17802e37057e83718e7e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 17:19:56 +0100 Subject: Add replication streams for ex outliers and current state resets --- synapse/replication/resource.py | 17 ++++++- synapse/storage/events.py | 60 +++++++++++++++++++++++- synapse/storage/schema/delta/30/state_stream.sql | 38 +++++++++++++++ 3 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 synapse/storage/schema/delta/30/state_stream.sql (limited to 'synapse') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 096a79a7a4..7afa1242d5 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -204,7 +204,11 @@ class ReplicationResource(Resource): request_events = current_token.events if request_backfill is None: request_backfill = current_token.backfill - events_rows, backfill_rows = yield self.store.get_all_new_events( + ( + events_rows, backfill_rows, + forward_ex_outliers, backward_ex_outliers, + state_resets + ) = yield self.store.get_all_new_events( request_backfill, request_events, current_token.backfill, current_token.events, limit @@ -215,6 +219,17 @@ class ReplicationResource(Resource): writer.write_header_and_rows("backfill", backfill_rows, ( "position", "internal", "json", "state_group" )) + writer.write_header_and_rows( + "forward_ex_outliers", forward_ex_outliers, + ("position", "event_id", "state_group") + ) + writer.write_header_and_rows( + "backward_ex_outliers", backward_ex_outliers, + ("position", "event_id", "state_group") + ) + writer.write_header_and_rows( + "state_resets", state_resets, ("position",) + ) @defer.inlineCallbacks def presence(self, writer, current_token): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index bd4d503b6d..9725a3fed7 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -323,6 +323,18 @@ class EventsStore(SQLBaseStore): (metadata_json, event.event_id,) ) + stream_order = event.internal_metadata.stream_ordering + state_group_id = context.state_group or context.new_state_group_id + self._simple_insert_txn( + txn, + table="ex_outlier_stream", + values={ + "event_stream_ordering": stream_order, + "event_id": event.event_id, + "state_group": state_group_id, + } + ) + sql = ( "UPDATE events SET outlier = ?" " WHERE event_id = ?" @@ -1119,8 +1131,34 @@ class EventsStore(SQLBaseStore): if last_forward_id != current_forward_id: txn.execute(sql, (last_forward_id, current_forward_id, limit)) new_forward_events = txn.fetchall() + + if len(new_forward_events) == limit: + upper_bound = new_forward_events[-1][0] + else: + upper_bound = current_forward_id + + sql = ( + "SELECT -event_stream_ordering FROM current_state_resets" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering ASC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + state_resets = txn.fetchall() + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + forward_ex_outliers = txn.fetchall() else: new_forward_events = [] + state_resets = [] + forward_ex_outliers = [] sql = ( "SELECT -e.stream_ordering, ej.internal_metadata, ej.json" @@ -1136,8 +1174,28 @@ class EventsStore(SQLBaseStore): if last_backfill_id != current_backfill_id: txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) new_backfill_events = txn.fetchall() + + if len(new_backfill_events) == limit: + upper_bound = new_backfill_events[-1][0] + else: + upper_bound = current_backfill_id + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_backfill_id, -upper_bound)) + backward_ex_outliers = txn.fetchall() else: new_backfill_events = [] + backward_ex_outliers = [] - return (new_forward_events, new_backfill_events) + return ( + new_forward_events, new_backfill_events, + forward_ex_outliers, backward_ex_outliers, + state_resets, + ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/schema/delta/30/state_stream.sql new file mode 100644 index 0000000000..706fe1dcf4 --- /dev/null +++ b/synapse/storage/schema/delta/30/state_stream.sql @@ -0,0 +1,38 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * The positions in the event stream_ordering when the current_state was + * replaced by the state at the event. + */ + +CREATE TABLE IF NOT EXISTS current_state_resets( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL +); + +/* The outlier events that have aquired a state group typically through + * backfill. This is tracked separately to the events table, as assigning a + * state group change the position of the existing event in the stream + * ordering. + * However since a stream_ordering is assigned in persist_event for the + * (event, state) pair, we can use that stream_ordering to identify when + * the new state was assigned for the event. + */ +CREATE TABLE IF NOT EXISTS ex_outlier_stream( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL, + event_id TEXT NOT NULL, + state_group BIGINT NOT NULL +); -- cgit 1.4.1 From 9113316b0e53afd874822d26a9913d2b97f57b53 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Fri, 25 Mar 2016 23:38:19 +0000 Subject: typo --- synapse/replication/resource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 096a79a7a4..33cb2eafa3 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -77,7 +77,7 @@ class ReplicationResource(Resource): The response is a JSON object with keys for each stream with updates. Under each key is a JSON object with: - * "postion": The current position of the stream. + * "position": The current position of the stream. * "field_names": The names of the fields in each row. * "rows": The updates as an array of arrays. -- cgit 1.4.1 From a8a5dd3b44a4526307502bd621ee0bd43c87c77f Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 31 Mar 2016 01:55:21 +0100 Subject: handle requests with missing content-length headers (e.g. YouTube) --- synapse/http/client.py | 33 +++++++++++++++++++++------ synapse/rest/media/v1/preview_url_resource.py | 4 ++-- 2 files changed, 28 insertions(+), 9 deletions(-) (limited to 'synapse') diff --git a/synapse/http/client.py b/synapse/http/client.py index 71b2e3375e..30f31a915d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -23,8 +23,9 @@ from canonicaljson import encode_canonical_json from twisted.internet import defer, reactor, ssl, protocol from twisted.web.client import ( - RedirectAgent, Agent, readBody, FileBodyProducer, PartialDownloadError, + BrowserLikeRedirectAgent, Agent, readBody, FileBodyProducer, PartialDownloadError, ) +from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers from twisted.web._newclient import ResponseDone @@ -59,11 +60,11 @@ class SimpleHttpClient(object): # The default context factory in Twisted 14.0.0 (which we require) is # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' - self.agent = RedirectAgent(Agent( + self.agent = Agent( reactor, connectTimeout=15, contextFactory=hs.get_http_client_context_factory() - )) + ) self.user_agent = hs.version_string if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) @@ -253,10 +254,6 @@ class SimpleHttpClient(object): headers. """ - def body_callback(method, url_bytes, headers_dict): - self.sign_request(destination, method, url_bytes, headers_dict) - return None - response = yield self.request( "GET", url.encode("ascii"), @@ -309,6 +306,10 @@ class _ReadBodyToFileProtocol(protocol.Protocol): def connectionLost(self, reason): if reason.check(ResponseDone): self.deferred.callback(self.length) + elif reason.check(PotentialDataLoss): + # stolen from https://github.com/twisted/treq/pull/49/files + # http://twistedmatrix.com/trac/ticket/4840 + self.deferred.callback(self.length) else: self.deferred.errback(reason) @@ -350,6 +351,24 @@ class CaptchaServerHttpClient(SimpleHttpClient): # twisted dislikes google's response, no content length. defer.returnValue(e.response) +class SpiderHttpClient(SimpleHttpClient): + """ + Separate HTTP client for spidering arbitrary URLs. + Special in that it follows retries and has a UA that looks + like a browser. + + used by the preview_url endpoint in the content repo. + """ + def __init__(self, hs): + SimpleHttpClient.__init__(self, hs) + # clobber the base class's agent and UA: + self.agent = BrowserLikeRedirectAgent(Agent( + reactor, + connectTimeout=15, + contextFactory=hs.get_http_client_context_factory() + )) + # Look like Chrome for now + #self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) Chrome Safari" % hs.version_string) def encode_urlencode_args(args): return {k: encode_urlencode_arg(v) for k, v in args.items()} diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index b999944e86..ca2529cc10 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -19,7 +19,7 @@ from twisted.web.server import NOT_DONE_YET from twisted.internet import defer from lxml import html from synapse.util.stringutils import random_string -from synapse.http.client import SimpleHttpClient +from synapse.http.client import SpiderHttpClient from synapse.http.server import request_handler, respond_with_json, respond_with_json_bytes import os @@ -33,7 +33,7 @@ class PreviewUrlResource(BaseMediaResource): def __init__(self, hs, filepaths): BaseMediaResource.__init__(self, hs, filepaths) - self.client = SimpleHttpClient(hs) + self.client = SpiderHttpClient(hs) def render_GET(self, request): self._async_render_GET(request) -- cgit 1.4.1 From f699b8f997ed743af0cfa7046428915a7f42610b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 31 Mar 2016 10:04:28 +0100 Subject: Read from DNS cache if within TTL --- synapse/http/endpoint.py | 39 +++++++++++++++++++++++---------------- tests/test_dns.py | 5 ++++- 2 files changed, 27 insertions(+), 17 deletions(-) (limited to 'synapse') diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 4775f6707d..e80d00e2af 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError import collections import logging import random +import time logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ SERVER_CACHE = {} _Server = collections.namedtuple( - "_Server", "priority weight host port" + "_Server", "priority weight host port expires" ) @@ -92,7 +93,8 @@ class SRVClientEndpoint(object): host=domain, port=default_port, priority=0, - weight=0 + weight=0, + expires=0, ) else: self.default_server = None @@ -154,6 +156,12 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): + cache_entry = cache.get(service_name, None) + if cache_entry: + if all(s.expires > int(time.time()) for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + servers = [] try: @@ -173,27 +181,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): continue payload = answer.payload - host = str(payload.target) + srv_ttl = answer.ttl try: answers, _, _ = yield dns_client.lookupAddress(host) except DNSNameError: continue - ips = [ - answer.payload.dottedQuad() - for answer in answers - if answer.type == dns.A and answer.payload - ] - - for ip in ips: - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight) - )) + for answer in answers: + if answer.type == dns.A and answer.payload: + ip = answer.payload.dottedQuad() + host_ttl = min(srv_ttl, answer.ttl) + + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(time.time()) + host_ttl, + )) servers.sort() cache[service_name] = list(servers) diff --git a/tests/test_dns.py b/tests/test_dns.py index 637b1606f8..e006ed1a59 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -69,8 +69,11 @@ class DnsTestCase(unittest.TestCase): service_name = "test_service.examle.com" + entry = Mock(spec_set=["expires"]) + entry.expires = 999999999 + cache = { - service_name: [object()] + service_name: [entry] } servers = yield resolve_service( -- cgit 1.4.1 From c27c51484a306caf3946f205a933878c563d3d8a Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 31 Mar 2016 10:12:31 +0100 Subject: Don't ignore the obey overlay if the rule has an enabled attribute of False Fixes https://github.com/vector-im/vector-web/issues/1244 --- synapse/push/push_rule_evaluator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 51f73a5b78..c3c2877629 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -133,8 +133,9 @@ class PushRuleEvaluator: enabled = self.enabled_map.get(r['rule_id'], None) if enabled is not None and not enabled: continue - - if not r.get("enabled", True): + elif enabled is None and not r.get("enabled", True): + # if no override, check enabled on the rule itself + # (may have come from a base rule) continue conditions = r['conditions'] -- cgit 1.4.1 From f9d3665c8841335cd70325dd758b4193c462ca60 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 31 Mar 2016 10:23:48 +0100 Subject: Allow clock to be passed in to func --- synapse/http/endpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index e80d00e2af..bc28a2959a 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -155,10 +155,10 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks -def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): +def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): cache_entry = cache.get(service_name, None) if cache_entry: - if all(s.expires > int(time.time()) for s in cache_entry): + if all(s.expires > int(clock.time()) for s in cache_entry): servers = list(cache_entry) defer.returnValue(servers) @@ -199,7 +199,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): port=int(payload.port), priority=int(payload.priority), weight=int(payload.weight), - expires=int(time.time()) + host_ttl, + expires=int(clock.time()) + host_ttl, )) servers.sort() -- cgit 1.4.1 From 2ec54260350b46c937527bd566b713cf3544f1d2 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 31 Mar 2016 10:33:02 +0100 Subject: Use a namedtuple rather than tuple unpacking --- synapse/replication/resource.py | 16 ++++++---------- synapse/storage/events.py | 11 +++++++++-- 2 files changed, 15 insertions(+), 12 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 7afa1242d5..69afcb03d2 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -204,31 +204,27 @@ class ReplicationResource(Resource): request_events = current_token.events if request_backfill is None: request_backfill = current_token.backfill - ( - events_rows, backfill_rows, - forward_ex_outliers, backward_ex_outliers, - state_resets - ) = yield self.store.get_all_new_events( + res = yield self.store.get_all_new_events( request_backfill, request_events, current_token.backfill, current_token.events, limit ) - writer.write_header_and_rows("events", events_rows, ( + writer.write_header_and_rows("events", res.new_forward_events, ( "position", "internal", "json", "state_group" )) - writer.write_header_and_rows("backfill", backfill_rows, ( + writer.write_header_and_rows("backfill", res.new_backfill_events, ( "position", "internal", "json", "state_group" )) writer.write_header_and_rows( - "forward_ex_outliers", forward_ex_outliers, + "forward_ex_outliers", res.forward_ex_outliers, ("position", "event_id", "state_group") ) writer.write_header_and_rows( - "backward_ex_outliers", backward_ex_outliers, + "backward_ex_outliers", res.backward_ex_outliers, ("position", "event_id", "state_group") ) writer.write_header_and_rows( - "state_resets", state_resets, ("position",) + "state_resets", res.state_resets, ("position",) ) @defer.inlineCallbacks diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 9725a3fed7..b7ad045e41 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -25,7 +25,7 @@ from synapse.api.constants import EventTypes from canonicaljson import encode_canonical_json from contextlib import contextmanager - +from collections import namedtuple import logging import math @@ -1193,9 +1193,16 @@ class EventsStore(SQLBaseStore): new_backfill_events = [] backward_ex_outliers = [] - return ( + return AllNewEventsResult( new_forward_events, new_backfill_events, forward_ex_outliers, backward_ex_outliers, state_resets, ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) + + +AllNewEventsResult = namedtuple("AllNewEventsResult", [ + "new_forward_events", "new_backfill_events", + "forward_ex_outliers", "backward_ex_outliers", + "state_resets" +]) -- cgit 1.4.1 From 5260db7663ff242e8a0adcbdfeeaf0f6e7ba1e96 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 31 Mar 2016 10:49:16 +0100 Subject: Line length --- synapse/handlers/room.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 71f7ab3d22..a230dc37f2 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -536,7 +536,9 @@ class RoomMemberHandler(BaseHandler): if not is_host_in_room: # perhaps we've been invited - inviter = self.get_inviter(target_user.to_string(), context.current_state) + inviter = self.get_inviter( + target_user.to_string(), context.current_state + ) if not inviter: raise SynapseError(404, "Not a known room") -- cgit 1.4.1 From 0d3d7de6fcb98972532bf9aaa983ddd8befb3db8 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 31 Mar 2016 12:42:27 +0100 Subject: sync in changes from matrixfederationclient --- synapse/http/client.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/http/client.py b/synapse/http/client.py index 30f31a915d..219b734268 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -244,7 +244,7 @@ class SimpleHttpClient(object): # The two should be factored out. @defer.inlineCallbacks - def get_file(self, url, output_stream, args={}, max_size=None): + def get_file(self, url, output_stream, max_size=None): """GETs a file from a given URL Args: url (str): The URL to GET @@ -299,7 +299,11 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: - logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) + self.deferred.errback(SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + )) self.deferred = defer.Deferred() self.transport.loseConnection() -- cgit 1.4.1 From d35780eda02935165f38f3fcbf8843f40026eeaa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 31 Mar 2016 13:08:45 +0100 Subject: Split out RoomMemberHandler --- synapse/handlers/__init__.py | 3 +- synapse/handlers/room.py | 605 +------------------------------------ synapse/handlers/room_member.py | 646 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 651 insertions(+), 603 deletions(-) create mode 100644 synapse/handlers/room_member.py (limited to 'synapse') diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 66d2c01123..f4dbf47c1d 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -17,8 +17,9 @@ from synapse.appservice.scheduler import AppServiceScheduler from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( - RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler, + RoomCreationHandler, RoomListHandler, RoomContextHandler, ) +from .room_member import RoomMemberHandler from .message import MessageHandler from .events import EventStreamHandler, EventHandler from .federation import FederationHandler diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a230dc37f2..ee99ded214 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -18,20 +18,16 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester +from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.api.constants import ( - EventTypes, Membership, JoinRules, RoomCreationPreset, + EventTypes, JoinRules, RoomCreationPreset, ) -from synapse.api.errors import AuthError, StoreError, SynapseError, Codes +from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.util import stringutils, unwrapFirstError from synapse.util.logcontext import preserve_context_over_fn from synapse.util.caches.response_cache import ResponseCache -from signedjson.sign import verify_signed_json -from signedjson.key import decode_verify_key_bytes - from collections import OrderedDict -from unpaddedbase64 import decode_base64 import logging import math @@ -357,601 +353,6 @@ class RoomCreationHandler(BaseHandler): ) -class RoomMemberHandler(BaseHandler): - # TODO(paul): This handler currently contains a messy conflation of - # low-level API that works on UserID objects and so on, and REST-level - # API that takes ID strings and returns pagination chunks. These concerns - # ought to be separated out a lot better. - - def __init__(self, hs): - super(RoomMemberHandler, self).__init__(hs) - - self.clock = hs.get_clock() - - self.distributor = hs.get_distributor() - self.distributor.declare("user_joined_room") - self.distributor.declare("user_left_room") - - @defer.inlineCallbacks - def get_room_members(self, room_id): - users = yield self.store.get_users_in_room(room_id) - - defer.returnValue([UserID.from_string(u) for u in users]) - - @defer.inlineCallbacks - def fetch_room_distributions_into(self, room_id, localusers=None, - remotedomains=None, ignore_user=None): - """Fetch the distribution of a room, adding elements to either - 'localusers' or 'remotedomains', which should be a set() if supplied. - If ignore_user is set, ignore that user. - - This function returns nothing; its result is performed by the - side-effect on the two passed sets. This allows easy accumulation of - member lists of multiple rooms at once if required. - """ - members = yield self.get_room_members(room_id) - for member in members: - if ignore_user is not None and member == ignore_user: - continue - - if self.hs.is_mine(member): - if localusers is not None: - localusers.add(member) - else: - if remotedomains is not None: - remotedomains.add(member.domain) - - @defer.inlineCallbacks - def update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - ): - effective_membership_state = action - if action in ["kick", "unban"]: - effective_membership_state = "leave" - - if third_party_signed is not None: - replication = self.hs.get_replication_layer() - yield replication.exchange_third_party_invite( - third_party_signed["sender"], - target.to_string(), - room_id, - third_party_signed, - ) - - msg_handler = self.hs.get_handlers().message_handler - - content = {"membership": effective_membership_state} - if requester.is_guest: - content["kind"] = "guest" - - event, context = yield msg_handler.create_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "state_key": target.to_string(), - - # For backwards compatibility: - "membership": effective_membership_state, - }, - token_id=requester.access_token_id, - txn_id=txn_id, - ) - - old_state = context.current_state.get((EventTypes.Member, event.state_key)) - old_membership = old_state.content.get("membership") if old_state else None - if action == "unban" and old_membership != "ban": - raise SynapseError( - 403, - "Cannot unban user who was not banned (membership=%s)" % old_membership, - errcode=Codes.BAD_STATE - ) - if old_membership == "ban" and action != "unban": - raise SynapseError( - 403, - "Cannot %s user who was is banned" % (action,), - errcode=Codes.BAD_STATE - ) - - member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event( - requester, - event, - context, - ratelimit=ratelimit, - remote_room_hosts=remote_room_hosts, - ) - - @defer.inlineCallbacks - def send_membership_event( - self, - requester, - event, - context, - remote_room_hosts=None, - ratelimit=True, - ): - """ - Change the membership status of a user in a room. - - Args: - requester (Requester): The local user who requested the membership - event. If None, certain checks, like whether this homeserver can - act as the sender, will be skipped. - event (SynapseEvent): The membership event. - context: The context of the event. - is_guest (bool): Whether the sender is a guest. - room_hosts ([str]): Homeservers which are likely to already be in - the room, and could be danced with in order to join this - homeserver for the first time. - ratelimit (bool): Whether to rate limit this request. - Raises: - SynapseError if there was a problem changing the membership. - """ - remote_room_hosts = remote_room_hosts or [] - - target_user = UserID.from_string(event.state_key) - room_id = event.room_id - - if requester is not None: - sender = UserID.from_string(event.sender) - assert sender == requester.user, ( - "Sender (%s) must be same as requester (%s)" % - (sender, requester.user) - ) - assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) - else: - requester = Requester(target_user, None, False) - - message_handler = self.hs.get_handlers().message_handler - prev_event = message_handler.deduplicate_state_event(event, context) - if prev_event is not None: - return - - action = "send" - - if event.membership == Membership.JOIN: - if requester.is_guest and not self._can_guest_join(context.current_state): - # This should be an auth check, but guests are a local concept, - # so don't really fit into the general auth process. - raise AuthError(403, "Guest access not allowed") - do_remote_join_dance, remote_room_hosts = self._should_do_dance( - context, - (self.get_inviter(event.state_key, context.current_state)), - remote_room_hosts, - ) - if do_remote_join_dance: - action = "remote_join" - elif event.membership == Membership.LEAVE: - is_host_in_room = self.is_host_in_room(context.current_state) - - if not is_host_in_room: - # perhaps we've been invited - inviter = self.get_inviter( - target_user.to_string(), context.current_state - ) - if not inviter: - raise SynapseError(404, "Not a known room") - - if self.hs.is_mine(inviter): - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath. - # - # This is a bit of a hack, because the room might still be - # active on other servers. - pass - else: - # send the rejection to the inviter's HS. - remote_room_hosts = remote_room_hosts + [inviter.domain] - action = "remote_reject" - - federation_handler = self.hs.get_handlers().federation_handler - - if action == "remote_join": - if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") - - # We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - yield federation_handler.do_invite_join( - remote_room_hosts, - event.room_id, - event.user_id, - event.content, - ) - elif action == "remote_reject": - yield federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - event.user_id - ) - else: - yield self.handle_new_client_event( - requester, - event, - context, - extra_users=[target_user], - ratelimit=ratelimit, - ) - - prev_member_event = context.current_state.get( - (EventTypes.Member, target_user.to_string()), - None - ) - - if event.membership == Membership.JOIN: - if not prev_member_event or prev_member_event.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. - yield user_joined_room(self.distributor, target_user, room_id) - elif event.membership == Membership.LEAVE: - if prev_member_event and prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target_user, room_id) - - def _can_guest_join(self, current_state): - """ - Returns whether a guest can join a room based on its current state. - """ - guest_access = current_state.get((EventTypes.GuestAccess, ""), None) - return ( - guest_access - and guest_access.content - and "guest_access" in guest_access.content - and guest_access.content["guest_access"] == "can_join" - ) - - def _should_do_dance(self, context, inviter, room_hosts=None): - # TODO: Shouldn't this be remote_room_host? - room_hosts = room_hosts or [] - - is_host_in_room = self.is_host_in_room(context.current_state) - if is_host_in_room: - return False, room_hosts - - if inviter and not self.hs.is_mine(inviter): - room_hosts.append(inviter.domain) - - return True, room_hosts - - @defer.inlineCallbacks - def lookup_room_alias(self, room_alias): - """ - Get the room ID associated with a room alias. - - Args: - room_alias (RoomAlias): The alias to look up. - Returns: - A tuple of: - The room ID as a RoomID object. - Hosts likely to be participating in the room ([str]). - Raises: - SynapseError if room alias could not be found. - """ - directory_handler = self.hs.get_handlers().directory_handler - mapping = yield directory_handler.get_association(room_alias) - - if not mapping: - raise SynapseError(404, "No such room alias") - - room_id = mapping["room_id"] - servers = mapping["servers"] - - defer.returnValue((RoomID.from_string(room_id), servers)) - - def get_inviter(self, user_id, current_state): - prev_state = current_state.get((EventTypes.Member, user_id)) - if prev_state and prev_state.membership == Membership.INVITE: - return UserID.from_string(prev_state.user_id) - return None - - @defer.inlineCallbacks - def get_joined_rooms_for_user(self, user): - """Returns a list of roomids that the user has any of the given - membership states in.""" - - rooms = yield self.store.get_rooms_for_user( - user.to_string(), - ) - - # For some reason the list of events contains duplicates - # TODO(paul): work out why because I really don't think it should - room_ids = set(r.room_id for r in rooms) - - defer.returnValue(room_ids) - - @defer.inlineCallbacks - def do_3pid_invite( - self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id - ): - invitee = yield self._lookup_3pid( - id_server, medium, address - ) - - if invitee: - handler = self.hs.get_handlers().room_member_handler - yield handler.update_membership( - requester, - UserID.from_string(invitee), - room_id, - "invite", - txn_id=txn_id, - ) - else: - yield self._make_and_store_3pid_invite( - requester, - id_server, - medium, - address, - room_id, - inviter, - txn_id=txn_id - ) - - @defer.inlineCallbacks - def _lookup_3pid(self, id_server, medium, address): - """Looks up a 3pid in the passed identity server. - - Args: - id_server (str): The server name (including port, if required) - of the identity server to use. - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). - - Returns: - (str) the matrix ID of the 3pid, or None if it is not recognized. - """ - try: - data = yield self.hs.get_simple_http_client().get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), - { - "medium": medium, - "address": address, - } - ) - - if "mxid" in data: - if "signatures" not in data: - raise AuthError(401, "No signatures on 3pid binding") - self.verify_any_signature(data, id_server) - defer.returnValue(data["mxid"]) - - except IOError as e: - logger.warn("Error from identity server lookup: %s" % (e,)) - defer.returnValue(None) - - @defer.inlineCallbacks - def verify_any_signature(self, data, server_hostname): - if server_hostname not in data["signatures"]: - raise AuthError(401, "No signature from server %s" % (server_hostname,)) - for key_name, signature in data["signatures"][server_hostname].items(): - key_data = yield self.hs.get_simple_http_client().get_json( - "%s%s/_matrix/identity/api/v1/pubkey/%s" % - (id_server_scheme, server_hostname, key_name,), - ) - if "public_key" not in key_data: - raise AuthError(401, "No public key named %s from %s" % - (key_name, server_hostname,)) - verify_signed_json( - data, - server_hostname, - decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) - ) - return - - @defer.inlineCallbacks - def _make_and_store_3pid_invite( - self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id - ): - room_state = yield self.hs.get_state_handler().get_current_state(room_id) - - inviter_display_name = "" - inviter_avatar_url = "" - member_event = room_state.get((EventTypes.Member, user.to_string())) - if member_event: - inviter_display_name = member_event.content.get("displayname", "") - inviter_avatar_url = member_event.content.get("avatar_url", "") - - canonical_room_alias = "" - canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) - if canonical_alias_event: - canonical_room_alias = canonical_alias_event.content.get("alias", "") - - room_name = "" - room_name_event = room_state.get((EventTypes.Name, "")) - if room_name_event: - room_name = room_name_event.content.get("name", "") - - room_join_rules = "" - join_rules_event = room_state.get((EventTypes.JoinRules, "")) - if join_rules_event: - room_join_rules = join_rules_event.content.get("join_rule", "") - - room_avatar_url = "" - room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) - if room_avatar_event: - room_avatar_url = room_avatar_event.content.get("url", "") - - token, public_keys, fallback_public_key, display_name = ( - yield self._ask_id_server_for_third_party_invite( - id_server=id_server, - medium=medium, - address=address, - room_id=room_id, - inviter_user_id=user.to_string(), - room_alias=canonical_room_alias, - room_avatar_url=room_avatar_url, - room_join_rules=room_join_rules, - room_name=room_name, - inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url - ) - ) - - msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.ThirdPartyInvite, - "content": { - "display_name": display_name, - "public_keys": public_keys, - - # For backwards compatibility: - "key_validity_url": fallback_public_key["key_validity_url"], - "public_key": fallback_public_key["public_key"], - }, - "room_id": room_id, - "sender": user.to_string(), - "state_key": token, - }, - txn_id=txn_id, - ) - - @defer.inlineCallbacks - def _ask_id_server_for_third_party_invite( - self, - id_server, - medium, - address, - room_id, - inviter_user_id, - room_alias, - room_avatar_url, - room_join_rules, - room_name, - inviter_display_name, - inviter_avatar_url - ): - """ - Asks an identity server for a third party invite. - - :param id_server (str): hostname + optional port for the identity server. - :param medium (str): The literal string "email". - :param address (str): The third party address being invited. - :param room_id (str): The ID of the room to which the user is invited. - :param inviter_user_id (str): The user ID of the inviter. - :param room_alias (str): An alias for the room, for cosmetic - notifications. - :param room_avatar_url (str): The URL of the room's avatar, for cosmetic - notifications. - :param room_join_rules (str): The join rules of the email - (e.g. "public"). - :param room_name (str): The m.room.name of the room. - :param inviter_display_name (str): The current display name of the - inviter. - :param inviter_avatar_url (str): The URL of the inviter's avatar. - - :return: A deferred tuple containing: - token (str): The token which must be signed to prove authenticity. - public_keys ([{"public_key": str, "key_validity_url": str}]): - public_key is a base64-encoded ed25519 public key. - fallback_public_key: One element from public_keys. - display_name (str): A user-friendly name to represent the invited - user. - """ - - is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( - id_server_scheme, id_server, - ) - - invite_config = { - "medium": medium, - "address": address, - "room_id": room_id, - "room_alias": room_alias, - "room_avatar_url": room_avatar_url, - "room_join_rules": room_join_rules, - "room_name": room_name, - "sender": inviter_user_id, - "sender_display_name": inviter_display_name, - "sender_avatar_url": inviter_avatar_url, - } - - if self.hs.config.invite_3pid_guest: - registration_handler = self.hs.get_handlers().registration_handler - guest_access_token = yield registration_handler.guest_access_token_for( - medium=medium, - address=address, - inviter_user_id=inviter_user_id, - ) - - guest_user_info = yield self.hs.get_auth().get_user_by_access_token( - guest_access_token - ) - - invite_config.update({ - "guest_access_token": guest_access_token, - "guest_user_id": guest_user_info["user"].to_string(), - }) - - data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( - is_url, - invite_config - ) - # TODO: Check for success - token = data["token"] - public_keys = data.get("public_keys", []) - if "public_key" in data: - fallback_public_key = { - "public_key": data["public_key"], - "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, id_server, - ), - } - else: - fallback_public_key = public_keys[0] - - if not public_keys: - public_keys.append(fallback_public_key) - display_name = data["display_name"] - defer.returnValue((token, public_keys, fallback_public_key, display_name)) - - @defer.inlineCallbacks - def forget(self, user, room_id): - user_id = user.to_string() - - member = yield self.state_handler.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id - ) - membership = member.membership if member else None - - if membership is not None and membership != Membership.LEAVE: - raise SynapseError(400, "User %s in room %s" % ( - user_id, room_id - )) - - if membership: - yield self.store.forget(user_id, room_id) - - class RoomListHandler(BaseHandler): def __init__(self, hs): super(RoomListHandler, self).__init__(hs) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py new file mode 100644 index 0000000000..5fdbd3adcc --- /dev/null +++ b/synapse/handlers/room_member.py @@ -0,0 +1,646 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from ._base import BaseHandler + +from synapse.types import UserID, RoomID, Requester +from synapse.api.constants import ( + EventTypes, Membership, +) +from synapse.api.errors import AuthError, SynapseError, Codes +from synapse.util.logcontext import preserve_context_over_fn + +from signedjson.sign import verify_signed_json +from signedjson.key import decode_verify_key_bytes + +from unpaddedbase64 import decode_base64 + +import logging + +logger = logging.getLogger(__name__) + +id_server_scheme = "https://" + + +def user_left_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_left_room", user=user, room_id=room_id + ) + + +def user_joined_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_joined_room", user=user, room_id=room_id + ) + + +class RoomMemberHandler(BaseHandler): + # TODO(paul): This handler currently contains a messy conflation of + # low-level API that works on UserID objects and so on, and REST-level + # API that takes ID strings and returns pagination chunks. These concerns + # ought to be separated out a lot better. + + def __init__(self, hs): + super(RoomMemberHandler, self).__init__(hs) + + self.clock = hs.get_clock() + + self.distributor = hs.get_distributor() + self.distributor.declare("user_joined_room") + self.distributor.declare("user_left_room") + + @defer.inlineCallbacks + def get_room_members(self, room_id): + users = yield self.store.get_users_in_room(room_id) + + defer.returnValue([UserID.from_string(u) for u in users]) + + @defer.inlineCallbacks + def fetch_room_distributions_into(self, room_id, localusers=None, + remotedomains=None, ignore_user=None): + """Fetch the distribution of a room, adding elements to either + 'localusers' or 'remotedomains', which should be a set() if supplied. + If ignore_user is set, ignore that user. + + This function returns nothing; its result is performed by the + side-effect on the two passed sets. This allows easy accumulation of + member lists of multiple rooms at once if required. + """ + members = yield self.get_room_members(room_id) + for member in members: + if ignore_user is not None and member == ignore_user: + continue + + if self.hs.is_mine(member): + if localusers is not None: + localusers.add(member) + else: + if remotedomains is not None: + remotedomains.add(member.domain) + + @defer.inlineCallbacks + def update_membership( + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + ): + effective_membership_state = action + if action in ["kick", "unban"]: + effective_membership_state = "leave" + + if third_party_signed is not None: + replication = self.hs.get_replication_layer() + yield replication.exchange_third_party_invite( + third_party_signed["sender"], + target.to_string(), + room_id, + third_party_signed, + ) + + msg_handler = self.hs.get_handlers().message_handler + + content = {"membership": effective_membership_state} + if requester.is_guest: + content["kind"] = "guest" + + event, context = yield msg_handler.create_event( + { + "type": EventTypes.Member, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "state_key": target.to_string(), + + # For backwards compatibility: + "membership": effective_membership_state, + }, + token_id=requester.access_token_id, + txn_id=txn_id, + ) + + old_state = context.current_state.get((EventTypes.Member, event.state_key)) + old_membership = old_state.content.get("membership") if old_state else None + if action == "unban" and old_membership != "ban": + raise SynapseError( + 403, + "Cannot unban user who was not banned (membership=%s)" % old_membership, + errcode=Codes.BAD_STATE + ) + if old_membership == "ban" and action != "unban": + raise SynapseError( + 403, + "Cannot %s user who was is banned" % (action,), + errcode=Codes.BAD_STATE + ) + + member_handler = self.hs.get_handlers().room_member_handler + yield member_handler.send_membership_event( + requester, + event, + context, + ratelimit=ratelimit, + remote_room_hosts=remote_room_hosts, + ) + + @defer.inlineCallbacks + def send_membership_event( + self, + requester, + event, + context, + remote_room_hosts=None, + ratelimit=True, + ): + """ + Change the membership status of a user in a room. + + Args: + requester (Requester): The local user who requested the membership + event. If None, certain checks, like whether this homeserver can + act as the sender, will be skipped. + event (SynapseEvent): The membership event. + context: The context of the event. + is_guest (bool): Whether the sender is a guest. + room_hosts ([str]): Homeservers which are likely to already be in + the room, and could be danced with in order to join this + homeserver for the first time. + ratelimit (bool): Whether to rate limit this request. + Raises: + SynapseError if there was a problem changing the membership. + """ + remote_room_hosts = remote_room_hosts or [] + + target_user = UserID.from_string(event.state_key) + room_id = event.room_id + + if requester is not None: + sender = UserID.from_string(event.sender) + assert sender == requester.user, ( + "Sender (%s) must be same as requester (%s)" % + (sender, requester.user) + ) + assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) + else: + requester = Requester(target_user, None, False) + + message_handler = self.hs.get_handlers().message_handler + prev_event = message_handler.deduplicate_state_event(event, context) + if prev_event is not None: + return + + action = "send" + + if event.membership == Membership.JOIN: + if requester.is_guest and not self._can_guest_join(context.current_state): + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") + do_remote_join_dance, remote_room_hosts = self._should_do_dance( + context, + (self.get_inviter(event.state_key, context.current_state)), + remote_room_hosts, + ) + if do_remote_join_dance: + action = "remote_join" + elif event.membership == Membership.LEAVE: + is_host_in_room = self.is_host_in_room(context.current_state) + + if not is_host_in_room: + # perhaps we've been invited + inviter = self.get_inviter( + target_user.to_string(), context.current_state + ) + if not inviter: + raise SynapseError(404, "Not a known room") + + if self.hs.is_mine(inviter): + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath. + # + # This is a bit of a hack, because the room might still be + # active on other servers. + pass + else: + # send the rejection to the inviter's HS. + remote_room_hosts = remote_room_hosts + [inviter.domain] + action = "remote_reject" + + federation_handler = self.hs.get_handlers().federation_handler + + if action == "remote_join": + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield federation_handler.do_invite_join( + remote_room_hosts, + event.room_id, + event.user_id, + event.content, + ) + elif action == "remote_reject": + yield federation_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + event.user_id + ) + else: + yield self.handle_new_client_event( + requester, + event, + context, + extra_users=[target_user], + ratelimit=ratelimit, + ) + + prev_member_event = context.current_state.get( + (EventTypes.Member, target_user.to_string()), + None + ) + + if event.membership == Membership.JOIN: + if not prev_member_event or prev_member_event.membership != Membership.JOIN: + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + yield user_joined_room(self.distributor, target_user, room_id) + elif event.membership == Membership.LEAVE: + if prev_member_event and prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target_user, room_id) + + def _can_guest_join(self, current_state): + """ + Returns whether a guest can join a room based on its current state. + """ + guest_access = current_state.get((EventTypes.GuestAccess, ""), None) + return ( + guest_access + and guest_access.content + and "guest_access" in guest_access.content + and guest_access.content["guest_access"] == "can_join" + ) + + def _should_do_dance(self, context, inviter, room_hosts=None): + # TODO: Shouldn't this be remote_room_host? + room_hosts = room_hosts or [] + + is_host_in_room = self.is_host_in_room(context.current_state) + if is_host_in_room: + return False, room_hosts + + if inviter and not self.hs.is_mine(inviter): + room_hosts.append(inviter.domain) + + return True, room_hosts + + @defer.inlineCallbacks + def lookup_room_alias(self, room_alias): + """ + Get the room ID associated with a room alias. + + Args: + room_alias (RoomAlias): The alias to look up. + Returns: + A tuple of: + The room ID as a RoomID object. + Hosts likely to be participating in the room ([str]). + Raises: + SynapseError if room alias could not be found. + """ + directory_handler = self.hs.get_handlers().directory_handler + mapping = yield directory_handler.get_association(room_alias) + + if not mapping: + raise SynapseError(404, "No such room alias") + + room_id = mapping["room_id"] + servers = mapping["servers"] + + defer.returnValue((RoomID.from_string(room_id), servers)) + + def get_inviter(self, user_id, current_state): + prev_state = current_state.get((EventTypes.Member, user_id)) + if prev_state and prev_state.membership == Membership.INVITE: + return UserID.from_string(prev_state.user_id) + return None + + @defer.inlineCallbacks + def get_joined_rooms_for_user(self, user): + """Returns a list of roomids that the user has any of the given + membership states in.""" + + rooms = yield self.store.get_rooms_for_user( + user.to_string(), + ) + + # For some reason the list of events contains duplicates + # TODO(paul): work out why because I really don't think it should + room_ids = set(r.room_id for r in rooms) + + defer.returnValue(room_ids) + + @defer.inlineCallbacks + def do_3pid_invite( + self, + room_id, + inviter, + medium, + address, + id_server, + requester, + txn_id + ): + invitee = yield self._lookup_3pid( + id_server, medium, address + ) + + if invitee: + handler = self.hs.get_handlers().room_member_handler + yield handler.update_membership( + requester, + UserID.from_string(invitee), + room_id, + "invite", + txn_id=txn_id, + ) + else: + yield self._make_and_store_3pid_invite( + requester, + id_server, + medium, + address, + room_id, + inviter, + txn_id=txn_id + ) + + @defer.inlineCallbacks + def _lookup_3pid(self, id_server, medium, address): + """Looks up a 3pid in the passed identity server. + + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + medium (str): The type of the third party identifier (e.g. "email"). + address (str): The third party identifier (e.g. "foo@example.com"). + + Returns: + (str) the matrix ID of the 3pid, or None if it is not recognized. + """ + try: + data = yield self.hs.get_simple_http_client().get_json( + "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), + { + "medium": medium, + "address": address, + } + ) + + if "mxid" in data: + if "signatures" not in data: + raise AuthError(401, "No signatures on 3pid binding") + self.verify_any_signature(data, id_server) + defer.returnValue(data["mxid"]) + + except IOError as e: + logger.warn("Error from identity server lookup: %s" % (e,)) + defer.returnValue(None) + + @defer.inlineCallbacks + def verify_any_signature(self, data, server_hostname): + if server_hostname not in data["signatures"]: + raise AuthError(401, "No signature from server %s" % (server_hostname,)) + for key_name, signature in data["signatures"][server_hostname].items(): + key_data = yield self.hs.get_simple_http_client().get_json( + "%s%s/_matrix/identity/api/v1/pubkey/%s" % + (id_server_scheme, server_hostname, key_name,), + ) + if "public_key" not in key_data: + raise AuthError(401, "No public key named %s from %s" % + (key_name, server_hostname,)) + verify_signed_json( + data, + server_hostname, + decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) + ) + return + + @defer.inlineCallbacks + def _make_and_store_3pid_invite( + self, + requester, + id_server, + medium, + address, + room_id, + user, + txn_id + ): + room_state = yield self.hs.get_state_handler().get_current_state(room_id) + + inviter_display_name = "" + inviter_avatar_url = "" + member_event = room_state.get((EventTypes.Member, user.to_string())) + if member_event: + inviter_display_name = member_event.content.get("displayname", "") + inviter_avatar_url = member_event.content.get("avatar_url", "") + + canonical_room_alias = "" + canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) + if canonical_alias_event: + canonical_room_alias = canonical_alias_event.content.get("alias", "") + + room_name = "" + room_name_event = room_state.get((EventTypes.Name, "")) + if room_name_event: + room_name = room_name_event.content.get("name", "") + + room_join_rules = "" + join_rules_event = room_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + room_join_rules = join_rules_event.content.get("join_rule", "") + + room_avatar_url = "" + room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) + if room_avatar_event: + room_avatar_url = room_avatar_event.content.get("url", "") + + token, public_keys, fallback_public_key, display_name = ( + yield self._ask_id_server_for_third_party_invite( + id_server=id_server, + medium=medium, + address=address, + room_id=room_id, + inviter_user_id=user.to_string(), + room_alias=canonical_room_alias, + room_avatar_url=room_avatar_url, + room_join_rules=room_join_rules, + room_name=room_name, + inviter_display_name=inviter_display_name, + inviter_avatar_url=inviter_avatar_url + ) + ) + + msg_handler = self.hs.get_handlers().message_handler + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.ThirdPartyInvite, + "content": { + "display_name": display_name, + "public_keys": public_keys, + + # For backwards compatibility: + "key_validity_url": fallback_public_key["key_validity_url"], + "public_key": fallback_public_key["public_key"], + }, + "room_id": room_id, + "sender": user.to_string(), + "state_key": token, + }, + txn_id=txn_id, + ) + + @defer.inlineCallbacks + def _ask_id_server_for_third_party_invite( + self, + id_server, + medium, + address, + room_id, + inviter_user_id, + room_alias, + room_avatar_url, + room_join_rules, + room_name, + inviter_display_name, + inviter_avatar_url + ): + """ + Asks an identity server for a third party invite. + + :param id_server (str): hostname + optional port for the identity server. + :param medium (str): The literal string "email". + :param address (str): The third party address being invited. + :param room_id (str): The ID of the room to which the user is invited. + :param inviter_user_id (str): The user ID of the inviter. + :param room_alias (str): An alias for the room, for cosmetic + notifications. + :param room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + :param room_join_rules (str): The join rules of the email + (e.g. "public"). + :param room_name (str): The m.room.name of the room. + :param inviter_display_name (str): The current display name of the + inviter. + :param inviter_avatar_url (str): The URL of the inviter's avatar. + + :return: A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. + """ + + is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( + id_server_scheme, id_server, + ) + + invite_config = { + "medium": medium, + "address": address, + "room_id": room_id, + "room_alias": room_alias, + "room_avatar_url": room_avatar_url, + "room_join_rules": room_join_rules, + "room_name": room_name, + "sender": inviter_user_id, + "sender_display_name": inviter_display_name, + "sender_avatar_url": inviter_avatar_url, + } + + if self.hs.config.invite_3pid_guest: + registration_handler = self.hs.get_handlers().registration_handler + guest_access_token = yield registration_handler.guest_access_token_for( + medium=medium, + address=address, + inviter_user_id=inviter_user_id, + ) + + guest_user_info = yield self.hs.get_auth().get_user_by_access_token( + guest_access_token + ) + + invite_config.update({ + "guest_access_token": guest_access_token, + "guest_user_id": guest_user_info["user"].to_string(), + }) + + data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( + is_url, + invite_config + ) + # TODO: Check for success + token = data["token"] + public_keys = data.get("public_keys", []) + if "public_key" in data: + fallback_public_key = { + "public_key": data["public_key"], + "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_scheme, id_server, + ), + } + else: + fallback_public_key = public_keys[0] + + if not public_keys: + public_keys.append(fallback_public_key) + display_name = data["display_name"] + defer.returnValue((token, public_keys, fallback_public_key, display_name)) + + @defer.inlineCallbacks + def forget(self, user, room_id): + user_id = user.to_string() + + member = yield self.state_handler.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None + + if membership is not None and membership != Membership.LEAVE: + raise SynapseError(400, "User %s in room %s" % ( + user_id, room_id + )) + + if membership: + yield self.store.forget(user_id, room_id) -- cgit 1.4.1 From bb9a2ca87c280e1c6ff6740ee9d2764e1b5226a5 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 31 Mar 2016 14:15:09 +0100 Subject: synthesise basig OG metadata from pages lacking it --- synapse/rest/media/v1/preview_url_resource.py | 47 +++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index ca2529cc10..b1d5cabfaa 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -23,6 +23,7 @@ from synapse.http.client import SpiderHttpClient from synapse.http.server import request_handler, respond_with_json, respond_with_json_bytes import os +import re import ujson as json import logging @@ -70,6 +71,7 @@ class PreviewUrlResource(BaseMediaResource): # define our OG response for this media elif self._is_html(media_info['media_type']): + # TODO: somehow stop a big HTML tree from exploding synapse's RAM tree = html.parse(media_info['filename']) # suck it up into lxml and define our OG response. @@ -82,17 +84,58 @@ class PreviewUrlResource(BaseMediaResource): # "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png" # "og:description" : "Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP" # "og:site_name" : "Twitter" + + # or: + + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : " ", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", og = {} for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): og[tag.attrib['property']] = tag.attrib['content'] + if not og: + # do some basic spidering of the HTML + title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") + og['og:title'] = title[0].text if title else None + + images = tree.xpath("//img") + big_images = [ i for i in images if ( + 'width' in i and 'height' in i and + i.attrib['width'] > 64 and i.attrib['height'] > 64 + )] or images + og['og:image'] = images[0].attrib['src'] if images else None + + text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text() | //span/text() | //a/text()") + text = '' + for text_node in text_nodes: + if len(text) < 1024: + text += text_node + ' ' + else: + break + text = re.sub(r'[\t ]+', ' ', text) + text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text) + text = text.strip()[:1024] + og['og:description'] = text if text else None + + # TODO: turn any OG media URLs into mxc URLs to capture and thumbnail them too # TODO: store our OG details in a cache (and expire them when stale) # TODO: delete the content to stop diskfilling, as we only ever cared about its OG else: logger.warn("Failed to find any OG data in %s", url) og = {} + logger.warn(og) + respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) except: # XXX: if we don't explicitly respond here, the request never returns. @@ -111,6 +154,10 @@ class PreviewUrlResource(BaseMediaResource): @defer.inlineCallbacks def _download_url(self, url, user): + # TODO: we should probably honour robots.txt... except in practice + # we're most likely being explicitly triggered by a human rather than a + # bot, so are we really a robot? + # XXX: horrible duplication with base_resource's _download_remote_file() file_id = random_string(24) -- cgit 1.4.1 From 76503f95ed2b618a3ff97c6b04868d8e440ef9d4 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 31 Mar 2016 15:00:42 +0100 Subject: Remove the is_new_state argument to persist event. Move the checks for whether an event is new state inside persist event itself. This was harder than expected because there wasn't enough information passed to persist event to correctly handle invites from remote servers for new rooms. --- synapse/events/__init__.py | 3 ++ synapse/handlers/federation.py | 20 ++-------- synapse/storage/events.py | 90 +++++++++++++++++++++++------------------- 3 files changed, 57 insertions(+), 56 deletions(-) (limited to 'synapse') diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 925a83c645..13154b1723 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -33,6 +33,9 @@ class _EventInternalMetadata(object): def is_outlier(self): return getattr(self, "outlier", False) + def is_invite_from_remote(self): + return getattr(self, "invite_from_remote", False) + def _event_dict_property(key): def getter(self): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 267fedf114..4a35344d32 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -102,8 +102,7 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, state=None, - auth_chain=None): + def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): """ Called by the ReplicationLayer when we have a new pdu. We need to do auth checks and put it through the StateHandler. """ @@ -174,11 +173,7 @@ class FederationHandler(BaseHandler): }) seen_ids.add(e.event_id) - yield self._handle_new_events( - origin, - event_infos, - outliers=True - ) + yield self._handle_new_events(origin, event_infos) try: context, event_stream_id, max_stream_id = yield self._handle_new_event( @@ -761,6 +756,7 @@ class FederationHandler(BaseHandler): event = pdu event.internal_metadata.outlier = True + event.internal_metadata.invite_from_remote = True event.signatures.update( compute_event_signature( @@ -1069,9 +1065,6 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function def _handle_new_event(self, origin, event, state=None, auth_events=None): - - outlier = event.internal_metadata.is_outlier() - context = yield self._prep_event( origin, event, state=state, @@ -1087,14 +1080,12 @@ class FederationHandler(BaseHandler): event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, - is_new_state=not outlier, ) defer.returnValue((context, event_stream_id, max_stream_id)) @defer.inlineCallbacks - def _handle_new_events(self, origin, event_infos, backfilled=False, - outliers=False): + def _handle_new_events(self, origin, event_infos, backfilled=False): contexts = yield defer.gatherResults( [ self._prep_event( @@ -1113,7 +1104,6 @@ class FederationHandler(BaseHandler): for ev_info, context in itertools.izip(event_infos, contexts) ], backfilled=backfilled, - is_new_state=(not outliers and not backfilled), ) @defer.inlineCallbacks @@ -1176,7 +1166,6 @@ class FederationHandler(BaseHandler): (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) ], - is_new_state=False, ) new_event_context = yield self.state_handler.compute_event_context( @@ -1185,7 +1174,6 @@ class FederationHandler(BaseHandler): event_stream_id, max_stream_id = yield self.store.persist_event( event, new_event_context, - is_new_state=True, current_state=state, ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index dc3e994de9..b5f9b3b900 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -61,8 +61,7 @@ class EventsStore(SQLBaseStore): ) @defer.inlineCallbacks - def persist_events(self, events_and_contexts, backfilled=False, - is_new_state=True): + def persist_events(self, events_and_contexts, backfilled=False): if not events_and_contexts: return @@ -110,13 +109,11 @@ class EventsStore(SQLBaseStore): self._persist_events_txn, events_and_contexts=chunk, backfilled=backfilled, - is_new_state=is_new_state, ) @defer.inlineCallbacks @log_function - def persist_event(self, event, context, - is_new_state=True, current_state=None): + def persist_event(self, event, context, current_state=None): try: with self._stream_id_gen.get_next() as stream_ordering: @@ -128,7 +125,6 @@ class EventsStore(SQLBaseStore): self._persist_event_txn, event=event, context=context, - is_new_state=is_new_state, current_state=current_state, ) except _RollbackButIsFineException: @@ -194,8 +190,7 @@ class EventsStore(SQLBaseStore): defer.returnValue({e.event_id: e for e in events}) @log_function - def _persist_event_txn(self, txn, event, context, - is_new_state, current_state): + def _persist_event_txn(self, txn, event, context, current_state): # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: @@ -236,12 +231,10 @@ class EventsStore(SQLBaseStore): txn, [(event, context)], backfilled=False, - is_new_state=is_new_state, ) @log_function - def _persist_events_txn(self, txn, events_and_contexts, backfilled, - is_new_state): + def _persist_events_txn(self, txn, events_and_contexts, backfilled): depth_updates = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids @@ -452,10 +445,9 @@ class EventsStore(SQLBaseStore): txn, [event for event, _ in events_and_contexts] ) - state_events_and_contexts = filter( - lambda i: i[0].is_state(), - events_and_contexts, - ) + state_events_and_contexts = [ + ec for ec in events_and_contexts if ec[0].is_state() + ] state_values = [] for event, context in state_events_and_contexts: @@ -493,32 +485,50 @@ class EventsStore(SQLBaseStore): ], ) - if is_new_state: - for event, _ in state_events_and_contexts: - if not context.rejected: - txn.call_after( - self._get_current_state_for_key.invalidate, - (event.room_id, event.type, event.state_key,) - ) + for event, _ in state_events_and_contexts: + if backfilled: + # Backfilled events come before the current state so shouldn't + # clobber it. + continue - if event.type in [EventTypes.Name, EventTypes.Aliases]: - txn.call_after( - self.get_room_name_and_aliases.invalidate, - (event.room_id,) - ) - - self._simple_upsert_txn( - txn, - "current_state_events", - keyvalues={ - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - values={ - "event_id": event.event_id, - } - ) + if (not event.internal_metadata.is_invite_from_remote() + and event.internal_metadata.is_outlier()): + # Outlier events generally shouldn't clobber the current state. + # However invites from remote severs for rooms we aren't in + # are a bit special: they don't come with any associated + # state so are technically an outlier, however all the + # client-facing code assumes that they are in the current + # state table so we insert the event anyway. + continue + + if context.rejected: + # If the event failed it's auth checks then it shouldn't + # clobbler the current state. + continue + + txn.call_after( + self._get_current_state_for_key.invalidate, + (event.room_id, event.type, event.state_key,) + ) + + if event.type in [EventTypes.Name, EventTypes.Aliases]: + txn.call_after( + self.get_room_name_and_aliases.invalidate, + (event.room_id,) + ) + + self._simple_upsert_txn( + txn, + "current_state_events", + keyvalues={ + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + }, + values={ + "event_id": event.event_id, + } + ) return -- cgit 1.4.1 From 5d069291696c62c94d795506dd5a22a3dbd019e1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 31 Mar 2016 15:09:09 +0100 Subject: Move the check for backfilled outside the for loop --- synapse/storage/events.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index b5f9b3b900..83279d65fa 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -485,12 +485,12 @@ class EventsStore(SQLBaseStore): ], ) - for event, _ in state_events_and_contexts: - if backfilled: - # Backfilled events come before the current state so shouldn't - # clobber it. - continue + if backfilled: + # Backfilled events come before the current state so we don't need + # to update the current state table + return + for event, _ in state_events_and_contexts: if (not event.internal_metadata.is_invite_from_remote() and event.internal_metadata.is_outlier()): # Outlier events generally shouldn't clobber the current state. -- cgit 1.4.1 From 72550c3803e5020aa377f8d10c0c20afd4273c0d Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 31 Mar 2016 15:14:14 +0100 Subject: prevent choking on invalid utf-8, and handle image thumbnailing smarter --- synapse/rest/media/v1/preview_url_resource.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index b1d5cabfaa..04d02ee427 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -72,7 +72,15 @@ class PreviewUrlResource(BaseMediaResource): # define our OG response for this media elif self._is_html(media_info['media_type']): # TODO: somehow stop a big HTML tree from exploding synapse's RAM - tree = html.parse(media_info['filename']) + + # XXX: can't work out how to make lxml ignore UTF8 decoding errors + # so slurp as a string at this point. + file = open(media_info['filename']) + body = file.read() + file.close() + # FIXME: we shouldn't be forcing utf-8 if the page isn't actually utf-8... + tree = html.fromstring(body.decode('utf-8','ignore')) + # tree = html.parse(media_info['filename']) # suck it up into lxml and define our OG response. # if we see any URLs in the OG response, then spider them @@ -108,14 +116,19 @@ class PreviewUrlResource(BaseMediaResource): title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") og['og:title'] = title[0].text if title else None - images = tree.xpath("//img") + images = [ i for i in tree.xpath("//img") if 'src' in i.attrib ] big_images = [ i for i in images if ( - 'width' in i and 'height' in i and + 'width' in i.attrib and 'height' in i.attrib and i.attrib['width'] > 64 and i.attrib['height'] > 64 - )] or images - og['og:image'] = images[0].attrib['src'] if images else None + )] + big_images = big_images.sort(key=lambda i: (-1 * int(i.attrib['width']) * int(i.attrib['height']))) + images = big_images if big_images else images + + if images: + og['og:image'] = images[0].attrib['src'] text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text() | //span/text() | //a/text()") + # text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text()") text = '' for text_node in text_nodes: if len(text) < 1024: -- cgit 1.4.1 From dc4c1579d48b7db8264a935be3e955b779c78ab6 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 31 Mar 2016 15:32:24 +0100 Subject: Remove outlier parameter from compute_event_context Use event.internal_metadata.is_outlier instead. --- synapse/handlers/_base.py | 3 +-- synapse/handlers/federation.py | 11 ++++------- synapse/state.py | 4 ++-- 3 files changed, 7 insertions(+), 11 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90eabb6eb7..d407eaeee9 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -261,8 +261,7 @@ class BaseHandler(object): context = yield state_handler.compute_event_context( builder, - old_state=(prev_member_event,), - outlier=True + old_state=(prev_member_event,) ) if builder.is_state(): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4a35344d32..4049c01d26 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1118,11 +1118,9 @@ class FederationHandler(BaseHandler): """ events_to_context = {} for e in itertools.chain(auth_events, state): - ctx = yield self.state_handler.compute_event_context( - e, outlier=True, - ) - events_to_context[e.event_id] = ctx e.internal_metadata.outlier = True + ctx = yield self.state_handler.compute_event_context(e) + events_to_context[e.event_id] = ctx event_map = { e.event_id: e @@ -1169,7 +1167,7 @@ class FederationHandler(BaseHandler): ) new_event_context = yield self.state_handler.compute_event_context( - event, old_state=state, outlier=False, + event, old_state=state ) event_stream_id, max_stream_id = yield self.store.persist_event( @@ -1181,10 +1179,9 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _prep_event(self, origin, event, state=None, auth_events=None): - outlier = event.internal_metadata.is_outlier() context = yield self.state_handler.compute_event_context( - event, old_state=state, outlier=outlier, + event, old_state=state, ) if not auth_events: diff --git a/synapse/state.py b/synapse/state.py index 41d32e664a..4672ada1b3 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -100,7 +100,7 @@ class StateHandler(object): defer.returnValue(state) @defer.inlineCallbacks - def compute_event_context(self, event, old_state=None, outlier=False): + def compute_event_context(self, event, old_state=None): """ Fills out the context with the `current state` of the graph. The `current state` here is defined to be the state of the event graph just before the event - i.e. it never includes `event` @@ -115,7 +115,7 @@ class StateHandler(object): """ context = EventContext() - if outlier: + if event.internal_metadata.is_outlier(): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. -- cgit 1.4.1 From 683e564815be5f7852c417cbab06876db6122401 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 31 Mar 2016 23:52:58 +0100 Subject: handle spidered relative images correctly --- synapse/http/client.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/http/client.py b/synapse/http/client.py index 219b734268..1b6f7cb795 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -282,7 +282,7 @@ class SimpleHttpClient(object): logger.exception("Failed to download body") raise - defer.returnValue((length, headers)) + defer.returnValue((length, headers, response.request.absoluteURI)) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 04d02ee427..bae3905a43 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -18,6 +18,7 @@ from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from twisted.internet import defer from lxml import html +from urlparse import urlparse, urlunparse from synapse.util.stringutils import random_string from synapse.http.client import SpiderHttpClient from synapse.http.server import request_handler, respond_with_json, respond_with_json_bytes @@ -125,7 +126,14 @@ class PreviewUrlResource(BaseMediaResource): images = big_images if big_images else images if images: - og['og:image'] = images[0].attrib['src'] + base = list(urlparse(media_info['uri'])) + src = list(urlparse(images[0].attrib['src'])) + if not src[0] and not src[1]: + src[0] = base[0] + src[1] = base[1] + if not src[2].startswith('/'): + src[2] = re.sub(r'/[^/]+$', '/', base[2]) + src[2] + og['og:image'] = urlunparse(src) text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text() | //span/text() | //a/text()") # text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text()") @@ -140,6 +148,7 @@ class PreviewUrlResource(BaseMediaResource): text = text.strip()[:1024] og['og:description'] = text if text else None + # TODO: extract a favicon? # TODO: turn any OG media URLs into mxc URLs to capture and thumbnail them too # TODO: store our OG details in a cache (and expire them when stale) # TODO: delete the content to stop diskfilling, as we only ever cared about its OG @@ -180,7 +189,7 @@ class PreviewUrlResource(BaseMediaResource): try: with open(fname, "wb") as f: logger.debug("Trying to get url '%s'" % url) - length, headers = yield self.client.get_file( + length, headers, uri = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) # FIXME: handle 404s sanely - don't spider an error page @@ -233,6 +242,7 @@ class PreviewUrlResource(BaseMediaResource): "created_ts": time_now_ms, "filesystem_id": file_id, "filename": fname, + "uri": uri, }) def _is_media(self, content_type): -- cgit 1.4.1 From c60b751694bbeb82105eb828d41c0b5c26d5e195 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Fri, 1 Apr 2016 02:17:48 +0100 Subject: fix assorted redirect, unicode and screenscraping bugs --- synapse/rest/media/v1/preview_url_resource.py | 174 ++++++++++++++------------ 1 file changed, 96 insertions(+), 78 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index bae3905a43..a7ffe593b1 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -74,84 +74,93 @@ class PreviewUrlResource(BaseMediaResource): elif self._is_html(media_info['media_type']): # TODO: somehow stop a big HTML tree from exploding synapse's RAM - # XXX: can't work out how to make lxml ignore UTF8 decoding errors - # so slurp as a string at this point. - file = open(media_info['filename']) - body = file.read() - file.close() - # FIXME: we shouldn't be forcing utf-8 if the page isn't actually utf-8... - tree = html.fromstring(body.decode('utf-8','ignore')) - # tree = html.parse(media_info['filename']) - - # suck it up into lxml and define our OG response. - # if we see any URLs in the OG response, then spider them - # (although the client could choose to do this by asking for previews of those URLs to avoid DoSing the server) - - # "og:type" : "article" - # "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672" - # "og:title" : "Matrix on Twitter" - # "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png" - # "og:description" : "Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP" - # "og:site_name" : "Twitter" - - # or: - - # "og:type" : "video", - # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", - # "og:site_name" : "YouTube", - # "og:video:type" : "application/x-shockwave-flash", - # "og:description" : " ", - # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", - # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", - # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", - # "og:video:width" : "1280" - # "og:video:height" : "720", - # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", - - og = {} - for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - og[tag.attrib['property']] = tag.attrib['content'] - - if not og: - # do some basic spidering of the HTML - title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") - og['og:title'] = title[0].text if title else None - - images = [ i for i in tree.xpath("//img") if 'src' in i.attrib ] - big_images = [ i for i in images if ( - 'width' in i.attrib and 'height' in i.attrib and - i.attrib['width'] > 64 and i.attrib['height'] > 64 - )] - big_images = big_images.sort(key=lambda i: (-1 * int(i.attrib['width']) * int(i.attrib['height']))) - images = big_images if big_images else images - - if images: - base = list(urlparse(media_info['uri'])) - src = list(urlparse(images[0].attrib['src'])) - if not src[0] and not src[1]: - src[0] = base[0] - src[1] = base[1] - if not src[2].startswith('/'): - src[2] = re.sub(r'/[^/]+$', '/', base[2]) + src[2] - og['og:image'] = urlunparse(src) - - text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text() | //span/text() | //a/text()") - # text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text()") - text = '' - for text_node in text_nodes: - if len(text) < 1024: - text += text_node + ' ' + def _calc_og(): + # suck it up into lxml and define our OG response. + # if we see any URLs in the OG response, then spider them + # (although the client could choose to do this by asking for previews of those URLs to avoid DoSing the server) + + # "og:type" : "article" + # "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672" + # "og:title" : "Matrix on Twitter" + # "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png" + # "og:description" : "Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP" + # "og:site_name" : "Twitter" + + # or: + + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : " ", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + + og = {} + for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): + og[tag.attrib['property']] = tag.attrib['content'] + + if 'og:title' not in og: + # do some basic spidering of the HTML + title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") + og['og:title'] = title[0].text if title else None + + + if 'og:image' not in og: + meta_image = tree.xpath("//*/meta[@itemprop='image']/@content"); + if meta_image: + og['og:image'] = self._rebase_url(meta_image[0], media_info['uri']) else: - break - text = re.sub(r'[\t ]+', ' ', text) - text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text) - text = text.strip()[:1024] - og['og:description'] = text if text else None - - # TODO: extract a favicon? - # TODO: turn any OG media URLs into mxc URLs to capture and thumbnail them too - # TODO: store our OG details in a cache (and expire them when stale) - # TODO: delete the content to stop diskfilling, as we only ever cared about its OG + images = [ i for i in tree.xpath("//img") if 'src' in i.attrib ] + big_images = [ i for i in images if ( + 'width' in i.attrib and 'height' in i.attrib and + i.attrib['width'] > 64 and i.attrib['height'] > 64 + )] + big_images = big_images.sort(key=lambda i: (-1 * int(i.attrib['width']) * int(i.attrib['height']))) + images = big_images if big_images else images + + if images: + og['og:image'] = self._rebase_url(images[0].attrib['src'], media_info['uri']) + + if 'og:description' not in og: + meta_description = tree.xpath("//*/meta[@name='description']/@content"); + if meta_description: + og['og:description'] = meta_description[0] + else: + text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text() | //span/text() | //a/text()") + # text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text()") + text = '' + for text_node in text_nodes: + if len(text) < 500: + text += text_node + ' ' + else: + break + text = re.sub(r'[\t ]+', ' ', text) + text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text) + text = text.strip()[:500] + og['og:description'] = text if text else None + + # TODO: extract a favicon? + # TODO: turn any OG media URLs into mxc URLs to capture and thumbnail them too + # TODO: store our OG details in a cache (and expire them when stale) + # TODO: delete the content to stop diskfilling, as we only ever cared about its OG + return og + + try: + tree = html.parse(media_info['filename']) + og = _calc_og() + except UnicodeDecodeError: + # XXX: evil evil bodge + file = open(media_info['filename']) + body = file.read() + file.close() + tree = html.fromstring(body.decode('utf-8','ignore')) + og = _calc_og() + else: logger.warn("Failed to find any OG data in %s", url) og = {} @@ -173,6 +182,15 @@ class PreviewUrlResource(BaseMediaResource): ) raise + def _rebase_url(self, url, base): + base = list(urlparse(base)) + url = list(urlparse(url)) + if not url[0] and not url[1]: + url[0] = base[0] + url[1] = base[1] + if not url[2].startswith('/'): + url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] + return urlunparse(url) @defer.inlineCallbacks def _download_url(self, url, user): @@ -223,7 +241,7 @@ class PreviewUrlResource(BaseMediaResource): download_name = None yield self.store.store_local_media( - media_id=fname, + media_id=file_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=download_name, -- cgit 1.4.1 From 7753fc65702dc2b98885f462fe0e6ac9f8e27b06 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 10:34:51 +0100 Subject: Fix the invalidation of the names and aliases cache --- synapse/storage/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 83279d65fa..7468e6e00c 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -198,7 +198,7 @@ class EventsStore(SQLBaseStore): txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) - txn.call_after(self.get_room_name_and_aliases, event.room_id) + txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,)) # Add an entry to the current_state_resets table to record the point # where we clobbered the current state -- cgit 1.4.1 From 35bb465b8698cb3ed9d034563d3b9ee03579e775 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 13:10:07 +0100 Subject: Filter rooms list before chunking --- synapse/handlers/sync.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 48ab5707e1..06098f899e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -252,6 +252,18 @@ class SyncHandler(BaseHandler): archived = [] deferreds = [] + user_id = sync_config.user.to_string() + + def _should_include_room(event): + # Always send down rooms we were banned or kicked from. + if not sync_config.filter_collection.include_leave: + if event.membership == Membership.LEAVE: + if user_id == event.sender: + return False + return True + + room_list = filter(_should_include_room, room_list) + room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)] for room_list_chunk in room_list_chunks: for event in room_list_chunk: @@ -276,12 +288,6 @@ class SyncHandler(BaseHandler): invite=invite, )) elif event.membership in (Membership.LEAVE, Membership.BAN): - # Always send down rooms we were banned or kicked from. - if not sync_config.filter_collection.include_leave: - if event.membership == Membership.LEAVE: - if sync_config.user.to_string() == event.sender: - continue - leave_token = now_token.copy_and_replace( "room_key", "s%d" % (event.stream_ordering,) ) -- cgit 1.4.1 From e36bfbab38def70e0fcc1bafcecb6e666dbbc1ad Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 13:29:05 +0100 Subject: Use a stream id generator for backfilled ids --- synapse/storage/__init__.py | 20 ++++-------- synapse/storage/account_data.py | 4 +-- synapse/storage/events.py | 19 +++-------- synapse/storage/presence.py | 6 ++-- synapse/storage/push_rule.py | 2 +- synapse/storage/pusher.py | 2 +- synapse/storage/receipts.py | 6 ++-- synapse/storage/state.py | 2 +- synapse/storage/stream.py | 2 +- synapse/storage/tags.py | 6 ++-- synapse/storage/util/id_generators.py | 61 +++++++++++++++++++++++------------ 11 files changed, 69 insertions(+), 61 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index aaad38039e..f87e907cd8 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -88,15 +88,6 @@ class DataStore(RoomMemberStore, RoomStore, self.hs = hs self.database_engine = hs.database_engine - cur = db_conn.cursor() - try: - cur.execute("SELECT MIN(stream_ordering) FROM events",) - rows = cur.fetchall() - self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1 - self.min_stream_token = min(self.min_stream_token, -1) - finally: - cur.close() - self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, @@ -105,6 +96,9 @@ class DataStore(RoomMemberStore, RoomStore, self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering" ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", direction=-1 + ) self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) @@ -129,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore, extra_tables=[("deleted_pushers", "stream_id")], ) - events_max = self._stream_id_gen.get_max_token() + events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -145,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token() + account_max = self._account_data_id_gen.get_current_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) @@ -156,7 +150,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._presence_id_gen.get_max_token(), + max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, @@ -167,7 +161,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_max_token()[0], + max_value=self._push_rules_stream_id_gen.get_current_token()[0], ) self.push_rules_stream_cache = StreamChangeCache( diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index faddefe219..7a7fbf1e52 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore): "add_room_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore): "add_user_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 83279d65fa..4ab23c1597 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -24,7 +24,6 @@ from synapse.util.logutils import log_function from synapse.api.constants import EventTypes from canonicaljson import encode_canonical_json -from contextlib import contextmanager from collections import namedtuple import logging @@ -66,14 +65,9 @@ class EventsStore(SQLBaseStore): return if backfilled: - start = self.min_stream_token - 1 - self.min_stream_token -= len(events_and_contexts) + 1 - stream_orderings = range(start, self.min_stream_token, -1) - - @contextmanager - def stream_ordering_manager(): - yield stream_orderings - stream_ordering_manager = stream_ordering_manager() + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(events_and_contexts) + ) else: stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) @@ -130,7 +124,7 @@ class EventsStore(SQLBaseStore): except _RollbackButIsFineException: pass - max_persisted_id = yield self._stream_id_gen.get_max_token() + max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((stream_ordering, max_persisted_id)) @defer.inlineCallbacks @@ -1117,10 +1111,7 @@ class EventsStore(SQLBaseStore): def get_current_backfill_token(self): """The current minimum token that backfilled events have reached""" - - # TODO: Fix race with the persit_event txn by using one of the - # stream id managers - return -self.min_stream_token + return -self._backfill_id_gen.get_current_token() def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 4cec31e316..59b4ef5ce6 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore): self._update_presence_txn, stream_orderings, presence_states, ) - defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) + defer.returnValue(( + stream_orderings[-1], self._presence_id_gen.get_current_token() + )) def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): @@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore): defer.returnValue([UserPresenceState(**row) for row in rows]) def get_current_presence_token(self): - return self._presence_id_gen.get_max_token() + return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9dbad2fd5f..d2bf7f2aec 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore): """Get the position of the push rules stream. Returns a pair of a stream id for the push_rules stream and the room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_max_token() + return self._push_rules_stream_id_gen.get_current_token() def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 87b2ac5773..d1669c778a 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore): defer.returnValue(rows) def get_pushers_stream_token(self): - return self._pushers_id_gen.get_max_token() + return self._pushers_id_gen.get_current_token() def get_all_updated_pushers(self, last_id, current_id, limit): def get_all_updated_pushers_txn(txn): diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 6b9d848eaa..4befebc8e2 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() ) @cached(num_args=2) @@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue(results) def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_max_token() + return self._receipts_id_gen.get_current_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): @@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = self._stream_id_gen.get_max_token() + max_persisted_id = self._stream_id_gen.get_current_token() defer.returnValue((stream_id, max_persisted_id)) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7fc9a4f264..8644830657 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -458,4 +458,4 @@ class StateStore(SQLBaseStore): ) def get_state_stream_token(self): - return self._state_groups_id_gen.get_max_token() + return self._state_groups_id_gen.get_current_token() diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index cf84938be5..76bcd9cd00 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -539,7 +539,7 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_room_events_max_id(self, direction='f'): - token = yield self._stream_id_gen.get_max_token() + token = yield self._stream_id_gen.get_current_token() if direction != 'b': defer.returnValue("s%d" % (token,)) else: diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index a0e6b42b30..9da23f34cb 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore): Returns: A deferred int. """ - return self._account_data_id_gen.get_max_token() + return self._account_data_id_gen.get_current_token() @cached() def get_tags_for_user(self, user_id): @@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index a02dfc7d58..03f2aa6a5c 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -21,7 +21,7 @@ import threading class IdGenerator(object): def __init__(self, db_conn, table, column): self._lock = threading.Lock() - self._next_id = _load_max_id(db_conn, table, column) + self._next_id = _load_current_id(db_conn, table, column) def get_next(self): with self._lock: @@ -29,12 +29,16 @@ class IdGenerator(object): return self._next_id -def _load_max_id(db_conn, table, column): +def _load_current_id(db_conn, table, column, direction=1): cur = db_conn.cursor() - cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + if direction == 1: + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + else: + cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) val, = cur.fetchone() cur.close() - return int(val) if val else 1 + current_id = int(val) if val else direction + return (max if direction == 1 else min)(current_id, direction) class StreamIdGenerator(object): @@ -45,17 +49,30 @@ class StreamIdGenerator(object): all ids less than or equal to it have completed. This handles the fact that persistence of events can complete out of order. + :param connection db_conn: A database connection to use to fetch the + initial value of the generator from. + :param str table: A database table to read the initial value of the id + generator from. + :param str column: The column of the database table to read the initial + value from the id generator from. + :param list extra_tables: List of pairs of database tables and columns to + use to source the initial value of the generator from. The value with + the largest magnitude is used. + :param int direction: which direction the stream ids grow in. +1 to grow + upwards, -1 to grow downwards. + Usage: with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[]): + def __init__(self, db_conn, table, column, extra_tables=[], direction=1): self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._direction = direction + self._current = _load_current_id(db_conn, table, column, direction) for table, column in extra_tables: - self._current_max = max( - self._current_max, - _load_max_id(db_conn, table, column) + self._current = (max if direction > 0 else min)( + self._current, + _load_current_id(db_conn, table, column, direction) ) self._unfinished_ids = deque() @@ -66,8 +83,8 @@ class StreamIdGenerator(object): # ... persist event ... """ with self._lock: - self._current_max += 1 - next_id = self._current_max + self._current += self._direction + next_id = self._current self._unfinished_ids.append(next_id) @@ -88,8 +105,12 @@ class StreamIdGenerator(object): # ... persist events ... """ with self._lock: - next_ids = range(self._current_max + 1, self._current_max + n + 1) - self._current_max += n + next_ids = range( + self._current + self._direction, + self._current + self._direction * (n + 1), + self._direction + ) + self._current += n for next_id in next_ids: self._unfinished_ids.append(next_id) @@ -105,15 +126,15 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - 1 + return self._unfinished_ids[0] - self._direction - return self._current_max + return self._current class ChainedIdGenerator(object): @@ -125,7 +146,7 @@ class ChainedIdGenerator(object): def __init__(self, chained_generator, db_conn, table, column): self.chained_generator = chained_generator self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._current_max = _load_current_id(db_conn, table, column) self._unfinished_ids = deque() def get_next(self): @@ -137,7 +158,7 @@ class ChainedIdGenerator(object): with self._lock: self._current_max += 1 next_id = self._current_max - chained_id = self.chained_generator.get_max_token() + chained_id = self.chained_generator.get_current_token() self._unfinished_ids.append((next_id, chained_id)) @@ -151,7 +172,7 @@ class ChainedIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ @@ -160,4 +181,4 @@ class ChainedIdGenerator(object): stream_id, chained_id = self._unfinished_ids[0] return (stream_id - 1, chained_id) - return (self._current_max, self.chained_generator.get_max_token()) + return (self._current_max, self.chained_generator.get_current_token()) -- cgit 1.4.1 From a2866e2e6a8fa60a538a98f62e1733ab062020aa Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 13:50:54 +0100 Subject: Rename direction to step, apply checks consistently --- synapse/storage/__init__.py | 2 +- synapse/storage/util/id_generators.py | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 16 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f87e907cd8..57863bba4d 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -97,7 +97,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "events", "stream_ordering" ) self._backfill_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", direction=-1 + db_conn, "events", "stream_ordering", step=-1 ) self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 03f2aa6a5c..310b7dc6ee 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -29,16 +29,16 @@ class IdGenerator(object): return self._next_id -def _load_current_id(db_conn, table, column, direction=1): +def _load_current_id(db_conn, table, column, step=1): cur = db_conn.cursor() - if direction == 1: + if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) else: cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) val, = cur.fetchone() cur.close() - current_id = int(val) if val else direction - return (max if direction == 1 else min)(current_id, direction) + current_id = int(val) if val else step + return (max if step > 0 else min)(current_id, step) class StreamIdGenerator(object): @@ -58,21 +58,21 @@ class StreamIdGenerator(object): :param list extra_tables: List of pairs of database tables and columns to use to source the initial value of the generator from. The value with the largest magnitude is used. - :param int direction: which direction the stream ids grow in. +1 to grow + :param int step: which direction the stream ids grow in. +1 to grow upwards, -1 to grow downwards. Usage: with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[], direction=1): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): self._lock = threading.Lock() - self._direction = direction - self._current = _load_current_id(db_conn, table, column, direction) + self._step = step + self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self._current = (max if direction > 0 else min)( + self._current = (max if step > 0 else min)( self._current, - _load_current_id(db_conn, table, column, direction) + _load_current_id(db_conn, table, column, step) ) self._unfinished_ids = deque() @@ -83,7 +83,7 @@ class StreamIdGenerator(object): # ... persist event ... """ with self._lock: - self._current += self._direction + self._current += self._step next_id = self._current self._unfinished_ids.append(next_id) @@ -106,9 +106,9 @@ class StreamIdGenerator(object): """ with self._lock: next_ids = range( - self._current + self._direction, - self._current + self._direction * (n + 1), - self._direction + self._current + self._step, + self._current + self._step * (n + 1), + self._step ) self._current += n @@ -132,7 +132,7 @@ class StreamIdGenerator(object): """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - self._direction + return self._unfinished_ids[0] - self._step return self._current -- cgit 1.4.1 From 8d73cd502bd8ee6903c81f20f79fe5e1509692e3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 14:06:00 +0100 Subject: Add concurrently_execute function --- synapse/handlers/message.py | 10 +---- synapse/handlers/room.py | 17 ++++---- synapse/handlers/sync.py | 98 +++++++++++++++++++-------------------------- synapse/util/async.py | 32 ++++++++++++++- 4 files changed, 82 insertions(+), 75 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5c50c611ba..0bb111d047 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -21,6 +21,7 @@ from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError +from synapse.util.async import concurrently_execute from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.types import UserID, RoomStreamToken, StreamToken @@ -556,14 +557,7 @@ class MessageHandler(BaseHandler): except: logger.exception("Failed to get snapshot") - # Only do N rooms at once - n = 5 - d_list = [handle_room(e) for e in room_list] - for i in range(0, len(d_list), n): - yield defer.gatherResults( - d_list[i:i + n], - consumeErrors=True - ).addErrback(unwrapFirstError) + yield concurrently_execute(handle_room, room_list, 10) account_data_events = [] for account_data_type, content in account_data.items(): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ee99ded214..3e1d9282d7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -23,7 +23,8 @@ from synapse.api.constants import ( EventTypes, JoinRules, RoomCreationPreset, ) from synapse.api.errors import AuthError, StoreError, SynapseError -from synapse.util import stringutils, unwrapFirstError +from synapse.util import stringutils +from synapse.util.async import concurrently_execute from synapse.util.logcontext import preserve_context_over_fn from synapse.util.caches.response_cache import ResponseCache @@ -368,6 +369,8 @@ class RoomListHandler(BaseHandler): def _get_public_room_list(self): room_ids = yield self.store.get_public_room_ids() + results = [] + @defer.inlineCallbacks def handle_room(room_id): aliases = yield self.store.get_aliases_for_room(room_id) @@ -428,18 +431,12 @@ class RoomListHandler(BaseHandler): joined_users = yield self.store.get_users_in_room(room_id) result["num_joined_members"] = len(joined_users) - defer.returnValue(result) + results.append(result) - result = [] - for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)): - chunk_result = yield defer.gatherResults([ - handle_room(room_id) - for room_id in chunk - ], consumeErrors=True).addErrback(unwrapFirstError) - result.extend(v for v in chunk_result if v) + yield concurrently_execute(handle_room, room_ids, 10) # FIXME (erikj): START is no longer a valid value - defer.returnValue({"start": "START", "end": "END", "chunk": result}) + defer.returnValue({"start": "START", "end": "END", "chunk": results}) class RoomContextHandler(BaseHandler): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 06098f899e..e38fe1ef9c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -17,8 +17,8 @@ from ._base import BaseHandler from synapse.streams.config import PaginationConfig from synapse.api.constants import Membership, EventTypes -from synapse.util import unwrapFirstError -from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.async import concurrently_execute +from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user @@ -250,64 +250,50 @@ class SyncHandler(BaseHandler): joined = [] invited = [] archived = [] - deferreds = [] user_id = sync_config.user.to_string() - def _should_include_room(event): - # Always send down rooms we were banned or kicked from. - if not sync_config.filter_collection.include_leave: - if event.membership == Membership.LEAVE: - if user_id == event.sender: - return False - return True - - room_list = filter(_should_include_room, room_list) - - room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)] - for room_list_chunk in room_list_chunks: - for event in room_list_chunk: - if event.membership == Membership.JOIN: - room_sync_deferred = preserve_fn( - self.full_state_sync_for_joined_room - )( - room_id=event.room_id, - sync_config=sync_config, - now_token=now_token, - timeline_since_token=timeline_since_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - room_sync_deferred.addCallback(joined.append) - deferreds.append(room_sync_deferred) - elif event.membership == Membership.INVITE: - invite = yield self.store.get_event(event.event_id) - invited.append(InvitedSyncResult( - room_id=event.room_id, - invite=invite, - )) - elif event.membership in (Membership.LEAVE, Membership.BAN): - leave_token = now_token.copy_and_replace( - "room_key", "s%d" % (event.stream_ordering,) - ) - room_sync_deferred = preserve_fn( - self.full_state_sync_for_archived_room - )( - sync_config=sync_config, - room_id=event.room_id, - leave_event_id=event.event_id, - leave_token=leave_token, - timeline_since_token=timeline_since_token, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - room_sync_deferred.addCallback(archived.append) - deferreds.append(room_sync_deferred) + @defer.inlineCallbacks + def _generate_room_entry(event): + if event.membership == Membership.JOIN: + room_result = yield self.full_state_sync_for_joined_room( + room_id=event.room_id, + sync_config=sync_config, + now_token=now_token, + timeline_since_token=timeline_since_token, + ephemeral_by_room=ephemeral_by_room, + tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, + ) + joined.append(room_result) + elif event.membership == Membership.INVITE: + invite = yield self.store.get_event(event.event_id) + invited.append(InvitedSyncResult( + room_id=event.room_id, + invite=invite, + )) + elif event.membership in (Membership.LEAVE, Membership.BAN): + # Always send down rooms we were banned or kicked from. + if not sync_config.filter_collection.include_leave: + if event.membership == Membership.LEAVE: + if user_id == event.sender: + return + + leave_token = now_token.copy_and_replace( + "room_key", "s%d" % (event.stream_ordering,) + ) + room_result = yield self.full_state_sync_for_archived_room( + sync_config=sync_config, + room_id=event.room_id, + leave_event_id=event.event_id, + leave_token=leave_token, + timeline_since_token=timeline_since_token, + tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, + ) + archived.append(room_result) - yield defer.gatherResults( - deferreds, consumeErrors=True - ).addErrback(unwrapFirstError) + yield concurrently_execute(_generate_room_entry, room_list, 10) account_data_for_user = sync_config.filter_collection.filter_account_data( self.account_data_for_user(account_data) diff --git a/synapse/util/async.py b/synapse/util/async.py index 640fae3890..a75e1c71fb 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,7 +16,8 @@ from twisted.internet import defer, reactor -from .logcontext import PreserveLoggingContext +from .logcontext import PreserveLoggingContext, preserve_fn +from synapse.util import unwrapFirstError @defer.inlineCallbacks @@ -107,3 +108,32 @@ class ObservableDeferred(object): return "" % ( id(self), self._result, self._deferred, ) + + +def concurrently_execute(func, args, limit): + """Executes the function with each argument conncurrently while limiting + the number of concurrent executions. + + Args: + func (func): Function to execute, should return a deferred. + args (list): List of arguments to pass to func, each invocation of func + gets a signle argument. + limit (int): Maximum number of conccurent executions. + + Returns: + deferred + """ + it = iter(args) + + @defer.inlineCallbacks + def _concurrently_execute_inner(): + try: + while True: + yield func(it.next()) + except StopIteration: + pass + + return defer.gatherResults([ + preserve_fn(_concurrently_execute_inner)() + for _ in xrange(limit) + ], consumeErrors=True).addErrback(unwrapFirstError) -- cgit 1.4.1 From 3f4eb4c92402d80d0f41501bf71a60a1b94f2756 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 14:15:27 +0100 Subject: Comment --- synapse/util/async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/util/async.py b/synapse/util/async.py index a75e1c71fb..cd4d90f3cf 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -121,7 +121,7 @@ def concurrently_execute(func, args, limit): limit (int): Maximum number of conccurent executions. Returns: - deferred + deferred: Resolved when all function invocations have finished. """ it = iter(args) -- cgit 1.4.1 From 35b5c4ba1b1892fde18f531c96d71aa58de649e1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 15:07:01 +0100 Subject: use google style doc strings --- synapse/storage/util/id_generators.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 310b7dc6ee..58ea54cf67 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -49,17 +49,18 @@ class StreamIdGenerator(object): all ids less than or equal to it have completed. This handles the fact that persistence of events can complete out of order. - :param connection db_conn: A database connection to use to fetch the - initial value of the generator from. - :param str table: A database table to read the initial value of the id - generator from. - :param str column: The column of the database table to read the initial - value from the id generator from. - :param list extra_tables: List of pairs of database tables and columns to - use to source the initial value of the generator from. The value with - the largest magnitude is used. - :param int step: which direction the stream ids grow in. +1 to grow - upwards, -1 to grow downwards. + Args: + db_conn(connection): A database connection to use to fetch the + initial value of the generator from. + table(str): A database table to read the initial value of the id + generator from. + column(str): The column of the database table to read the initial + value from the id generator from. + extra_tables(list): List of pairs of database tables and columns to + use to source the initial value of the generator from. The value + with the largest magnitude is used. + step(int): which direction the stream ids grow in. +1 to grow + upwards, -1 to grow downwards. Usage: with stream_id_gen.get_next() as stream_id: -- cgit 1.4.1 From 9bc5b4c663ed4bb35ef74a820c108765c7ca0f67 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 15:08:20 +0100 Subject: Assert that the step != 0 --- synapse/storage/util/id_generators.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse') diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 58ea54cf67..f69f1cdad4 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -67,6 +67,7 @@ class StreamIdGenerator(object): # ... persist event ... """ def __init__(self, db_conn, table, column, extra_tables=[], step=1): + assert step != 0 self._lock = threading.Lock() self._step = step self._current = _load_current_id(db_conn, table, column, step) -- cgit 1.4.1 From 2a37467fa1358eb41513893efe44cbd294dca36c Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 16:08:59 +0100 Subject: Use google style doc strings. pycharm supports them so there is no need to use the other format. Might as well convert the existing strings to reduce the risk of people accidentally cargo culting the wrong doc string format. --- setup.cfg | 3 ++ synapse/handlers/_base.py | 27 +++++++----- synapse/handlers/auth.py | 26 +++++++---- synapse/handlers/federation.py | 23 +++++----- synapse/handlers/room_member.py | 48 ++++++++++----------- synapse/handlers/sync.py | 49 +++++++++++++-------- synapse/http/servlet.py | 81 ++++++++++++++++++++++------------- synapse/notifier.py | 15 ++++--- synapse/push/baserules.py | 8 ++-- synapse/rest/client/v2_alpha/sync.py | 79 ++++++++++++++++++---------------- synapse/state.py | 19 ++++---- synapse/storage/event_push_actions.py | 5 ++- synapse/storage/registration.py | 15 ++++--- synapse/storage/state.py | 13 +++--- 14 files changed, 242 insertions(+), 169 deletions(-) (limited to 'synapse') diff --git a/setup.cfg b/setup.cfg index f8cc13c840..5ebce1c56b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,3 +17,6 @@ ignore = [flake8] max-line-length = 90 ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. + +[pep8] +max-line-length = 90 diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90eabb6eb7..5601ecea6e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -41,8 +41,9 @@ class BaseHandler(object): """ Common base class for the event handlers. - :type store: synapse.storage.events.StateStore - :type state_handler: synapse.state.StateHandler + Attributes: + store (synapse.storage.events.StateStore): + state_handler (synapse.state.StateHandler): """ def __init__(self, hs): @@ -65,11 +66,12 @@ class BaseHandler(object): """ Returns dict of user_id -> list of events that user is allowed to see. - :param (str, bool) user_tuples: (user id, is_peeking) for each - user to be checked. is_peeking should be true if: - * the user is not currently a member of the room, and: - * the user has not been a member of the room since the given - events + Args: + user_tuples (str, bool): (user id, is_peeking) for each user to be + checked. is_peeking should be true if: + * the user is not currently a member of the room, and: + * the user has not been a member of the room since the + given events """ forgotten = yield defer.gatherResults([ self.store.who_forgot_in_room( @@ -165,13 +167,16 @@ class BaseHandler(object): """ Check which events a user is allowed to see - :param str user_id: user id to be checked - :param [synapse.events.EventBase] events: list of events to be checked - :param bool is_peeking should be True if: + Args: + user_id(str): user id to be checked + events([synapse.events.EventBase]): list of events to be checked + is_peeking(bool): should be True if: * the user is not currently a member of the room, and: * the user has not been a member of the room since the given events - :rtype [synapse.events.EventBase] + + Returns: + [synapse.events.EventBase] """ types = ( (EventTypes.RoomHistoryVisibility, ""), diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 82d458b424..d5d6faa85f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -163,9 +163,13 @@ class AuthHandler(BaseHandler): def get_session_id(self, clientdict): """ Gets the session ID for a client given the client dictionary - :param clientdict: The dictionary sent by the client in the request - :return: The string session ID the client sent. If the client did not - send a session ID, returns None. + + Args: + clientdict: The dictionary sent by the client in the request + + Returns: + str|None: The string session ID the client sent. If the client did + not send a session ID, returns None. """ sid = None if clientdict and 'auth' in clientdict: @@ -179,9 +183,11 @@ class AuthHandler(BaseHandler): Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by the client. - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param value: (any) The data to store + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + value (any): The data to store """ sess = self._get_session_info(session_id) sess.setdefault('serverdict', {})[key] = value @@ -190,9 +196,11 @@ class AuthHandler(BaseHandler): def get_session_data(self, session_id, key, default=None): """ Retrieve data stored with set_session_data - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param default: (any) Value to return if the key has not been set + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + default (any): Value to return if the key has not been set """ sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4a35344d32..092802b973 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1706,13 +1706,15 @@ class FederationHandler(BaseHandler): def _check_signature(self, event, auth_events): """ Checks that the signature in the event is consistent with its invite. - :param event (Event): The m.room.member event to check - :param auth_events (dict<(event type, state_key), event>) - :raises - AuthError if signature didn't match any keys, or key has been + Args: + event (Event): The m.room.member event to check + auth_events (dict<(event type, state_key), event>): + + Raises: + AuthError: if signature didn't match any keys, or key has been revoked, - SynapseError if a transient error meant a key couldn't be checked + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ signed = event.content["third_party_invite"]["signed"] @@ -1754,12 +1756,13 @@ class FederationHandler(BaseHandler): """ Checks whether public_key has been revoked. - :param public_key (str): base-64 encoded public key. - :param url (str): Key revocation URL. + Args: + public_key (str): base-64 encoded public key. + url (str): Key revocation URL. - :raises - AuthError if they key has been revoked. - SynapseError if a transient error meant a key couldn't be checked + Raises: + AuthError: if they key has been revoked. + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ try: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5fdbd3adcc..01f833c371 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -411,7 +411,7 @@ class RoomMemberHandler(BaseHandler): address (str): The third party identifier (e.g. "foo@example.com"). Returns: - (str) the matrix ID of the 3pid, or None if it is not recognized. + str: the matrix ID of the 3pid, or None if it is not recognized. """ try: data = yield self.hs.get_simple_http_client().get_json( @@ -545,29 +545,29 @@ class RoomMemberHandler(BaseHandler): """ Asks an identity server for a third party invite. - :param id_server (str): hostname + optional port for the identity server. - :param medium (str): The literal string "email". - :param address (str): The third party address being invited. - :param room_id (str): The ID of the room to which the user is invited. - :param inviter_user_id (str): The user ID of the inviter. - :param room_alias (str): An alias for the room, for cosmetic - notifications. - :param room_avatar_url (str): The URL of the room's avatar, for cosmetic - notifications. - :param room_join_rules (str): The join rules of the email - (e.g. "public"). - :param room_name (str): The m.room.name of the room. - :param inviter_display_name (str): The current display name of the - inviter. - :param inviter_avatar_url (str): The URL of the inviter's avatar. - - :return: A deferred tuple containing: - token (str): The token which must be signed to prove authenticity. - public_keys ([{"public_key": str, "key_validity_url": str}]): - public_key is a base64-encoded ed25519 public key. - fallback_public_key: One element from public_keys. - display_name (str): A user-friendly name to represent the invited - user. + Args: + id_server (str): hostname + optional port for the identity server. + medium (str): The literal string "email". + address (str): The third party address being invited. + room_id (str): The ID of the room to which the user is invited. + inviter_user_id (str): The user ID of the inviter. + room_alias (str): An alias for the room, for cosmetic notifications. + room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + room_join_rules (str): The join rules of the email (e.g. "public"). + room_name (str): The m.room.name of the room. + inviter_display_name (str): The current display name of the + inviter. + inviter_avatar_url (str): The URL of the inviter's avatar. + + Returns: + A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. """ is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 48ab5707e1..20a0626574 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -671,7 +671,8 @@ class SyncHandler(BaseHandler): def load_filtered_recents(self, room_id, sync_config, now_token, since_token=None, recents=None, newly_joined_room=False): """ - :returns a Deferred TimelineBatch + Returns: + a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): filtering_factor = 2 @@ -838,8 +839,11 @@ class SyncHandler(BaseHandler): """ Get the room state after the given event - :param synapse.events.EventBase event: event of interest - :return: A Deferred map from ((type, state_key)->Event) + Args: + event(synapse.events.EventBase): event of interest + + Returns: + A Deferred map from ((type, state_key)->Event) """ state = yield self.store.get_state_for_event(event.event_id) if event.is_state(): @@ -850,9 +854,13 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def get_state_at(self, room_id, stream_position): """ Get the room state at a particular stream position - :param str room_id: room for which to get state - :param StreamToken stream_position: point at which to get state - :returns: A Deferred map from ((type, state_key)->Event) + + Args: + room_id(str): room for which to get state + stream_position(StreamToken): point at which to get state + + Returns: + A Deferred map from ((type, state_key)->Event) """ last_events, token = yield self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1, @@ -873,15 +881,18 @@ class SyncHandler(BaseHandler): """ Works out the differnce in state between the start of the timeline and the previous sync. - :param str room_id - :param TimelineBatch batch: The timeline batch for the room that will - be sent to the user. - :param sync_config - :param str since_token: Token of the end of the previous batch. May be None. - :param str now_token: Token of the end of the current batch. - :param bool full_state: Whether to force returning the full state. + Args: + room_id(str): + batch(synapse.handlers.sync.TimelineBatch): The timeline batch for + the room that will be sent to the user. + sync_config(synapse.handlers.sync.SyncConfig): + since_token(str|None): Token of the end of the previous batch. May + be None. + now_token(str): Token of the end of the current batch. + full_state(bool): Whether to force returning the full state. - :returns A new event dictionary + Returns: + A deferred new event dictionary """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -953,11 +964,13 @@ class SyncHandler(BaseHandler): Check if the user has just joined the given room (so should be given the full state) - :param sync_config: - :param dict[(str,str), synapse.events.FrozenEvent] state_delta: the - difference in state since the last sync + Args: + sync_config(synapse.handlers.sync.SyncConfig): + state_delta(dict[(str,str), synapse.events.FrozenEvent]): the + difference in state since the last sync - :returns A deferred Tuple (state_delta, limited) + Returns: + A deferred Tuple (state_delta, limited) """ join_event = state_delta.get(( EventTypes.Member, sync_config.user.to_string()), None) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 1c8bd8666f..e41afeab8e 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -26,14 +26,19 @@ logger = logging.getLogger(__name__) def parse_integer(request, name, default=None, required=False): """Parse an integer parameter from the request string - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :return: An int value or the default. - :raises - SynapseError if the parameter is absent and required, or if the + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (int|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + int|None: An int value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the parameter is present and not an integer. """ if name in request.args: @@ -53,14 +58,19 @@ def parse_integer(request, name, default=None, required=False): def parse_boolean(request, name, default=None, required=False): """Parse a boolean parameter from the request query string - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :return: A bool value or the default. - :raises - SynapseError if the parameter is absent and required, or if the + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (bool|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + bool|None: A bool value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the parameter is present and not one of "true" or "false". """ @@ -88,15 +98,20 @@ def parse_string(request, name, default=None, required=False, allowed_values=None, param_type="string"): """Parse a string parameter from the request query string. - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :param allowed_values (list): List of allowed values for the string, - or None if any value is allowed, defaults to None - :return: A string value or the default. - :raises + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (str|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + allowed_values (list[str]): List of allowed values for the string, + or None if any value is allowed, defaults to None + + Returns: + str|None: A string value or the default. + + Raises: SynapseError if the parameter is absent and required, or if the parameter is present, must be one of a list of allowed values and is not one of those allowed values. @@ -122,9 +137,13 @@ def parse_string(request, name, default=None, required=False, def parse_json_value_from_request(request): """Parse a JSON value from the body of a twisted HTTP request. - :param request: the twisted HTTP request. - :returns: The JSON value. - :raises + Args: + request: the twisted HTTP request. + + Returns: + The JSON value. + + Raises: SynapseError if the request body couldn't be decoded as JSON. """ try: @@ -143,8 +162,10 @@ def parse_json_value_from_request(request): def parse_json_object_from_request(request): """Parse a JSON object from the body of a twisted HTTP request. - :param request: the twisted HTTP request. - :raises + Args: + request: the twisted HTTP request. + + Raises: SynapseError if the request body couldn't be decoded as JSON or if it wasn't a JSON object. """ diff --git a/synapse/notifier.py b/synapse/notifier.py index f00cd8c588..6af7a8f424 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -503,13 +503,14 @@ class Notifier(object): def wait_for_replication(self, callback, timeout): """Wait for an event to happen. - :param callback: - Gets called whenever an event happens. If this returns a truthy - value then ``wait_for_replication`` returns, otherwise it waits - for another event. - :param int timeout: - How many milliseconds to wait for callback return a truthy value. - :returns: + Args: + callback: Gets called whenever an event happens. If this returns a + truthy value then ``wait_for_replication`` returns, otherwise + it waits for another event. + timeout: How many milliseconds to wait for callback return a truthy + value. + + Returns: A deferred that resolves with the value returned by the callback. """ listener = _NotificationListener(None) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 792af70eb7..6add94beeb 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -19,9 +19,11 @@ import copy def list_with_base_rules(rawrules): """Combine the list of rules set by the user with the default push rules - :param list rawrules: The rules the user has modified or set. - :returns: A new list with the rules set by the user combined with the - defaults. + Args: + rawrules(list): The rules the user has modified or set. + + Returns: + A new list with the rules set by the user combined with the defaults. """ ruleslist = [] diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index c5785d7074..60d3dc4030 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -199,15 +199,17 @@ class SyncRestServlet(RestServlet): """ Encode the joined rooms in a sync result - :param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync - results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - - :return: the joined rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Args: + rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync + results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + + Returns: + dict[str, dict[str, object]]: the joined rooms list, in our + response format """ joined = {} for room in rooms: @@ -221,15 +223,17 @@ class SyncRestServlet(RestServlet): """ Encode the invited rooms in a sync result - :param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of - sync results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing + Args: + rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - :return: the invited rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Returns: + dict[str, dict[str, object]]: the invited rooms list, in our + response format """ invited = {} for room in rooms: @@ -251,15 +255,17 @@ class SyncRestServlet(RestServlet): """ Encode the archived rooms in a sync result - :param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of - sync results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - - :return: the invited rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Args: + rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + + Returns: + dict[str, dict[str, object]]: The invited rooms list, in our + response format """ joined = {} for room in rooms: @@ -272,17 +278,18 @@ class SyncRestServlet(RestServlet): @staticmethod def encode_room(room, time_now, token_id, joined=True): """ - :param JoinedSyncResult|ArchivedSyncResult room: sync result for a - single room - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - :param joined: True if the user is joined to this room - will mean - we handle ephemeral events - - :return: the room, encoded in our response format - :rtype: dict[str, object] + Args: + room (JoinedSyncResult|ArchivedSyncResult): sync result for a + single room + time_now (int): current time - used as a baseline for age + calculations + token_id (int): ID of the user's auth token - used for namespacing + of transaction IDs + joined (bool): True if the user is joined to this room - will mean + we handle ephemeral events + + Returns: + dict[str, object]: the room, encoded in our response format """ def serialize(event): # TODO(mjark): Respect formatting requirements in the filter. diff --git a/synapse/state.py b/synapse/state.py index 41d32e664a..4a9e148de7 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -86,7 +86,8 @@ class StateHandler(object): If `event_type` is specified, then the method returns only the one event (or None) with that `event_type` and `state_key`. - :returns map from (type, state_key) to event + Returns: + map from (type, state_key) to event """ event_ids = yield self.store.get_latest_event_ids_in_room(room_id) @@ -176,10 +177,11 @@ class StateHandler(object): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. - :returns a Deferred tuple of (`state_group`, `state`, `prev_state`). - `state_group` is the name of a state group if one and only one is - involved. `state` is a map from (type, state_key) to event, and - `prev_state` is a list of event ids. + Returns: + a Deferred tuple of (`state_group`, `state`, `prev_state`). + `state_group` is the name of a state group if one and only one is + involved. `state` is a map from (type, state_key) to event, and + `prev_state` is a list of event ids. """ logger.debug("resolve_state_groups event_ids %s", event_ids) @@ -251,9 +253,10 @@ class StateHandler(object): def _resolve_events(self, state_sets, event_type=None, state_key=""): """ - :returns a tuple (new_state, prev_states). new_state is a map - from (type, state_key) to event. prev_states is a list of event_ids. - :rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str]) + Returns + (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple + (new_state, prev_states). new_state is a map from (type, state_key) + to event. prev_states is a list of event_ids. """ with Measure(self.clock, "state._resolve_events"): state = {} diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index dc5830450a..3933b6e2c5 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -26,8 +26,9 @@ logger = logging.getLogger(__name__) class EventPushActionsStore(SQLBaseStore): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ - :param event: the event set actions for - :param tuples: list of tuples of (user_id, actions) + Args: + event: the event set actions for + tuples: list of tuples of (user_id, actions) """ values = [] for uid, actions in tuples: diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bd4eb88a92..d46a963bb8 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -458,12 +458,15 @@ class RegistrationStore(SQLBaseStore): """ Gets the 3pid's guest access token if exists, else saves access_token. - :param medium (str): Medium of the 3pid. Must be "email". - :param address (str): 3pid address. - :param access_token (str): The access token to persist if none is - already persisted. - :param inviter_user_id (str): User ID of the inviter. - :return (deferred str): Whichever access token is persisted at the end + Args: + medium (str): Medium of the 3pid. Must be "email". + address (str): 3pid address. + access_token (str): The access token to persist if none is + already persisted. + inviter_user_id (str): User ID of the inviter. + + Returns: + deferred str: Whichever access token is persisted at the end of this function call. """ def insert(txn): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7fc9a4f264..f84fd0e30a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -249,11 +249,14 @@ class StateStore(SQLBaseStore): """ Get the state dict corresponding to a particular event - :param str event_id: event whose state should be returned - :param list[(str, str)]|None types: List of (type, state_key) tuples - which are used to filter the state fetched. May be None, which - matches any key - :return: a deferred dict from (type, state_key) -> state_event + Args: + event_id(str): event whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from (type, state_key) -> state_event """ state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) -- cgit 1.4.1 From c906f3066152ba7a65c0765a5812b71bb8a4016c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 16:17:32 +0100 Subject: Do checks for memberships before creating events --- synapse/handlers/room_member.py | 237 ++++++++++++++++++++++--------------- synapse/state.py | 8 +- tests/rest/client/v1/test_rooms.py | 4 +- 3 files changed, 151 insertions(+), 98 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5fdbd3adcc..09b8f6217a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -95,6 +95,80 @@ class RoomMemberHandler(BaseHandler): if remotedomains is not None: remotedomains.add(member.domain) + @defer.inlineCallbacks + def _local_membership_update( + self, requester, target, room_id, membership, + txn_id=None, + ratelimit=True, + ): + msg_handler = self.hs.get_handlers().message_handler + + content = {"membership": membership} + if requester.is_guest: + content["kind"] = "guest" + + event, context = yield msg_handler.create_event( + { + "type": EventTypes.Member, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "state_key": target.to_string(), + + # For backwards compatibility: + "membership": membership, + }, + token_id=requester.access_token_id, + txn_id=txn_id, + ) + + yield self.handle_new_client_event( + requester, + event, + context, + extra_users=[target], + ratelimit=ratelimit, + ) + + prev_member_event = context.current_state.get( + (EventTypes.Member, target.to_string()), + None + ) + + if event.membership == Membership.JOIN: + if not prev_member_event or prev_member_event.membership != Membership.JOIN: + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + yield user_joined_room(self.distributor, target, room_id) + elif event.membership == Membership.LEAVE: + if prev_member_event and prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target, room_id) + + @defer.inlineCallbacks + def remote_join(self, remote_room_hosts, room_id, user, content): + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield self.hs.get_handlers().federation_handler.do_invite_join( + remote_room_hosts, + room_id, + user.to_string(), + content, + ) + yield user_joined_room(self.distributor, user, room_id) + + def reject_remote_invite(self, user_id, room_id, remote_room_hosts): + return self.hs.get_handlers().federation_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + user_id + ) + @defer.inlineCallbacks def update_membership( self, @@ -120,28 +194,15 @@ class RoomMemberHandler(BaseHandler): third_party_signed, ) - msg_handler = self.hs.get_handlers().message_handler - - content = {"membership": effective_membership_state} - if requester.is_guest: - content["kind"] = "guest" + if not remote_room_hosts: + remote_room_hosts = [] - event, context = yield msg_handler.create_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "state_key": target.to_string(), - - # For backwards compatibility: - "membership": effective_membership_state, - }, - token_id=requester.access_token_id, - txn_id=txn_id, + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + current_state = yield self.state_handler.get_current_state( + room_id, latest_event_ids=latest_event_ids, ) - old_state = context.current_state.get((EventTypes.Member, event.state_key)) + old_state = current_state.get((EventTypes.Member, target.to_string())) old_membership = old_state.content.get("membership") if old_state else None if action == "unban" and old_membership != "ban": raise SynapseError( @@ -156,13 +217,57 @@ class RoomMemberHandler(BaseHandler): errcode=Codes.BAD_STATE ) - member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event( - requester, - event, - context, + is_host_in_room = self.is_host_in_room(current_state) + + if effective_membership_state == Membership.JOIN: + if requester.is_guest and not self._can_guest_join(current_state): + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") + + if not is_host_in_room: + inviter = yield self.get_inviter(target.to_string(), room_id) + if inviter and not self.hs.is_mine(inviter): + remote_room_hosts.append(inviter.domain) + + content = {"membership": Membership.JOIN} + if requester.is_guest: + content["kind"] = "guest" + + ret = yield self.remote_join( + remote_room_hosts, room_id, target, content + ) + defer.returnValue(ret) + + elif effective_membership_state == Membership.LEAVE: + if not is_host_in_room: + # perhaps we've been invited + inviter = yield self.get_inviter(target.to_string(), room_id) + if not inviter: + raise SynapseError(404, "Not a known room") + + if self.hs.is_mine(inviter): + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath. + # + # This is a bit of a hack, because the room might still be + # active on other servers. + pass + else: + # send the rejection to the inviter's HS. + remote_room_hosts = remote_room_hosts + [inviter.domain] + ret = yield self.reject_remote_invite( + target.to_string(), room_id, remote_room_hosts + ) + defer.returnValue(ret) + + yield self._local_membership_update( + requester=requester, + target=target, + room_id=room_id, + membership=effective_membership_state, + txn_id=txn_id, ratelimit=ratelimit, - remote_room_hosts=remote_room_hosts, ) @defer.inlineCallbacks @@ -211,73 +316,19 @@ class RoomMemberHandler(BaseHandler): if prev_event is not None: return - action = "send" - if event.membership == Membership.JOIN: if requester.is_guest and not self._can_guest_join(context.current_state): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") - do_remote_join_dance, remote_room_hosts = self._should_do_dance( - context, - (self.get_inviter(event.state_key, context.current_state)), - remote_room_hosts, - ) - if do_remote_join_dance: - action = "remote_join" - elif event.membership == Membership.LEAVE: - is_host_in_room = self.is_host_in_room(context.current_state) - - if not is_host_in_room: - # perhaps we've been invited - inviter = self.get_inviter( - target_user.to_string(), context.current_state - ) - if not inviter: - raise SynapseError(404, "Not a known room") - if self.hs.is_mine(inviter): - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath. - # - # This is a bit of a hack, because the room might still be - # active on other servers. - pass - else: - # send the rejection to the inviter's HS. - remote_room_hosts = remote_room_hosts + [inviter.domain] - action = "remote_reject" - - federation_handler = self.hs.get_handlers().federation_handler - - if action == "remote_join": - if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") - - # We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - yield federation_handler.do_invite_join( - remote_room_hosts, - event.room_id, - event.user_id, - event.content, - ) - elif action == "remote_reject": - yield federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - event.user_id - ) - else: - yield self.handle_new_client_event( - requester, - event, - context, - extra_users=[target_user], - ratelimit=ratelimit, - ) + yield self.handle_new_client_event( + requester, + event, + context, + extra_users=[target_user], + ratelimit=ratelimit, + ) prev_member_event = context.current_state.get( (EventTypes.Member, target_user.to_string()), @@ -306,11 +357,11 @@ class RoomMemberHandler(BaseHandler): and guest_access.content["guest_access"] == "can_join" ) - def _should_do_dance(self, context, inviter, room_hosts=None): + def _should_do_dance(self, current_state, inviter, room_hosts=None): # TODO: Shouldn't this be remote_room_host? room_hosts = room_hosts or [] - is_host_in_room = self.is_host_in_room(context.current_state) + is_host_in_room = self.is_host_in_room(current_state) if is_host_in_room: return False, room_hosts @@ -344,11 +395,11 @@ class RoomMemberHandler(BaseHandler): defer.returnValue((RoomID.from_string(room_id), servers)) - def get_inviter(self, user_id, current_state): - prev_state = current_state.get((EventTypes.Member, user_id)) - if prev_state and prev_state.membership == Membership.INVITE: - return UserID.from_string(prev_state.user_id) - return None + @defer.inlineCallbacks + def get_inviter(self, user_id, room_id): + invite = yield self.store.get_room_member(user_id=user_id, room_id=room_id) + if invite: + defer.returnValue(UserID.from_string(invite.sender)) @defer.inlineCallbacks def get_joined_rooms_for_user(self, user): diff --git a/synapse/state.py b/synapse/state.py index 4672ada1b3..5a5fd8ff12 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -75,7 +75,8 @@ class StateHandler(object): self._state_cache.start() @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key=""): + def get_current_state(self, room_id, event_type=None, state_key="", + latest_event_ids=None): """ Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. @@ -88,9 +89,10 @@ class StateHandler(object): :returns map from (type, state_key) to event """ - event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - res = yield self.resolve_state_groups(room_id, event_ids) + res = yield self.resolve_state_groups(room_id, latest_event_ids) state = res[1] if event_type: diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 4ab8b35e6b..8853cbb5fc 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase): # set [invite/join/left] of self, set [invite/join/left] of other, # expect all 404s because room doesn't exist on any server for usr in [self.user_id, self.rmcreator_id]: - yield self.join(room=room, user=usr, expect_code=404) - yield self.leave(room=room, user=usr, expect_code=404) + yield self.join(room=room, user=usr, expect_code=403) + yield self.leave(room=room, user=usr, expect_code=403) @defer.inlineCallbacks def test_membership_private_room_perms(self): -- cgit 1.4.1 From aa82cb38e97faa4ad1c15109cbc5b02647ea2461 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 16:36:54 +0100 Subject: Remove state hack from _create_new_client_event --- synapse/handlers/_base.py | 43 ------------------------------------------- 1 file changed, 43 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d407eaeee9..2d8ff79a9e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -221,49 +221,6 @@ class BaseHandler(object): context = yield state_handler.compute_event_context(builder) - # If we've received an invite over federation, there are no latest - # events in the room, because we don't know enough about the graph - # fragment we received to treat it like a graph, so the above returned - # no relevant events. It may have returned some events (if we have - # joined and left the room), but not useful ones, like the invite. - if ( - not self.is_host_in_room(context.current_state) and - builder.type == EventTypes.Member - ): - prev_member_event = yield self.store.get_room_member( - builder.sender, builder.room_id - ) - - # The prev_member_event may already be in context.current_state, - # despite us not being present in the room; in particular, if - # inviting user, and all other local users, have already left. - # - # In that case, we have all the information we need, and we don't - # want to drop "context" - not least because we may need to handle - # the invite locally, which will require us to have the whole - # context (not just prev_member_event) to auth it. - # - context_event_ids = ( - e.event_id for e in context.current_state.values() - ) - - if ( - prev_member_event and - prev_member_event.event_id not in context_event_ids - ): - # The prev_member_event is missing from context, so it must - # have arrived over federation and is an outlier. We forcibly - # set our context to the invite we received over federation - builder.prev_events = ( - prev_member_event.event_id, - prev_member_event.prev_events - ) - - context = yield state_handler.compute_event_context( - builder, - old_state=(prev_member_event,) - ) - if builder.is_state(): builder.prev_state = yield self.store.add_event_hashes( context.prev_state_events -- cgit 1.4.1 From d76d89323c4b962b4f3ff72a3c9b40f2d2d347b3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 17:39:32 +0100 Subject: Use computed prev event ids --- synapse/handlers/_base.py | 29 +++++++++++++++++------------ synapse/handlers/message.py | 6 +++++- synapse/handlers/room_member.py | 3 +++ synapse/storage/event_federation.py | 16 ++++++++++++++++ 4 files changed, 41 insertions(+), 13 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 2d8ff79a9e..242afa1564 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -199,20 +199,25 @@ class BaseHandler(object): ) @defer.inlineCallbacks - def _create_new_client_event(self, builder): - latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( - builder.room_id, - ) - - if latest_ret: - depth = max([d for _, _, d in latest_ret]) + 1 + def _create_new_client_event(self, builder, prev_event_ids=None): + if prev_event_ids: + prev_events = yield self.store.add_event_hashes(prev_event_ids) + prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) + depth = prev_max_depth + 1 else: - depth = 1 + latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( + builder.room_id, + ) - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in latest_ret - ] + if latest_ret: + depth = max([d for _, _, d in latest_ret]) + 1 + else: + depth = 1 + + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in latest_ret + ] builder.prev_events = prev_events builder.depth = depth diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0bb111d047..10608c0dd9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -176,7 +176,7 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_event(self, event_dict, token_id=None, txn_id=None): + def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None): """ Given a dict from a client, create a new event. @@ -187,6 +187,9 @@ class MessageHandler(BaseHandler): Args: event_dict (dict): An entire event + token_id (str) + txn_id (str) + prev_event_ids (list): The prev event ids to use when creating the event Returns: Tuple of created event (FrozenEvent), Context @@ -225,6 +228,7 @@ class MessageHandler(BaseHandler): event, context = yield self._create_new_client_event( builder=builder, + prev_event_ids=prev_event_ids, ) defer.returnValue((event, context)) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 09b8f6217a..7c1bb8cfe4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -98,6 +98,7 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, + prev_event_ids, txn_id=None, ratelimit=True, ): @@ -120,6 +121,7 @@ class RoomMemberHandler(BaseHandler): }, token_id=requester.access_token_id, txn_id=txn_id, + prev_event_ids=prev_event_ids, ) yield self.handle_new_client_event( @@ -268,6 +270,7 @@ class RoomMemberHandler(BaseHandler): membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, + prev_event_ids=latest_event_ids, ) @defer.inlineCallbacks diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 3489315e0d..0827946207 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -163,6 +163,22 @@ class EventFederationStore(SQLBaseStore): room_id, ) + @defer.inlineCallbacks + def get_max_depth_of_events(self, event_ids): + sql = ( + "SELECT MAX(depth) FROM events WHERE event_id IN (%s)" + ) % (",".join(["?"] * len(event_ids)),) + + rows = yield self._execute( + "get_max_depth_of_events", None, + sql, *event_ids + ) + + if rows: + defer.returnValue(rows[0][0]) + else: + defer.returnValue(1) + def _get_min_depth_interaction(self, txn, room_id): min_depth = self._simple_select_one_onecol_txn( txn, -- cgit 1.4.1 From 5fd07da76473f7a361db4b16b58fc4c21acc4af0 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 2 Apr 2016 00:35:49 +0100 Subject: refactor calc_og; spider image URLs; fix xpath; add a (broken) expiringcache; loads of other fixes --- synapse/rest/media/v1/preview_url_resource.py | 202 +++++++++++++++----------- 1 file changed, 121 insertions(+), 81 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index a7ffe593b1..1273472dab 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -20,6 +20,7 @@ from twisted.internet import defer from lxml import html from urlparse import urlparse, urlunparse from synapse.util.stringutils import random_string +from synapse.util.caches.expiringcache import ExpiringCache from synapse.http.client import SpiderHttpClient from synapse.http.server import request_handler, respond_with_json, respond_with_json_bytes @@ -36,6 +37,12 @@ class PreviewUrlResource(BaseMediaResource): def __init__(self, hs, filepaths): BaseMediaResource.__init__(self, hs, filepaths) self.client = SpiderHttpClient(hs) + self.cache = ExpiringCache( + cache_name = "url_previews", + clock = self.clock, + expiry_ms = 60*60*1000, # don't spider URLs more often than once an hour + ) + self.cache.start() def render_GET(self, request): self._async_render_GET(request) @@ -50,6 +57,11 @@ class PreviewUrlResource(BaseMediaResource): requester = yield self.auth.get_user_by_req(request) url = request.args.get("url")[0] + if self.cache: + og = self.cache.get(url) + respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) + return + # TODO: keep track of whether there's an ongoing request for this preview # and block and return their details if there is one. @@ -74,98 +86,25 @@ class PreviewUrlResource(BaseMediaResource): elif self._is_html(media_info['media_type']): # TODO: somehow stop a big HTML tree from exploding synapse's RAM - def _calc_og(): - # suck it up into lxml and define our OG response. - # if we see any URLs in the OG response, then spider them - # (although the client could choose to do this by asking for previews of those URLs to avoid DoSing the server) - - # "og:type" : "article" - # "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672" - # "og:title" : "Matrix on Twitter" - # "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png" - # "og:description" : "Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP" - # "og:site_name" : "Twitter" - - # or: - - # "og:type" : "video", - # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", - # "og:site_name" : "YouTube", - # "og:video:type" : "application/x-shockwave-flash", - # "og:description" : " ", - # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", - # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", - # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", - # "og:video:width" : "1280" - # "og:video:height" : "720", - # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", - - og = {} - for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - og[tag.attrib['property']] = tag.attrib['content'] - - if 'og:title' not in og: - # do some basic spidering of the HTML - title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") - og['og:title'] = title[0].text if title else None - - - if 'og:image' not in og: - meta_image = tree.xpath("//*/meta[@itemprop='image']/@content"); - if meta_image: - og['og:image'] = self._rebase_url(meta_image[0], media_info['uri']) - else: - images = [ i for i in tree.xpath("//img") if 'src' in i.attrib ] - big_images = [ i for i in images if ( - 'width' in i.attrib and 'height' in i.attrib and - i.attrib['width'] > 64 and i.attrib['height'] > 64 - )] - big_images = big_images.sort(key=lambda i: (-1 * int(i.attrib['width']) * int(i.attrib['height']))) - images = big_images if big_images else images - - if images: - og['og:image'] = self._rebase_url(images[0].attrib['src'], media_info['uri']) - - if 'og:description' not in og: - meta_description = tree.xpath("//*/meta[@name='description']/@content"); - if meta_description: - og['og:description'] = meta_description[0] - else: - text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text() | //span/text() | //a/text()") - # text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text()") - text = '' - for text_node in text_nodes: - if len(text) < 500: - text += text_node + ' ' - else: - break - text = re.sub(r'[\t ]+', ' ', text) - text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text) - text = text.strip()[:500] - og['og:description'] = text if text else None - - # TODO: extract a favicon? - # TODO: turn any OG media URLs into mxc URLs to capture and thumbnail them too - # TODO: store our OG details in a cache (and expire them when stale) - # TODO: delete the content to stop diskfilling, as we only ever cared about its OG - return og - try: tree = html.parse(media_info['filename']) - og = _calc_og() + og = yield self._calc_og(tree, media_info, requester) except UnicodeDecodeError: # XXX: evil evil bodge file = open(media_info['filename']) body = file.read() file.close() tree = html.fromstring(body.decode('utf-8','ignore')) - og = _calc_og() + og = yield self._calc_og(tree, media_info, requester) else: logger.warn("Failed to find any OG data in %s", url) og = {} - logger.warn(og) + if self.cache: + self.cache[url] = og + + logger.warn(og); respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) except: @@ -182,11 +121,112 @@ class PreviewUrlResource(BaseMediaResource): ) raise + @defer.inlineCallbacks + def _calc_og(self, tree, media_info, requester): + # suck our tree into lxml and define our OG response. + + # if we see any image URLs in the OG response, then spider them + # (although the client could choose to do this by asking for previews of those URLs to avoid DoSing the server) + + # "og:type" : "article" + # "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672" + # "og:title" : "Matrix on Twitter" + # "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png" + # "og:description" : "Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP" + # "og:site_name" : "Twitter" + + # or: + + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : " ", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + + og = {} + for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): + og[tag.attrib['property']] = tag.attrib['content'] + + # TODO: grab article: meta tags too, e.g.: + + # + # + # + # + # + # + + if 'og:title' not in og: + # do some basic spidering of the HTML + title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") + og['og:title'] = title[0].text.strip() if title else None + + + if 'og:image' not in og: + # TODO: extract a favicon failing all else + meta_image = tree.xpath("//*/meta[@itemprop='image']/@content"); + if meta_image: + og['og:image'] = self._rebase_url(meta_image[0], media_info['uri']) + else: + images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") + images = sorted(images, key=lambda i: (-1 * int(i.attrib['width']) * int(i.attrib['height']))) + if not images: + images = tree.xpath("//img[@src]") + if images: + og['og:image'] = self._rebase_url(images[0].attrib['src'], media_info['uri']) + + # pre-cache the image for posterity + if 'og:image' in og and og['og:image']: + image_info = yield self._download_url(og['og:image'], requester.user) + + if self._is_media(image_info['media_type']): + # TODO: make sure we don't choke on white-on-transparent images + dims = yield self._generate_local_thumbnails( + image_info['filesystem_id'], image_info + ) + og["og:image"] = "mxc://%s/%s" % (self.server_name, image_info['filesystem_id']) + og["og:image:type"] = image_info['media_type'] + og["og:image:width"] = dims['width'] + og["og:image:height"] = dims['height'] + else: + del og["og:image"] + + if 'og:description' not in og: + meta_description = tree.xpath("//*/meta[@name='description']/@content"); + if meta_description: + og['og:description'] = meta_description[0] + else: + # text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text() | //span/text() | //a/text()") + text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | ancestor::aside | " + + "ancestor::footer | ancestor::script | ancestor::style)]" + + "[ancestor::body]") + text = '' + for text_node in text_nodes: + if len(text) < 500: + text += text_node + ' ' + else: + break + text = re.sub(r'[\t ]+', ' ', text) + text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text) + text = text.strip()[:500] + og['og:description'] = text if text else None + + # TODO: persist a cache mapping { url, etag } -> { og, mxc of url (if we bother keeping it around), age } + # TODO: delete the url downloads to stop diskfilling, as we only ever cared about its OG + defer.returnValue(og); + def _rebase_url(self, url, base): base = list(urlparse(base)) url = list(urlparse(url)) - if not url[0] and not url[1]: - url[0] = base[0] + if not url[0]: + url[0] = base[0] or "http" + if not url[1]: url[1] = base[1] if not url[2].startswith('/'): url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] -- cgit 1.4.1 From b26e8604f168b0f1ecc095bd0d6a717128361a41 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 2 Apr 2016 01:35:44 +0100 Subject: make meta comparisons case insensitive --- synapse/rest/media/v1/preview_url_resource.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 1273472dab..77757548bd 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -170,7 +170,7 @@ class PreviewUrlResource(BaseMediaResource): if 'og:image' not in og: # TODO: extract a favicon failing all else - meta_image = tree.xpath("//*/meta[@itemprop='image']/@content"); + meta_image = tree.xpath("//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"); if meta_image: og['og:image'] = self._rebase_url(meta_image[0], media_info['uri']) else: @@ -198,7 +198,7 @@ class PreviewUrlResource(BaseMediaResource): del og["og:image"] if 'og:description' not in og: - meta_description = tree.xpath("//*/meta[@name='description']/@content"); + meta_description = tree.xpath("//*/meta[translate(@name, 'DESCRIPTION', 'description')='description']/@content"); if meta_description: og['og:description'] = meta_description[0] else: -- cgit 1.4.1 From 5037ee0d37f7e5c7a62f5af5ceef5363701e3202 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 2 Apr 2016 02:29:57 +0100 Subject: handle missing dimensions without crashing --- synapse/rest/media/v1/preview_url_resource.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 77757548bd..3ffdafce09 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -78,10 +78,14 @@ class PreviewUrlResource(BaseMediaResource): "og:description" : media_info['download_name'], "og:image" : "mxc://%s/%s" % (self.server_name, media_info['filesystem_id']), "og:image:type" : media_info['media_type'], - "og:image:width" : dims['width'], - "og:image:height" : dims['height'], } + if dims: + og["og:image:width"] = dims['width'] + og["og:image:height"] = dims['height'] + else: + logger.warn("Couldn't get dims for %s" % url) + # define our OG response for this media elif self._is_html(media_info['media_type']): # TODO: somehow stop a big HTML tree from exploding synapse's RAM @@ -174,6 +178,7 @@ class PreviewUrlResource(BaseMediaResource): if meta_image: og['og:image'] = self._rebase_url(meta_image[0], media_info['uri']) else: + # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") images = sorted(images, key=lambda i: (-1 * int(i.attrib['width']) * int(i.attrib['height']))) if not images: @@ -190,10 +195,14 @@ class PreviewUrlResource(BaseMediaResource): dims = yield self._generate_local_thumbnails( image_info['filesystem_id'], image_info ) + if dims: + og["og:image:width"] = dims['width'] + og["og:image:height"] = dims['height'] + else: + logger.warn("Couldn't get dims for %s" % og["og:image"]) + og["og:image"] = "mxc://%s/%s" % (self.server_name, image_info['filesystem_id']) og["og:image:type"] = image_info['media_type'] - og["og:image:width"] = dims['width'] - og["og:image:height"] = dims['height'] else: del og["og:image"] -- cgit 1.4.1 From 2c838f6459db35ad9812a83184d85a06ca5d940a Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 2 Apr 2016 02:30:07 +0100 Subject: pass back SVGs as their own thumbnails --- synapse/rest/media/v1/thumbnail_resource.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) (limited to 'synapse') diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index ab52499785..1e71738bc4 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -72,6 +72,11 @@ class ThumbnailResource(BaseMediaResource): self._respond_404(request) return + if media_info["media_type"] == "image/svg+xml": + file_path = self.filepaths.local_media_filepath(media_id) + yield self._respond_with_file(request, media_info["media_type"], file_path) + return + thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) if thumbnail_infos: @@ -103,6 +108,11 @@ class ThumbnailResource(BaseMediaResource): self._respond_404(request) return + if media_info["media_type"] == "image/svg+xml": + file_path = self.filepaths.local_media_filepath(media_id) + yield self._respond_with_file(request, media_info["media_type"], file_path) + return + thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: t_w = info["thumbnail_width"] == desired_width @@ -138,6 +148,11 @@ class ThumbnailResource(BaseMediaResource): desired_method, desired_type): media_info = yield self._get_remote_media(server_name, media_id) + if media_info["media_type"] == "image/svg+xml": + file_path = self.filepaths.remote_media_filepath(server_name, media_id) + yield self._respond_with_file(request, media_info["media_type"], file_path) + return + thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, ) @@ -181,6 +196,11 @@ class ThumbnailResource(BaseMediaResource): # We should proxy the thumbnail from the remote server instead. media_info = yield self._get_remote_media(server_name, media_id) + if media_info["media_type"] == "image/svg+xml": + file_path = self.filepaths.remote_media_filepath(server_name, media_id) + yield self._respond_with_file(request, media_info["media_type"], file_path) + return + thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, ) -- cgit 1.4.1 From 93771579610d723488486f40622d6c99ed061d7f Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 2 Apr 2016 02:31:45 +0100 Subject: how was _respond_default_thumbnail ever meant to work? --- synapse/rest/media/v1/thumbnail_resource.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse') diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 1e71738bc4..513b445688 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -228,6 +228,8 @@ class ThumbnailResource(BaseMediaResource): @defer.inlineCallbacks def _respond_default_thumbnail(self, request, media_info, width, height, method, m_type): + # XXX: how is this meant to work? store.get_default_thumbnails + # appears to always return [] so won't this always 404? media_type = media_info["media_type"] top_level_type = media_type.split("/")[0] sub_type = media_type.split("/")[-1].split(";")[0] -- cgit 1.4.1 From d1b154a10fc0f71fb36010f784ca6570f845c8d5 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 2 Apr 2016 03:06:39 +0100 Subject: support gzip compression, and don't pass through error msgs --- synapse/http/client.py | 11 ++++++++--- synapse/rest/media/v1/preview_url_resource.py | 5 +++-- 2 files changed, 11 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/http/client.py b/synapse/http/client.py index 1b6f7cb795..b21bf17378 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -23,7 +23,8 @@ from canonicaljson import encode_canonical_json from twisted.internet import defer, reactor, ssl, protocol from twisted.web.client import ( - BrowserLikeRedirectAgent, Agent, readBody, FileBodyProducer, PartialDownloadError, + BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, + readBody, FileBodyProducer, PartialDownloadError, ) from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers @@ -269,6 +270,10 @@ class SimpleHttpClient(object): # XXX: do we want to explicitly drop the connection here somehow? if so, how? raise # what should we be raising here? + if response.code > 299: + logger.warn("Got %d when downloading %s" % (response.code, url)) + raise + # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it # straight back in again @@ -366,11 +371,11 @@ class SpiderHttpClient(SimpleHttpClient): def __init__(self, hs): SimpleHttpClient.__init__(self, hs) # clobber the base class's agent and UA: - self.agent = BrowserLikeRedirectAgent(Agent( + self.agent = ContentDecoderAgent(BrowserLikeRedirectAgent(Agent( reactor, connectTimeout=15, contextFactory=hs.get_http_client_context_factory() - )) + )), [('gzip', GzipDecoder)]) # Look like Chrome for now #self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) Chrome Safari" % hs.version_string) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 3ffdafce09..162e09ba71 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -200,7 +200,7 @@ class PreviewUrlResource(BaseMediaResource): og["og:image:height"] = dims['height'] else: logger.warn("Couldn't get dims for %s" % og["og:image"]) - + og["og:image"] = "mxc://%s/%s" % (self.server_name, image_info['filesystem_id']) og["og:image:type"] = image_info['media_type'] else: @@ -259,7 +259,8 @@ class PreviewUrlResource(BaseMediaResource): length, headers, uri = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) - # FIXME: handle 404s sanely - don't spider an error page + # FIXME: pass through 404s and other error messages nicely + media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() -- cgit 1.4.1 From 7426c86eb88a7abef9af7ba544ccd709b25e8304 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 3 Apr 2016 00:31:57 +0100 Subject: add a persistent cache of URL lookups, and fix up the in-memory one to work --- synapse/http/client.py | 6 +- synapse/rest/media/v1/preview_url_resource.py | 64 ++++++++++++++++++---- synapse/storage/media_repository.py | 54 +++++++++++++++++- .../delta/30/local_media_repository_url_cache.sql | 27 +++++++++ 4 files changed, 137 insertions(+), 14 deletions(-) create mode 100644 synapse/storage/schema/delta/30/local_media_repository_url_cache.sql (limited to 'synapse') diff --git a/synapse/http/client.py b/synapse/http/client.py index b21bf17378..f42a36ffa6 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -251,8 +251,8 @@ class SimpleHttpClient(object): url (str): The URL to GET output_stream (file): File to write the response body to. Returns: - A (int,dict) tuple of the file length and a dict of the response - headers. + A (int,dict,string,int) tuple of the file length, dict of the response + headers, absolute URI of the response and HTTP response code. """ response = yield self.request( @@ -287,7 +287,7 @@ class SimpleHttpClient(object): logger.exception("Failed to download body") raise - defer.returnValue((length, headers, response.request.absoluteURI)) + defer.returnValue((length, headers, response.request.absoluteURI, response.code)) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 162e09ba71..86341cc4cc 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -37,6 +37,8 @@ class PreviewUrlResource(BaseMediaResource): def __init__(self, hs, filepaths): BaseMediaResource.__init__(self, hs, filepaths) self.client = SpiderHttpClient(hs) + + # simple memory cache mapping urls to OG metadata self.cache = ExpiringCache( cache_name = "url_previews", clock = self.clock, @@ -56,17 +58,41 @@ class PreviewUrlResource(BaseMediaResource): # XXX: if get_user_by_req fails, what should we do in an async render? requester = yield self.auth.get_user_by_req(request) url = request.args.get("url")[0] - - if self.cache: - og = self.cache.get(url) - respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) - return + ts = request.args.get("ts")[0] if "ts" in request.args else self.clock.time_msec() # TODO: keep track of whether there's an ongoing request for this preview # and block and return their details if there is one. + # first check the memory cache - good to handle all the clients on this + # HS thundering away to preview the same URL at the same time. + try: + og = self.cache[url] + respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) + return + except: + pass + + # then check the URL cache in the DB (which will also provide us with + # historical previews, if we have any) + cache_result = yield self.store.get_url_cache(url, ts) + if ( + cache_result and + cache_result["download_ts"] + cache_result["expires"] > ts and + cache_result["response_code"] / 100 == 2 + ): + respond_with_json_bytes( + request, 200, cache_result["og"].encode('utf-8'), + send_cors=True + ) + return + media_info = yield self._download_url(url, requester.user) + # FIXME: we should probably update our cache now anyway, so that + # even if the OG calculation raises, we don't keep hammering on the + # remote server. For now, leave it uncached to aid debugging OG + # calculation problems + logger.debug("got media_info of '%s'" % media_info) if self._is_media(media_info['media_type']): @@ -105,10 +131,21 @@ class PreviewUrlResource(BaseMediaResource): logger.warn("Failed to find any OG data in %s", url) og = {} - if self.cache: - self.cache[url] = og + logger.debug("Calculated OG for %s as %s" % (url, og)); + + # store OG in ephemeral in-memory cache + self.cache[url] = og - logger.warn(og); + # store OG in history-aware DB cache + yield self.store.store_url_cache( + url, + media_info["response_code"], + media_info["etag"], + media_info["expires"], + json.dumps(og), + media_info["filesystem_id"], + media_info["created_ts"], + ) respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) except: @@ -187,6 +224,9 @@ class PreviewUrlResource(BaseMediaResource): og['og:image'] = self._rebase_url(images[0].attrib['src'], media_info['uri']) # pre-cache the image for posterity + # FIXME: it might be cleaner to use the same flow as the main /preview_url request itself + # and benefit from the same caching etc. But for now we just rely on the caching + # of the master request to speed things up. if 'og:image' in og and og['og:image']: image_info = yield self._download_url(og['og:image'], requester.user) @@ -226,7 +266,6 @@ class PreviewUrlResource(BaseMediaResource): text = text.strip()[:500] og['og:description'] = text if text else None - # TODO: persist a cache mapping { url, etag } -> { og, mxc of url (if we bother keeping it around), age } # TODO: delete the url downloads to stop diskfilling, as we only ever cared about its OG defer.returnValue(og); @@ -256,7 +295,7 @@ class PreviewUrlResource(BaseMediaResource): try: with open(fname, "wb") as f: logger.debug("Trying to get url '%s'" % url) - length, headers, uri = yield self.client.get_file( + length, headers, uri, code = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) # FIXME: pass through 404s and other error messages nicely @@ -311,6 +350,11 @@ class PreviewUrlResource(BaseMediaResource): "filesystem_id": file_id, "filename": fname, "uri": uri, + "response_code": code, + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + "expires": 60 * 60 * 1000, + "etag": headers["ETag"] if "ETag" in headers else None, }) def _is_media(self, content_type): diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 9d3ba32478..bb002081ae 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -25,7 +25,7 @@ class MediaRepositoryStore(SQLBaseStore): def get_local_media(self, media_id): """Get the metadata for a local piece of media Returns: - None if the meia_id doesn't exist. + None if the media_id doesn't exist. """ return self._simple_select_one( "local_media_repository", @@ -50,6 +50,58 @@ class MediaRepositoryStore(SQLBaseStore): desc="store_local_media", ) + def get_url_cache(self, url, ts): + """Get the media_id and ts for a cached URL as of the given timestamp + Returns: + None if the URL isn't cached. + """ + def get_url_cache_txn(txn): + # get the most recently cached result (relative to the given ts) + sql = ( + "SELECT response_code, etag, expires, og, media_id, max(download_ts)" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts <= ?" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row[3]: + # ...or if we've requested a timestamp older than the oldest + # copy in the cache, return the oldest copy (if any) + sql = ( + "SELECT response_code, etag, expires, og, media_id, min(download_ts)" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts > ?" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row[3]: + return None + + return dict(zip(( + 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts' + ), row)) + + return self.runInteraction( + "get_url_cache", get_url_cache_txn + ) + + def store_url_cache(self, url, response_code, etag, expires, og, media_id, download_ts): + return self._simple_insert( + "local_media_repository_url_cache", + { + "url": url, + "response_code": response_code, + "etag": etag, + "expires": expires, + "og": og, + "media_id": media_id, + "download_ts": download_ts, + }, + desc="store_url_cache", + ) + def get_local_media_thumbnails(self, media_id): return self._simple_select_list( "local_media_repository_thumbnails", diff --git a/synapse/storage/schema/delta/30/local_media_repository_url_cache.sql b/synapse/storage/schema/delta/30/local_media_repository_url_cache.sql new file mode 100644 index 0000000000..9efb4280eb --- /dev/null +++ b/synapse/storage/schema/delta/30/local_media_repository_url_cache.sql @@ -0,0 +1,27 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE local_media_repository_url_cache( + url TEXT, -- the URL being cached + response_code INTEGER, -- the HTTP response code of this download attempt + etag TEXT, -- the etag header of this response + expires INTEGER, -- the number of ms this response was valid for + og TEXT, -- cache of the OG metadata of this URL as JSON + media_id TEXT, -- the media_id, if any, of the URL's content in the repo + download_ts BIGINT -- the timestamp of this download attempt +); + +CREATE INDEX local_media_repository_url_cache_by_url_download_ts + ON local_media_repository_url_cache(url, download_ts); -- cgit 1.4.1 From b09e29a03ca95c577215acbe8d5037d6337e1af3 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 3 Apr 2016 00:47:40 +0100 Subject: Ensure only one download for a given URL is active at a time --- synapse/rest/media/v1/preview_url_resource.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 86341cc4cc..c20de57991 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -23,6 +23,7 @@ from synapse.util.stringutils import random_string from synapse.util.caches.expiringcache import ExpiringCache from synapse.http.client import SpiderHttpClient from synapse.http.server import request_handler, respond_with_json, respond_with_json_bytes +from synapse.util.async import ObservableDeferred import os import re @@ -46,6 +47,8 @@ class PreviewUrlResource(BaseMediaResource): ) self.cache.start() + self.downloads = {} + def render_GET(self, request): self._async_render_GET(request) return NOT_DONE_YET @@ -86,7 +89,21 @@ class PreviewUrlResource(BaseMediaResource): ) return - media_info = yield self._download_url(url, requester.user) + # Ensure only one download for a given URL is active at a time + download = self.downloads.get(url) + if download is None: + download = self._download_url(url, requester.user) + download = ObservableDeferred( + download, + consumeErrors=True + ) + self.downloads[url] = download + + @download.addBoth + def callback(media_info): + del self.downloads[key] + return media_info + media_info = yield download.observe() # FIXME: we should probably update our cache now anyway, so that # even if the OG calculation raises, we don't keep hammering on the -- cgit 1.4.1 From 110780b18b029c5b6f1c34f7b4e027b88ea8b8ce Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 3 Apr 2016 00:48:31 +0100 Subject: remove stale todo --- synapse/rest/media/v1/preview_url_resource.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index c20de57991..582dd20fa6 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -63,9 +63,6 @@ class PreviewUrlResource(BaseMediaResource): url = request.args.get("url")[0] ts = request.args.get("ts")[0] if "ts" in request.args else self.clock.time_msec() - # TODO: keep track of whether there's an ongoing request for this preview - # and block and return their details if there is one. - # first check the memory cache - good to handle all the clients on this # HS thundering away to preview the same URL at the same time. try: -- cgit 1.4.1 From c3916462f68df84df29ad924c07f8e83c0143fcc Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 3 Apr 2016 01:33:12 +0100 Subject: rebase all image URLs --- synapse/rest/media/v1/preview_url_resource.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 582dd20fa6..31ce2b5831 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -235,14 +235,14 @@ class PreviewUrlResource(BaseMediaResource): if not images: images = tree.xpath("//img[@src]") if images: - og['og:image'] = self._rebase_url(images[0].attrib['src'], media_info['uri']) + og['og:image'] = images[0].attrib['src'] # pre-cache the image for posterity # FIXME: it might be cleaner to use the same flow as the main /preview_url request itself # and benefit from the same caching etc. But for now we just rely on the caching # of the master request to speed things up. if 'og:image' in og and og['og:image']: - image_info = yield self._download_url(og['og:image'], requester.user) + image_info = yield self._download_url(self._rebase_url(og['og:image'], media_info['uri']), requester.user) if self._is_media(image_info['media_type']): # TODO: make sure we don't choke on white-on-transparent images @@ -286,9 +286,9 @@ class PreviewUrlResource(BaseMediaResource): def _rebase_url(self, url, base): base = list(urlparse(base)) url = list(urlparse(url)) - if not url[0]: + if not url[0]: # fix up schema url[0] = base[0] or "http" - if not url[1]: + if not url[1]: # fix up hostname url[1] = base[1] if not url[2].startswith('/'): url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] -- cgit 1.4.1 From eab4d462f8e5d17c5ca7592d1ea15d8e4771a00c Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 3 Apr 2016 02:02:46 +0100 Subject: fix etag typing error. fix timestamp typing error --- synapse/rest/media/v1/preview_url_resource.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 31ce2b5831..7c69c01a6c 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -61,7 +61,7 @@ class PreviewUrlResource(BaseMediaResource): # XXX: if get_user_by_req fails, what should we do in an async render? requester = yield self.auth.get_user_by_req(request) url = request.args.get("url")[0] - ts = request.args.get("ts")[0] if "ts" in request.args else self.clock.time_msec() + ts = int(request.args.get("ts")[0]) if "ts" in request.args else self.clock.time_msec() # first check the memory cache - good to handle all the clients on this # HS thundering away to preview the same URL at the same time. @@ -368,7 +368,7 @@ class PreviewUrlResource(BaseMediaResource): # FIXME: we should calculate a proper expiration based on the # Cache-Control and Expire headers. But for now, assume 1 hour. "expires": 60 * 60 * 1000, - "etag": headers["ETag"] if "ETag" in headers else None, + "etag": headers["ETag"][0] if "ETag" in headers else None, }) def _is_media(self, content_type): -- cgit 1.4.1 From 8b98a7e8c37f0fae09f33a6d93953584288ed394 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 3 Apr 2016 12:56:29 +0100 Subject: pep8 --- synapse/http/client.py | 14 ++- synapse/rest/media/v1/media_repository.py | 1 - synapse/rest/media/v1/preview_url_resource.py | 127 +++++++++++++++----------- synapse/storage/media_repository.py | 3 +- 4 files changed, 85 insertions(+), 60 deletions(-) (limited to 'synapse') diff --git a/synapse/http/client.py b/synapse/http/client.py index f42a36ffa6..442b4bb73d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -15,7 +15,9 @@ from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE -from synapse.api.errors import CodeMessageException +from synapse.api.errors import ( + CodeMessageException, SynapseError, Codes, +) from synapse.util.logcontext import preserve_context_over_fn import synapse.metrics @@ -268,7 +270,7 @@ class SimpleHttpClient(object): if 'Content-Length' in headers and headers['Content-Length'] > max_size: logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) # XXX: do we want to explicitly drop the connection here somehow? if so, how? - raise # what should we be raising here? + raise # what should we be raising here? if response.code > 299: logger.warn("Got %d when downloading %s" % (response.code, url)) @@ -331,6 +333,7 @@ def _readBodyToFile(response, stream, max_size): response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) return d + class CaptchaServerHttpClient(SimpleHttpClient): """ Separate HTTP client for talking to google's captcha servers @@ -360,6 +363,7 @@ class CaptchaServerHttpClient(SimpleHttpClient): # twisted dislikes google's response, no content length. defer.returnValue(e.response) + class SpiderHttpClient(SimpleHttpClient): """ Separate HTTP client for spidering arbitrary URLs. @@ -376,8 +380,10 @@ class SpiderHttpClient(SimpleHttpClient): connectTimeout=15, contextFactory=hs.get_http_client_context_factory() )), [('gzip', GzipDecoder)]) - # Look like Chrome for now - #self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) Chrome Safari" % hs.version_string) + # We could look like Chrome: + # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) + # Chrome Safari" % hs.version_string) + def encode_urlencode_args(args): return {k: encode_urlencode_arg(v) for k, v in args.items()} diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 8f3491b91c..11f672aeab 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -80,4 +80,3 @@ class MediaRepositoryResource(Resource): self.putChild("thumbnail", ThumbnailResource(hs, filepaths)) self.putChild("identicon", IdenticonResource()) self.putChild("preview_url", PreviewUrlResource(hs, filepaths)) - diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 7c69c01a6c..29db5c7fce 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -13,25 +13,31 @@ # limitations under the License. from .base_resource import BaseMediaResource -from synapse.api.errors import Codes -from twisted.web.resource import Resource + from twisted.web.server import NOT_DONE_YET from twisted.internet import defer from lxml import html from urlparse import urlparse, urlunparse + +from synapse.api.errors import Codes from synapse.util.stringutils import random_string from synapse.util.caches.expiringcache import ExpiringCache from synapse.http.client import SpiderHttpClient -from synapse.http.server import request_handler, respond_with_json, respond_with_json_bytes +from synapse.http.server import ( + request_handler, respond_with_json, respond_with_json_bytes +) from synapse.util.async import ObservableDeferred +from synapse.util.stringutils import is_ascii import os import re +import cgi import ujson as json import logging logger = logging.getLogger(__name__) + class PreviewUrlResource(BaseMediaResource): isLeaf = True @@ -41,9 +47,10 @@ class PreviewUrlResource(BaseMediaResource): # simple memory cache mapping urls to OG metadata self.cache = ExpiringCache( - cache_name = "url_previews", - clock = self.clock, - expiry_ms = 60*60*1000, # don't spider URLs more often than once an hour + cache_name="url_previews", + clock=self.clock, + # don't spider URLs more often than once an hour + expiry_ms=60 * 60 * 1000, ) self.cache.start() @@ -56,12 +63,15 @@ class PreviewUrlResource(BaseMediaResource): @request_handler @defer.inlineCallbacks def _async_render_GET(self, request): - + try: # XXX: if get_user_by_req fails, what should we do in an async render? requester = yield self.auth.get_user_by_req(request) url = request.args.get("url")[0] - ts = int(request.args.get("ts")[0]) if "ts" in request.args else self.clock.time_msec() + if "ts" in request.args: + ts = int(request.args.get("ts")[0]) + else: + ts = self.clock.time_msec() # first check the memory cache - good to handle all the clients on this # HS thundering away to preview the same URL at the same time. @@ -98,7 +108,7 @@ class PreviewUrlResource(BaseMediaResource): @download.addBoth def callback(media_info): - del self.downloads[key] + del self.downloads[url] return media_info media_info = yield download.observe() @@ -111,13 +121,15 @@ class PreviewUrlResource(BaseMediaResource): if self._is_media(media_info['media_type']): dims = yield self._generate_local_thumbnails( - media_info['filesystem_id'], media_info - ) + media_info['filesystem_id'], media_info + ) og = { - "og:description" : media_info['download_name'], - "og:image" : "mxc://%s/%s" % (self.server_name, media_info['filesystem_id']), - "og:image:type" : media_info['media_type'], + "og:description": media_info['download_name'], + "og:image": "mxc://%s/%s" % ( + self.server_name, media_info['filesystem_id'] + ), + "og:image:type": media_info['media_type'], } if dims: @@ -138,14 +150,14 @@ class PreviewUrlResource(BaseMediaResource): file = open(media_info['filename']) body = file.read() file.close() - tree = html.fromstring(body.decode('utf-8','ignore')) + tree = html.fromstring(body.decode('utf-8', 'ignore')) og = yield self._calc_og(tree, media_info, requester) else: logger.warn("Failed to find any OG data in %s", url) og = {} - logger.debug("Calculated OG for %s as %s" % (url, og)); + logger.debug("Calculated OG for %s as %s" % (url, og)) # store OG in ephemeral in-memory cache self.cache[url] = og @@ -181,28 +193,20 @@ class PreviewUrlResource(BaseMediaResource): # suck our tree into lxml and define our OG response. # if we see any image URLs in the OG response, then spider them - # (although the client could choose to do this by asking for previews of those URLs to avoid DoSing the server) - - # "og:type" : "article" - # "og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672" - # "og:title" : "Matrix on Twitter" - # "og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png" - # "og:description" : "Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP" - # "og:site_name" : "Twitter" - - # or: + # (although the client could choose to do this by asking for previews of those + # URLs to avoid DoSing the server) # "og:type" : "video", # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", # "og:site_name" : "YouTube", # "og:video:type" : "application/x-shockwave-flash", - # "og:description" : " ", + # "og:description" : "Fun stuff happening here", # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", # "og:video:width" : "1280" # "og:video:height" : "720", - # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", og = {} for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): @@ -210,64 +214,76 @@ class PreviewUrlResource(BaseMediaResource): # TODO: grab article: meta tags too, e.g.: - # - # - # - # - # - # + # "article:publisher" : "https://www.facebook.com/thethudonline" /> + # "article:author" content="https://www.facebook.com/thethudonline" /> + # "article:tag" content="baby" /> + # "article:section" content="Breaking News" /> + # "article:published_time" content="2016-03-31T19:58:24+00:00" /> + # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> if 'og:title' not in og: # do some basic spidering of the HTML title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") og['og:title'] = title[0].text.strip() if title else None - if 'og:image' not in og: # TODO: extract a favicon failing all else - meta_image = tree.xpath("//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"); + meta_image = tree.xpath( + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" + ) if meta_image: og['og:image'] = self._rebase_url(meta_image[0], media_info['uri']) else: # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") - images = sorted(images, key=lambda i: (-1 * int(i.attrib['width']) * int(i.attrib['height']))) + images = sorted(images, key=lambda i: ( + -1 * int(i.attrib['width']) * int(i.attrib['height']) + )) if not images: images = tree.xpath("//img[@src]") if images: og['og:image'] = images[0].attrib['src'] # pre-cache the image for posterity - # FIXME: it might be cleaner to use the same flow as the main /preview_url request itself - # and benefit from the same caching etc. But for now we just rely on the caching - # of the master request to speed things up. + # FIXME: it might be cleaner to use the same flow as the main /preview_url request + # itself and benefit from the same caching etc. But for now we just rely on the + # caching on the master request to speed things up. if 'og:image' in og and og['og:image']: - image_info = yield self._download_url(self._rebase_url(og['og:image'], media_info['uri']), requester.user) + image_info = yield self._download_url( + self._rebase_url(og['og:image'], media_info['uri']), requester.user + ) if self._is_media(image_info['media_type']): # TODO: make sure we don't choke on white-on-transparent images dims = yield self._generate_local_thumbnails( - image_info['filesystem_id'], image_info - ) + image_info['filesystem_id'], image_info + ) if dims: og["og:image:width"] = dims['width'] og["og:image:height"] = dims['height'] else: logger.warn("Couldn't get dims for %s" % og["og:image"]) - og["og:image"] = "mxc://%s/%s" % (self.server_name, image_info['filesystem_id']) + og["og:image"] = "mxc://%s/%s" % ( + self.server_name, image_info['filesystem_id'] + ) og["og:image:type"] = image_info['media_type'] else: del og["og:image"] if 'og:description' not in og: - meta_description = tree.xpath("//*/meta[translate(@name, 'DESCRIPTION', 'description')='description']/@content"); + meta_description = tree.xpath( + "//*/meta" + "[translate(@name, 'DESCRIPTION', 'description')='description']" + "/@content") if meta_description: og['og:description'] = meta_description[0] else: - # text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | //p/text() | //div/text() | //span/text() | //a/text()") - text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | ancestor::aside | " + - "ancestor::footer | ancestor::script | ancestor::style)]" + + # text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | " + # "//p/text() | //div/text() | //span/text() | //a/text()") + text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | " + "ancestor::aside | ancestor::footer | " + "ancestor::script | ancestor::style)]" + "[ancestor::body]") text = '' for text_node in text_nodes: @@ -280,15 +296,16 @@ class PreviewUrlResource(BaseMediaResource): text = text.strip()[:500] og['og:description'] = text if text else None - # TODO: delete the url downloads to stop diskfilling, as we only ever cared about its OG - defer.returnValue(og); + # TODO: delete the url downloads to stop diskfilling, + # as we only ever cared about its OG + defer.returnValue(og) def _rebase_url(self, url, base): base = list(urlparse(base)) url = list(urlparse(url)) - if not url[0]: # fix up schema + if not url[0]: # fix up schema url[0] = base[0] or "http" - if not url[1]: # fix up hostname + if not url[1]: # fix up hostname url[1] = base[1] if not url[2].startswith('/'): url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] @@ -377,6 +394,8 @@ class PreviewUrlResource(BaseMediaResource): def _is_html(self, content_type): content_type = content_type.lower() - if (content_type.startswith("text/html") or - content_type.startswith("application/xhtml")): + if ( + content_type.startswith("text/html") or + content_type.startswith("application/xhtml") + ): return True diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index bb002081ae..c9dd20eed8 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -87,7 +87,8 @@ class MediaRepositoryStore(SQLBaseStore): "get_url_cache", get_url_cache_txn ) - def store_url_cache(self, url, response_code, etag, expires, og, media_id, download_ts): + def store_url_cache(self, url, response_code, etag, expires, og, media_id, + download_ts): return self._simple_insert( "local_media_repository_url_cache", { -- cgit 1.4.1 From 0834b152fb05e110428a4834a2e5dc51b6f7d327 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 3 Apr 2016 12:59:27 +0100 Subject: char encoding --- synapse/rest/media/v1/preview_url_resource.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 29db5c7fce..ff522c5fb8 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); -- cgit 1.4.1 From cf51c4120e79a59a798fcf88c5c7d9f95dc6e76d Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 3 Apr 2016 23:57:05 +0100 Subject: report image size (bytewise) in OG meta --- synapse/rest/media/v1/preview_url_resource.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index ff522c5fb8..f5ec32d8f2 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -131,6 +131,7 @@ class PreviewUrlResource(BaseMediaResource): self.server_name, media_info['filesystem_id'] ), "og:image:type": media_info['media_type'], + "matrix:image:size": media_info['media_length'], } if dims: @@ -269,6 +270,7 @@ class PreviewUrlResource(BaseMediaResource): self.server_name, image_info['filesystem_id'] ) og["og:image:type"] = image_info['media_type'] + og["matrix:image:size"] = image_info['media_length'] else: del og["og:image"] -- cgit 1.4.1 From 3d76b7cb2ba05fbf17be0a6647f39c419f428c16 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Apr 2016 15:52:01 +0100 Subject: Store invites in a separate table. --- synapse/handlers/room_member.py | 2 +- synapse/storage/events.py | 13 +--- synapse/storage/prepare_database.py | 2 +- synapse/storage/roommember.py | 111 ++++++++++++++++++++++------ synapse/storage/schema/delta/31/invites.sql | 28 +++++++ 5 files changed, 124 insertions(+), 32 deletions(-) create mode 100644 synapse/storage/schema/delta/31/invites.sql (limited to 'synapse') diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7c1bb8cfe4..98e346d48e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -400,7 +400,7 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def get_inviter(self, user_id, room_id): - invite = yield self.store.get_room_member(user_id=user_id, room_id=room_id) + invite = yield self.store.get_inviter(user_id=user_id, room_id=room_id) if invite: defer.returnValue(UserID.from_string(invite.sender)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c4dc3b3d51..5d299a1132 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -367,7 +367,8 @@ class EventsStore(SQLBaseStore): event for event, _ in events_and_contexts if event.type == EventTypes.Member - ] + ], + backfilled=backfilled, ) def event_dict(event): @@ -485,14 +486,8 @@ class EventsStore(SQLBaseStore): return for event, _ in state_events_and_contexts: - if (not event.internal_metadata.is_invite_from_remote() - and event.internal_metadata.is_outlier()): - # Outlier events generally shouldn't clobber the current state. - # However invites from remote severs for rooms we aren't in - # are a bit special: they don't come with any associated - # state so are technically an outlier, however all the - # client-facing code assumes that they are in the current - # state table so we insert the event anyway. + if event.internal_metadata.is_outlier(): + # Outlier events shouldn't clobber the current state. continue if context.rejected: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 3f29aad1e8..4099387ba7 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 30 +SCHEMA_VERSION = 31 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 430b49c12e..4c026b33ae 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -36,7 +36,7 @@ RoomsForUser = namedtuple( class RoomMemberStore(SQLBaseStore): - def _store_room_members_txn(self, txn, events): + def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ self._simple_insert_many_txn( @@ -62,6 +62,41 @@ class RoomMemberStore(SQLBaseStore): self._membership_stream_cache.entity_has_changed, event.state_key, event.internal_metadata.stream_ordering ) + txn.call_after( + self.get_invited_rooms_for_user.invalidate, (event.state_key,) + ) + + is_mine = self.hs.is_mine_id(event.state_key) + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_invite_from_remote() + ) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self._simple_insert_txn( + txn, + table="invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + } + ) + else: + sql = ( + "UPDATE invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + )) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. @@ -127,6 +162,14 @@ class RoomMemberStore(SQLBaseStore): user_id, [Membership.INVITE] ) + @defer.inlineCallbacks + def get_inviter(self, user_id, room_id): + invites = yield self.get_invited_rooms_for_user(user_id) + for invite in invites: + if invite.room_id == room_id: + defer.returnValue(invite) + defer.returnValue(None) + def get_leave_and_ban_events_for_user(self, user_id): """ Get all the leave events for a user Args: @@ -163,29 +206,55 @@ class RoomMemberStore(SQLBaseStore): def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, membership_list): - where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( - " OR ".join(["membership = ?" for _ in membership_list]), - ) - args = [user_id] - args.extend(membership_list) + do_invite = Membership.INVITE in membership_list + membership_list = [m for m in membership_list if m != Membership.INVITE] - sql = ( - "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" - " FROM current_state_events as c" - " INNER JOIN room_memberships as m" - " ON m.event_id = c.event_id" - " INNER JOIN events as e" - " ON e.event_id = c.event_id" - " AND m.room_id = c.room_id" - " AND m.user_id = c.state_key" - " WHERE %s" - ) % (where_clause,) + results = [] + if membership_list: + where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( + " OR ".join(["membership = ?" for _ in membership_list]), + ) + + args = [user_id] + args.extend(membership_list) + + sql = ( + "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" + " FROM current_state_events as c" + " INNER JOIN room_memberships as m" + " ON m.event_id = c.event_id" + " INNER JOIN events as e" + " ON e.event_id = c.event_id" + " AND m.room_id = c.room_id" + " AND m.user_id = c.state_key" + " WHERE %s" + ) % (where_clause,) + + txn.execute(sql, args) + results = [ + RoomsForUser(**r) for r in self.cursor_to_dict(txn) + ] + + if do_invite: + sql = ( + "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" + " FROM invites as i" + " INNER JOIN events as e USING (event_id)" + " WHERE invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, (user_id,)) + results.extend(RoomsForUser( + room_id=r["room_id"], + sender=r["inviter"], + event_id=r["event_id"], + stream_ordering=r["stream_ordering"], + membership=Membership.INVITE, + ) for r in self.cursor_to_dict(txn)) - txn.execute(sql, args) - return [ - RoomsForUser(**r) for r in self.cursor_to_dict(txn) - ] + return results @cached(max_entries=5000) def get_joined_hosts_for_room(self, room_id): diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql new file mode 100644 index 0000000000..4f6fb9ea63 --- /dev/null +++ b/synapse/storage/schema/delta/31/invites.sql @@ -0,0 +1,28 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE invites( + stream_id BIGINT NOT NULL, + inviter TEXT NOT NULL, + invitee TEXT NOT NULL, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + locally_rejected TEXT, + replaced_by TEXT +); + +CREATE INDEX invites_id ON invites(stream_id); +CREATE INDEX invites_for_user_idx ON invites(invitee, locally_rejected, replaced_by, room_id); -- cgit 1.4.1 From 92ab45a330c2d6c4e896786135e93b6cabfad1ea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Apr 2016 17:07:43 +0100 Subject: Add upgrade path, rename table --- synapse/storage/roommember.py | 6 +++--- synapse/storage/schema/delta/31/invites.sql | 20 +++++++++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 4c026b33ae..abe5942744 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -75,7 +75,7 @@ class RoomMemberStore(SQLBaseStore): if event.membership == Membership.INVITE: self._simple_insert_txn( txn, - table="invites", + table="local_invites", values={ "event_id": event.event_id, "invitee": event.state_key, @@ -86,7 +86,7 @@ class RoomMemberStore(SQLBaseStore): ) else: sql = ( - "UPDATE invites SET stream_id = ?, replaced_by = ? WHERE" + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" " room_id = ? AND invitee = ? AND locally_rejected is NULL" " AND replaced_by is NULL" ) @@ -239,7 +239,7 @@ class RoomMemberStore(SQLBaseStore): if do_invite: sql = ( "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" - " FROM invites as i" + " FROM local_invites as i" " INNER JOIN events as e USING (event_id)" " WHERE invitee = ? AND locally_rejected is NULL" " AND replaced_by is NULL" diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql index 4f6fb9ea63..1c83430da4 100644 --- a/synapse/storage/schema/delta/31/invites.sql +++ b/synapse/storage/schema/delta/31/invites.sql @@ -14,7 +14,7 @@ */ -CREATE TABLE invites( +CREATE TABLE local_invites( stream_id BIGINT NOT NULL, inviter TEXT NOT NULL, invitee TEXT NOT NULL, @@ -24,5 +24,19 @@ CREATE TABLE invites( replaced_by TEXT ); -CREATE INDEX invites_id ON invites(stream_id); -CREATE INDEX invites_for_user_idx ON invites(invitee, locally_rejected, replaced_by, room_id); +-- Insert all invites for local users into new `invites` table +INSERT INTO local_invites SELECT + stream_ordering as stream_id, + sender as inviter, + state_key as invitee, + event_id, + room_id, + NULL as locally_rejected, + NULL as replaced_by +FROM events +NATURAL JOIN current_state_events +NATURAL JOIN room_memberships +WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); + +CREATE INDEX local_invites_id ON local_invites(stream_id); +CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); -- cgit 1.4.1 From 0c53d750e7145f57ed97c544efdb846cc9e37b67 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Apr 2016 18:02:48 +0100 Subject: Docs and indents --- synapse/handlers/room_member.py | 5 ++++- synapse/storage/roommember.py | 18 ++++++++++++++++-- synapse/storage/schema/delta/31/invites.sql | 22 +++++++++++----------- 3 files changed, 31 insertions(+), 14 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 98e346d48e..f1c3e90ecd 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -400,7 +400,10 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def get_inviter(self, user_id, room_id): - invite = yield self.store.get_inviter(user_id=user_id, room_id=room_id) + invite = yield self.store.get_invite_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) if invite: defer.returnValue(UserID.from_string(invite.sender)) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index abe5942744..36456a75fc 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -66,11 +66,15 @@ class RoomMemberStore(SQLBaseStore): self.get_invited_rooms_for_user.invalidate, (event.state_key,) ) - is_mine = self.hs.is_mine_id(event.state_key) + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. + # The only current event that can also be an outlier is if its an + # invite that has come in across federation. is_new_state = not backfilled and ( not event.internal_metadata.is_outlier() or event.internal_metadata.is_invite_from_remote() ) + is_mine = self.hs.is_mine_id(event.state_key) if is_new_state and is_mine: if event.membership == Membership.INVITE: self._simple_insert_txn( @@ -163,7 +167,17 @@ class RoomMemberStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_inviter(self, user_id, room_id): + def get_invite_for_user_in_room(self, user_id, room_id): + """Gets the invite for the given user and room + + Args: + user_id (str) + room_id (str) + + Returns: + Deferred: Resolves to either a RoomsForUser or None if no invite was + found. + """ invites = yield self.get_invited_rooms_for_user(user_id) for invite in invites: if invite.room_id == room_id: diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql index 1c83430da4..2c57846d5a 100644 --- a/synapse/storage/schema/delta/31/invites.sql +++ b/synapse/storage/schema/delta/31/invites.sql @@ -26,17 +26,17 @@ CREATE TABLE local_invites( -- Insert all invites for local users into new `invites` table INSERT INTO local_invites SELECT - stream_ordering as stream_id, - sender as inviter, - state_key as invitee, - event_id, - room_id, - NULL as locally_rejected, - NULL as replaced_by -FROM events -NATURAL JOIN current_state_events -NATURAL JOIN room_memberships -WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); + stream_ordering as stream_id, + sender as inviter, + state_key as invitee, + event_id, + room_id, + NULL as locally_rejected, + NULL as replaced_by + FROM events + NATURAL JOIN current_state_events + NATURAL JOIN room_memberships + WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); CREATE INDEX local_invites_id ON local_invites(stream_id); CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); -- cgit 1.4.1 From df727f212606b771b1410c8e322fb8a99d159de4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Apr 2016 11:13:24 +0100 Subject: Fix stuck invites If rejecting a remote invite fails with an error response don't fail the entire request; instead mark the invite as locally rejected. This fixes the bug where users can get stuck invites which they can neither accept nor reject. --- synapse/handlers/federation.py | 34 +++++++++++++++++++++++----------- synapse/handlers/room_member.py | 18 ++++++++++++++---- synapse/storage/__init__.py | 3 ++- synapse/storage/roommember.py | 19 +++++++++++++++++++ 4 files changed, 58 insertions(+), 16 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4049c01d26..19769eecd7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -784,13 +784,19 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): - origin, event = yield self._make_and_verify_event( - target_hosts, - room_id, - user_id, - "leave" - ) - signed_event = self._sign_event(event) + try: + origin, event = yield self._make_and_verify_event( + target_hosts, + room_id, + user_id, + "leave" + ) + signed_event = self._sign_event(event) + except SynapseError: + raise + except CodeMessageException as e: + logger.warn("Failed to reject invite: %s", e) + raise SynapseError(500, "Failed to reject invite") # Try the host we successfully got a response to /make_join/ # request first. @@ -800,10 +806,16 @@ class FederationHandler(BaseHandler): except ValueError: pass - yield self.replication_layer.send_leave( - target_hosts, - signed_event - ) + try: + yield self.replication_layer.send_leave( + target_hosts, + signed_event + ) + except SynapseError: + raise + except CodeMessageException as e: + logger.warn("Failed to reject invite: %s", e) + raise SynapseError(500, "Failed to reject invite") context = yield self.state_handler.compute_event_context(event) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index f1c3e90ecd..6c7409215a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -258,10 +258,20 @@ class RoomMemberHandler(BaseHandler): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - ret = yield self.reject_remote_invite( - target.to_string(), room_id, remote_room_hosts - ) - defer.returnValue(ret) + + try: + ret = yield self.reject_remote_invite( + target.to_string(), room_id, remote_room_hosts + ) + defer.returnValue(ret) + except SynapseError as e: + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + target.to_string(), room_id + ) + + defer.returnValue({}) yield self._local_membership_update( requester=requester, diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 57863bba4d..07916b292d 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -94,7 +94,8 @@ class DataStore(RoomMemberStore, RoomStore, ) self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering" + db_conn, "events", "stream_ordering", + extra_tables=[("local_invites", "stream_id")] ) self._backfill_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", step=-1 diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 36456a75fc..66e7a40e3c 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -102,6 +102,25 @@ class RoomMemberStore(SQLBaseStore): event.state_key, )) + @defer.inlineCallbacks + def locally_reject_invite(self, user_id, room_id): + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, ( + stream_ordering, + True, + room_id, + user_id, + )) + + with self._stream_id_gen.get_next() as stream_ordering: + yield self.runInteraction("locally_reject_invite", f, stream_ordering) + def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. -- cgit 1.4.1 From 1d4deff25a1edce73fb3d2f1b327d672a75581b0 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 5 Apr 2016 11:23:57 +0100 Subject: Separate generating the replication response... from doing the http request parsing to make it easier to write unit tests for replication. --- synapse/replication/resource.py | 99 +++++++++++++++++++++++------------------ 1 file changed, 55 insertions(+), 44 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index c51a6fa103..a543af68f8 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -145,32 +145,43 @@ class ReplicationResource(Resource): timeout = parse_integer(request, "timeout", 10 * 1000) request.setHeader(b"Content-Type", b"application/json") - writer = _Writer(request) - @defer.inlineCallbacks - def replicate(): - current_token = yield self.current_replication_token() - logger.info("Replicating up to %r", current_token) - - yield self.account_data(writer, current_token, limit) - yield self.events(writer, current_token, limit) - yield self.presence(writer, current_token) # TODO: implement limit - yield self.typing(writer, current_token) # TODO: implement limit - yield self.receipts(writer, current_token, limit) - yield self.push_rules(writer, current_token, limit) - yield self.pushers(writer, current_token, limit) - yield self.state(writer, current_token, limit) - self.streams(writer, current_token) + request_streams = { + name: parse_integer(request, name) + for names in STREAM_NAMES for name in names + } + request_streams["streams"] = parse_string(request, "streams") - logger.info("Replicated %d rows", writer.total) - defer.returnValue(writer.total) + def replicate(): + return self.replicate(request_streams, limit) - yield self.notifier.wait_for_replication(replicate, timeout) + result = yield self.notifier.wait_for_replication(replicate, timeout) - writer.finish() + request.write(json.dumps(result, ensure_ascii=False)) + finish_request(request) - def streams(self, writer, current_token): - request_token = parse_string(writer.request, "streams") + @defer.inlineCallbacks + def replicate(self, request_streams, limit): + writer = _Writer() + current_token = yield self.current_replication_token() + logger.info("Replicating up to %r", current_token) + + yield self.account_data(writer, current_token, limit, request_streams) + yield self.events(writer, current_token, limit, request_streams) + # TODO: implement limit + yield self.presence(writer, current_token, request_streams) + yield self.typing(writer, current_token, request_streams) + yield self.receipts(writer, current_token, limit, request_streams) + yield self.push_rules(writer, current_token, limit, request_streams) + yield self.pushers(writer, current_token, limit, request_streams) + yield self.state(writer, current_token, limit, request_streams) + self.streams(writer, current_token, request_streams) + + logger.info("Replicated %d rows", writer.total) + defer.returnValue(writer.finish()) + + def streams(self, writer, current_token, request_streams): + request_token = request_streams.get("streams") streams = [] @@ -195,9 +206,9 @@ class ReplicationResource(Resource): ) @defer.inlineCallbacks - def events(self, writer, current_token, limit): - request_events = parse_integer(writer.request, "events") - request_backfill = parse_integer(writer.request, "backfill") + def events(self, writer, current_token, limit, request_streams): + request_events = request_streams.get("events") + request_backfill = request_streams.get("backfill") if request_events is not None or request_backfill is not None: if request_events is None: @@ -228,10 +239,10 @@ class ReplicationResource(Resource): ) @defer.inlineCallbacks - def presence(self, writer, current_token): + def presence(self, writer, current_token, request_streams): current_position = current_token.presence - request_presence = parse_integer(writer.request, "presence") + request_presence = request_streams.get("presence") if request_presence is not None: presence_rows = yield self.presence_handler.get_all_presence_updates( @@ -244,10 +255,10 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def typing(self, writer, current_token): + def typing(self, writer, current_token, request_streams): current_position = current_token.presence - request_typing = parse_integer(writer.request, "typing") + request_typing = request_streams.get("typing") if request_typing is not None: typing_rows = yield self.typing_handler.get_all_typing_updates( @@ -258,10 +269,10 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def receipts(self, writer, current_token, limit): + def receipts(self, writer, current_token, limit, request_streams): current_position = current_token.receipts - request_receipts = parse_integer(writer.request, "receipts") + request_receipts = request_streams.get("receipts") if request_receipts is not None: receipts_rows = yield self.store.get_all_updated_receipts( @@ -272,12 +283,12 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def account_data(self, writer, current_token, limit): + def account_data(self, writer, current_token, limit, request_streams): current_position = current_token.account_data - user_account_data = parse_integer(writer.request, "user_account_data") - room_account_data = parse_integer(writer.request, "room_account_data") - tag_account_data = parse_integer(writer.request, "tag_account_data") + user_account_data = request_streams.get("user_account_data") + room_account_data = request_streams.get("room_account_data") + tag_account_data = request_streams.get("tag_account_data") if user_account_data is not None or room_account_data is not None: if user_account_data is None: @@ -303,10 +314,10 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def push_rules(self, writer, current_token, limit): + def push_rules(self, writer, current_token, limit, request_streams): current_position = current_token.push_rules - push_rules = parse_integer(writer.request, "push_rules") + push_rules = request_streams.get("push_rules") if push_rules is not None: rows = yield self.store.get_all_push_rule_updates( @@ -318,10 +329,11 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def pushers(self, writer, current_token, limit): + def pushers(self, writer, current_token, limit, request_streams): current_position = current_token.pushers - pushers = parse_integer(writer.request, "pushers") + pushers = request_streams.get("pushers") + if pushers is not None: updated, deleted = yield self.store.get_all_updated_pushers( pushers, current_position, limit @@ -336,10 +348,11 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def state(self, writer, current_token, limit): + def state(self, writer, current_token, limit, request_streams): current_position = current_token.state - state = parse_integer(writer.request, "state") + state = request_streams.get("state") + if state is not None: state_groups, state_group_state = ( yield self.store.get_all_new_state_groups( @@ -356,9 +369,8 @@ class ReplicationResource(Resource): class _Writer(object): """Writes the streams as a JSON object as the response to the request""" - def __init__(self, request): + def __init__(self): self.streams = {} - self.request = request self.total = 0 def write_header_and_rows(self, name, rows, fields, position=None): @@ -377,8 +389,7 @@ class _Writer(object): self.total += len(rows) def finish(self): - self.request.write(json.dumps(self.streams, ensure_ascii=False)) - finish_request(self.request) + return self.streams class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( -- cgit 1.4.1 From 6222ae51cee230fc746d0706db13d8928f28234b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Apr 2016 12:56:29 +0100 Subject: Don't backfill from self --- synapse/handlers/federation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 267fedf114..edffa560bf 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -289,6 +289,9 @@ class FederationHandler(BaseHandler): def backfill(self, dest, room_id, limit, extremities=[]): """ Trigger a backfill request to `dest` for the given `room_id` """ + if dest == self.server_name: + raise SynapseError(400, "Can't backfill from self.") + if not extremities: extremities = yield self.store.get_oldest_events_in_room(room_id) @@ -455,7 +458,7 @@ class FederationHandler(BaseHandler): likely_domains = [ domain for domain, depth in curr_domains - if domain is not self.server_name + if domain != self.server_name ] @defer.inlineCallbacks -- cgit 1.4.1 From a1e0d316ea354fce07939073d9afc9c5d1013939 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 13:05:19 +0100 Subject: Move _get_cache_dict into the SQLBaseStore --- synapse/storage/__init__.py | 33 --------------------------------- synapse/storage/_base.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 33 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 07916b292d..045ae6c03f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -177,39 +177,6 @@ class DataStore(RoomMemberStore, RoomStore, self.__presence_on_startup = None return active_on_startup - def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - 100000" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - } - - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (int(max_value),)) - rows = txn.fetchall() - txn.close() - - cache = { - row[0]: int(row[1]) - for row in rows - } - - if cache: - min_val = min(cache.values()) - else: - min_val = max_value - - return cache, min_val - def _get_active_presence(self, db_conn): """Fetch non-offline presence from the database so that we can register the appropriate time outs. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b75b79df36..04d7fcf6d6 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -816,6 +816,40 @@ class SQLBaseStore(object): self._next_stream_id += 1 return i + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, + max_value): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - 100000" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + } + + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + rows = txn.fetchall() + txn.close() + + cache = { + row[0]: int(row[1]) + for row in rows + } + + if cache: + min_val = min(cache.values()) + else: + min_val = max_value + + return cache, min_val + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying -- cgit 1.4.1 From 87f2dec8d475f038beb138bc56e3ef76fcb83ec6 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 13:08:05 +0100 Subject: Make the cache objects be per instance rather than being global --- synapse/storage/receipts.py | 4 ++-- synapse/storage/registration.py | 2 +- synapse/storage/state.py | 4 ++-- synapse/util/caches/descriptors.py | 45 ++++++++++++++++++++------------------ 4 files changed, 29 insertions(+), 26 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 4befebc8e2..7fdd84bbdc 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore): "content": content, }]) - @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", - num_args=3, inlineCallbacks=True) + @cachedList(cached_method_name="get_linearized_receipts_for_room", + list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: defer.returnValue({}) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d46a963bb8..1f71773aaa 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -319,7 +319,7 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) - @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, + @cachedList(cached_method_name="is_guest", list_name="user_ids", num_args=1, inlineCallbacks=True) def are_guests(self, user_ids): sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e9f9406014..c5d2a3a6df 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -273,8 +273,8 @@ class StateStore(SQLBaseStore): desc="_get_state_group_for_event", ) - @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=1, inlineCallbacks=True) + @cachedList(cached_method_name="_get_state_group_for_event", + list_name="event_ids", num_args=1, inlineCallbacks=True) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 35544b19fd..758f5982b0 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -167,7 +167,8 @@ class CacheDescriptor(object): % (orig.__name__,) ) - self.cache = Cache( + def __get__(self, obj, objtype=None): + cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, @@ -175,14 +176,12 @@ class CacheDescriptor(object): tree=self.tree, ) - def __get__(self, obj, objtype=None): - @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = self.cache.get(cache_key) + cached_result_d = cache.get(cache_key) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -204,7 +203,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence + sequence = cache.sequence ret = defer.maybeDeferred( preserve_context_over_fn, @@ -213,20 +212,21 @@ class CacheDescriptor(object): ) def onErr(f): - self.cache.invalidate(cache_key) + cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) + cache.update(sequence, cache_key, ret) return preserve_context_over_deferred(ret.observe()) - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.invalidate_many = self.cache.invalidate_many - wrapped.prefill = self.cache.prefill + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill + wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped @@ -240,11 +240,12 @@ class CacheListDescriptor(object): the list of missing keys to the wrapped fucntion. """ - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + def __init__(self, orig, cached_method_name, list_name, num_args=1, + inlineCallbacks=False): """ Args: orig (function) - cache (Cache) + method_name (str); The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int) inlineCallbacks (bool): Whether orig is a generator that should @@ -263,7 +264,7 @@ class CacheListDescriptor(object): self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cache = cache + self.cached_method_name = cached_method_name self.sentinel = object() @@ -277,11 +278,13 @@ class CacheListDescriptor(object): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) + % (self.list_name, cached_method_name,) ) def __get__(self, obj, objtype=None): + cache = getattr(obj, self.cached_method_name).cache + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) @@ -297,14 +300,14 @@ class CacheListDescriptor(object): key[self.list_pos] = arg try: - res = self.cache.get(tuple(key)).observe() + res = cache.get(tuple(key)).observe() res.addCallback(lambda r, arg: (arg, r), arg) cached[arg] = res except KeyError: missing.append(arg) if missing: - sequence = self.cache.sequence + sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -327,10 +330,10 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) + cache.update(sequence, tuple(key), observer) def invalidate(f, key): - self.cache.invalidate(key) + cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) @@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): ) -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): """ return lambda orig: CacheListDescriptor( orig, - cache=cache, + cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, inlineCallbacks=inlineCallbacks, -- cgit 1.4.1 From 8aab9d87fa6739345810f0edf3982fe7f898ee30 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 6 Apr 2016 14:08:18 +0100 Subject: Don't require config to create database --- scripts/synapse_port_db | 13 ++--- synapse/app/homeserver.py | 15 +++-- synapse/storage/engines/__init__.py | 6 +- synapse/storage/engines/postgres.py | 8 +-- synapse/storage/engines/sqlite3.py | 13 +---- synapse/storage/prepare_database.py | 64 +++++++--------------- .../schema/delta/14/upgrade_appservice_db.py | 6 +- synapse/storage/schema/delta/20/pushers.py | 6 +- synapse/storage/schema/delta/25/fts.py | 6 +- synapse/storage/schema/delta/27/ts.py | 6 +- synapse/storage/schema/delta/30/as_users.py | 4 +- tests/storage/test_base.py | 2 +- tests/utils.py | 6 +- 13 files changed, 69 insertions(+), 86 deletions(-) (limited to 'synapse') diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index a2a0f364cf..253a6ef6c7 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -19,6 +19,7 @@ from twisted.enterprise import adbapi from synapse.storage._base import LoggingTransaction, SQLBaseStore from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database import argparse import curses @@ -37,6 +38,7 @@ BOOLEAN_COLUMNS = { "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], + "presence_stream": ["currently_active"], } @@ -292,7 +294,7 @@ class Porter(object): } ) - database_engine.prepare_database(db_conn) + prepare_database(db_conn, database_engine, config=None) db_conn.commit() @@ -309,8 +311,8 @@ class Porter(object): **self.postgres_config["args"] ) - sqlite_engine = create_engine(FakeConfig(sqlite_config)) - postgres_engine = create_engine(FakeConfig(postgres_config)) + sqlite_engine = create_engine(sqlite_config) + postgres_engine = create_engine(postgres_config) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine) @@ -792,8 +794,3 @@ if __name__ == "__main__": if end_error_exec_info: exc_type, exc_value, exc_traceback = end_error_exec_info traceback.print_exception(exc_type, exc_value, exc_traceback) - - -class FakeConfig: - def __init__(self, database_config): - self.database_config = database_config diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index fcdc8e6e10..2b4473b9ac 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -33,7 +33,7 @@ from synapse.python_dependencies import ( from synapse.rest import ClientRestResource from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage import are_all_users_on_domain -from synapse.storage.prepare_database import UpgradeDatabaseException +from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.server import HomeServer @@ -245,7 +245,7 @@ class SynapseHomeServer(HomeServer): except IncorrectDatabaseSetup as e: quit_with_error(e.message) - def get_db_conn(self): + def get_db_conn(self, run_new_connection=True): # Any param beginning with cp_ is a parameter for adbapi, and should # not be passed to the database engine. db_params = { @@ -254,7 +254,8 @@ class SynapseHomeServer(HomeServer): } db_conn = self.database_engine.module.connect(**db_params) - self.database_engine.on_new_connection(db_conn) + if run_new_connection: + self.database_engine.on_new_connection(db_conn) return db_conn @@ -386,7 +387,7 @@ def setup(config_options): tls_server_context_factory = context_factory.ServerContextFactory(config) - database_engine = create_engine(config) + database_engine = create_engine(config.database_config) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection hs = SynapseHomeServer( @@ -402,8 +403,10 @@ def setup(config_options): logger.info("Preparing database: %s...", config.database_config['name']) try: - db_conn = hs.get_db_conn() - database_engine.prepare_database(db_conn) + db_conn = hs.get_db_conn(run_new_connection=False) + prepare_database(db_conn, database_engine, config=config) + database_engine.on_new_connection(db_conn) + hs.run_startup_checks(db_conn, database_engine) db_conn.commit() diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index a48230b93f..7bb5de1fe7 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -26,13 +26,13 @@ SUPPORTED_MODULE = { } -def create_engine(config): - name = config.database_config["name"] +def create_engine(database_config): + name = database_config["name"] engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: module = importlib.import_module(name) - return engine_class(module, config=config) + return engine_class(module) raise RuntimeError( "Unsupported database engine '%s'" % (name,) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a09685b4df..c2290943b4 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import prepare_database - from ._base import IncorrectDatabaseSetup class PostgresEngine(object): single_threaded = False - def __init__(self, database_module, config): + def __init__(self, database_module): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) - self.config = config def check_database(self, txn): txn.execute("SHOW SERVER_ENCODING") @@ -44,9 +41,6 @@ class PostgresEngine(object): self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) - def prepare_database(self, db_conn): - prepare_database(db_conn, self, config=self.config) - def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): return error.pgcode in ["40001", "40P01"] diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 522b905949..14203aa500 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import ( - prepare_database, prepare_sqlite3_database -) +from synapse.storage.prepare_database import prepare_database import struct @@ -23,9 +21,8 @@ import struct class Sqlite3Engine(object): single_threaded = True - def __init__(self, database_module, config): + def __init__(self, database_module): self.module = database_module - self.config = config def check_database(self, txn): pass @@ -34,13 +31,9 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - self.prepare_database(db_conn) + prepare_database(db_conn, self, config=None) db_conn.create_function("rank", 1, _rank) - def prepare_database(self, db_conn): - prepare_sqlite3_database(db_conn) - prepare_database(db_conn, self, config=self.config) - def is_deadlock(self, error): return False diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 4099387ba7..00833422af 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -53,6 +53,9 @@ class UpgradeDatabaseException(PrepareDatabaseException): def prepare_database(db_conn, database_engine, config): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. + + If `config` is None then prepare_database will assert that no upgrade is + necessary, *or* will create a fresh database if the database is empty. """ try: cur = db_conn.cursor() @@ -60,13 +63,18 @@ def prepare_database(db_conn, database_engine, config): if version_info: user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine, config - ) - else: - _setup_new_database(cur, database_engine, config) - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + if config is None: + if user_version != SCHEMA_VERSION: + # If we don't pass in a config file then we are expecting to + # have already upgraded the DB. + raise UpgradeDatabaseException("Database needs to be upgraded") + else: + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine, config + ) + else: + _setup_new_database(cur, database_engine) cur.close() db_conn.commit() @@ -75,7 +83,7 @@ def prepare_database(db_conn, database_engine, config): raise -def _setup_new_database(cur, database_engine, config): +def _setup_new_database(cur, database_engine): """Sets up the database by finding a base set of "full schemas" and then applying any necessary deltas. @@ -148,12 +156,13 @@ def _setup_new_database(cur, database_engine, config): applied_delta_files=[], upgraded=False, database_engine=database_engine, - config=config, + config=None, + is_empty=True, ) def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine, config): + upgraded, database_engine, config, is_empty=False): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -246,7 +255,9 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module_name, absolute_path, python_file ) logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine, config=config) + module.run_create(cur, database_engine) + if not is_empty: + module.run_upgrade(cur, database_engine, config=config) elif ext == ".pyc": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package @@ -361,36 +372,3 @@ def _get_or_create_schema_state(txn, database_engine): return current_version, applied_deltas, upgraded return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py index 5c40a77757..8755bb2e49 100644 --- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py +++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py @@ -18,7 +18,7 @@ import logging logger = logging.getLogger(__name__) -def run_upgrade(cur, *args, **kwargs): +def run_create(cur, *args, **kwargs): cur.execute("SELECT id, regex FROM application_services_regex") for row in cur.fetchall(): try: @@ -35,3 +35,7 @@ def run_upgrade(cur, *args, **kwargs): "UPDATE application_services_regex SET regex=? WHERE id=?", (new_regex, row[0]) ) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py index 29164732af..147496a38b 100644 --- a/synapse/storage/schema/delta/20/pushers.py +++ b/synapse/storage/schema/delta/20/pushers.py @@ -27,7 +27,7 @@ import logging logger = logging.getLogger(__name__) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table...") cur.execute(""" CREATE TABLE IF NOT EXISTS pushers2 ( @@ -74,3 +74,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute("DROP TABLE pushers") cur.execute("ALTER TABLE pushers2 RENAME TO pushers") logger.info("Moved %d pushers to new table", count) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index d3ff2b1779..4269ac69ad 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -43,7 +43,7 @@ SQLITE_TABLE = ( ) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): if isinstance(database_engine, PostgresEngine): for statement in get_statements(POSTGRES_TABLE.splitlines()): cur.execute(statement) @@ -76,3 +76,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): sql = database_engine.convert_param_style(sql) cur.execute(sql, ("event_search", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py index f8c16391a2..71b12a2731 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/schema/delta/27/ts.py @@ -27,7 +27,7 @@ ALTER_TABLE = ( ) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): for statement in get_statements(ALTER_TABLE.splitlines()): cur.execute(statement) @@ -55,3 +55,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): sql = database_engine.convert_param_style(sql) cur.execute(sql, ("event_origin_server_ts", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 4f6e9dd540..b417e3ac08 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -18,7 +18,7 @@ from synapse.storage.appservice import ApplicationServiceStore logger = logging.getLogger(__name__) -def run_upgrade(cur, database_engine, config, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): # NULL indicates user was not registered by an appservice. try: cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") @@ -26,6 +26,8 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): # Maybe we already added the column? Hope so... pass + +def run_upgrade(cur, database_engine, config, *args, **kwargs): cur.execute("SELECT name FROM users") rows = cur.fetchall() diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 2e33beb07c..afbefb2e2d 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -53,7 +53,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): "test", db_pool=self.db_pool, config=config, - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), ) self.datastore = SQLBaseStore(hs) diff --git a/tests/utils.py b/tests/utils.py index 52405502e9..c179df31ee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -64,7 +64,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=db_pool, config=config, version_string="Synapse/tests", - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), get_db_conn=db_pool.get_db_conn, **kargs ) @@ -73,7 +73,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), **kargs ) @@ -298,7 +298,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object): return conn def create_engine(self): - return create_engine(self.config) + return create_engine(self.config.database_config) class MemoryDataStore(object): -- cgit 1.4.1 From 75fb9ac1be0fada60cdde38153ac0e3fe2b19a0a Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 14:12:51 +0100 Subject: Add a slaved events store class Add a test to check that get_room_names_and_aliases does the same thing on both the master and on the slave data store. --- synapse/replication/slave/__init__.py | 14 ++ synapse/replication/slave/storage/__init__.py | 14 ++ synapse/replication/slave/storage/_base.py | 28 +++ .../slave/storage/_slaved_id_tracker.py | 30 ++++ synapse/replication/slave/storage/events.py | 198 +++++++++++++++++++++ synapse/storage/events.py | 4 +- tests/replication/slave/__init__.py | 14 ++ tests/replication/slave/storage/__init__.py | 14 ++ tests/replication/slave/storage/_base.py | 57 ++++++ tests/replication/slave/storage/test_events.py | 114 ++++++++++++ 10 files changed, 485 insertions(+), 2 deletions(-) create mode 100644 synapse/replication/slave/__init__.py create mode 100644 synapse/replication/slave/storage/__init__.py create mode 100644 synapse/replication/slave/storage/_base.py create mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py create mode 100644 synapse/replication/slave/storage/events.py create mode 100644 tests/replication/slave/__init__.py create mode 100644 tests/replication/slave/storage/__init__.py create mode 100644 tests/replication/slave/storage/_base.py create mode 100644 tests/replication/slave/storage/test_events.py (limited to 'synapse') diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/slave/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/slave/storage/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py new file mode 100644 index 0000000000..46e43ce1c7 --- /dev/null +++ b/synapse/replication/slave/storage/_base.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage._base import SQLBaseStore +from twisted.internet import defer + + +class BaseSlavedStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(BaseSlavedStore, self).__init__(hs) + + def stream_positions(self): + return {} + + def process_replication(self, result): + return defer.succeed(None) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py new file mode 100644 index 0000000000..24b5c79d4a --- /dev/null +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.util.id_generators import _load_current_id + + +class SlavedIdTracker(object): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): + self.step = step + self._current = _load_current_id(db_conn, table, column, step) + for table, column in extra_tables: + self.advance(_load_current_id(db_conn, table, column)) + + def advance(self, new_id): + self._current = (max if self.step > 0 else min)(self._current, new_id) + + def get_current_token(self): + return self._current diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py new file mode 100644 index 0000000000..68b924e37b --- /dev/null +++ b/synapse/replication/slave/storage/events.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + +from synapse.api.constants import EventTypes +from synapse.events import FrozenEvent +from synapse.storage import DataStore +from synapse.storage.room import RoomStore +from synapse.storage.roommember import RoomMemberStore +from synapse.storage.event_federation import EventFederationStore +from synapse.storage.state import StateStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +import ujson as json + +# So, um, we want to borrow a load of functions intended for reading from +# a DataStore, but we don't want to take functions that either write to the +# DataStore or are cached and don't have cache invalidation logic. +# +# Rather than write duplicate versions of those functions, or lift them to +# a common base class, we going to grab the underlying __func__ object from +# the method descriptor on the DataStore and chuck them into our class. + + +class SlavedEventStore(BaseSlavedStore): + + def __init__(self, db_conn, hs): + super(SlavedEventStore, self).__init__(db_conn, hs) + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + events_max = self._stream_id_gen.get_current_token() + event_cache_prefill, min_event_val = self._get_cache_dict( + db_conn, "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", min_event_val, + prefilled_cache=event_cache_prefill, + ) + + # Cached functions can't be accessed through a class instance so we need + # to reach inside the __dict__ to extract them. + get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"] + get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] + get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] + get_latest_event_ids_in_room = EventFederationStore.__dict__[ + "get_latest_event_ids_in_room" + ] + _get_current_state_for_key = StateStore.__dict__[ + "_get_current_state_for_key" + ] + + get_current_state = DataStore.get_current_state.__func__ + get_current_state_for_key = DataStore.get_current_state_for_key.__func__ + _get_rooms_for_user_where_membership_is_txn = ( + DataStore._get_rooms_for_user_where_membership_is_txn.__func__ + ) + get_rooms_for_user_where_membership_is = ( + DataStore.get_rooms_for_user_where_membership_is.__func__ + ) + get_membership_changes_for_user = ( + DataStore.get_membership_changes_for_user.__func__ + ) + get_room_events_max_id = DataStore.get_room_events_max_id.__func__ + get_room_events_stream_for_room = ( + DataStore.get_room_events_stream_for_room.__func__ + ) + _set_before_and_after = DataStore._set_before_and_after + + _get_events = DataStore._get_events.__func__ + _get_events_from_cache = DataStore._get_events_from_cache.__func__ + + _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ + _parse_events_txn = DataStore._parse_events_txn.__func__ + _get_events_txn = DataStore._get_events_txn.__func__ + _fetch_events_txn = DataStore._fetch_events_txn.__func__ + _fetch_event_rows = DataStore._fetch_event_rows.__func__ + _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ + + def stream_positions(self): + result = super(SlavedEventStore, self).stream_positions() + result["events"] = self._stream_id_gen.get_current_token() + result["backfilled"] = self._backfill_id_gen.get_current_token() + return result + + def process_replication(self, result): + state_resets = set( + r[0] for r in result.get("state_resets", {"rows": []})["rows"] + ) + + stream = result.get("events") + if stream: + self._stream_id_gen.advance(stream["position"]) + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=False, state_resets=state_resets + ) + + stream = result.get("backfill") + if stream: + self._backfill_id_gen.advance(stream["position"]) + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=True, state_resets=state_resets + ) + + stream = result.get("forward_ex_outliers") + if stream: + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + stream = result.get("backward_ex_outliers") + if stream: + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + return super(SlavedEventStore, self).process_replication(result) + + def _process_replication_row(self, row, backfilled, state_resets): + position = row[0] + internal = json.loads(row[1]) + event_json = json.loads(row[2]) + + event = FrozenEvent(event_json, internal_metadata_dict=internal) + self._invalidate_caches_for_event( + event, backfilled, reset_state=position in state_resets + ) + + def _invalidate_caches_for_event(self, event, backfilled, reset_state): + if reset_state: + self._get_current_state_for_key.invalidate_all() + self.get_rooms_for_user.invalidate_all() + self.get_users_in_room.invalidate((event.room_id,)) + # self.get_joined_hosts_for_room.invalidate((event.room_id,)) + self.get_room_name_and_aliases.invalidate((event.room_id,)) + + self._invalidate_get_event_cache(event.event_id) + + if not backfilled: + self._events_stream_cache.entity_has_changed( + event.room_id, event.internal_metadata.stream_ordering + ) + + # self.get_unread_event_push_actions_by_room_for_user.invalidate_many( + # (event.room_id,) + # ) + + if event.type == EventTypes.Redaction: + self._invalidate_get_event_cache(event.redacts) + + if event.type == EventTypes.Member: + self.get_rooms_for_user.invalidate((event.state_key,)) + # self.get_joined_hosts_for_room.invalidate((event.room_id,)) + self.get_users_in_room.invalidate((event.room_id,)) + # self._membership_stream_cache.entity_has_changed( + # event.state_key, event.internal_metadata.stream_ordering + # ) + + if not event.is_state(): + return + + if backfilled: + return + + if (not event.internal_metadata.is_invite_from_remote() + and event.internal_metadata.is_outlier()): + return + + self._get_current_state_for_key.invalidate(( + event.room_id, event.type, event.state_key + )) + + if event.type in [EventTypes.Name, EventTypes.Aliases]: + self.get_room_name_and_aliases.invalidate( + (event.room_id,) + ) + pass diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5d299a1132..ee87a71719 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1134,7 +1134,7 @@ class EventsStore(SQLBaseStore): upper_bound = current_forward_id sql = ( - "SELECT -event_stream_ordering FROM current_state_resets" + "SELECT event_stream_ordering FROM current_state_resets" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " ORDER BY event_stream_ordering ASC" @@ -1143,7 +1143,7 @@ class EventsStore(SQLBaseStore): state_resets = txn.fetchall() sql = ( - "SELECT -event_stream_ordering, event_id, state_group" + "SELECT event_stream_ordering, event_id, state_group" " FROM ex_outlier_stream" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" diff --git a/tests/replication/slave/__init__.py b/tests/replication/slave/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/tests/replication/slave/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/slave/storage/__init__.py b/tests/replication/slave/storage/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/tests/replication/slave/storage/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py new file mode 100644 index 0000000000..0f525a8943 --- /dev/null +++ b/tests/replication/slave/storage/_base.py @@ -0,0 +1,57 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from tests import unittest + +from synapse.replication.slave.storage.events import SlavedEventStore + +from mock import Mock, NonCallableMock +from tests.utils import setup_test_homeserver +from synapse.replication.resource import ReplicationResource + + +class BaseSlavedStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + "blue", + http_client=None, + replication_layer=Mock(), + ratelimiter=NonCallableMock(spec_set=[ + "send_message", + ]), + ) + self.hs.get_ratelimiter().send_message.return_value = (True, 0) + + self.replication = ReplicationResource(self.hs) + + self.master_store = self.hs.get_datastore() + self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs) + self.event_id = 0 + + @defer.inlineCallbacks + def replicate(self): + streams = self.slaved_store.stream_positions() + result = yield self.replication.replicate(streams, 100) + yield self.slaved_store.process_replication(result) + + @defer.inlineCallbacks + def check(self, method, args, expected_result=None): + master_result = yield getattr(self.master_store, method)(*args) + slaved_result = yield getattr(self.slaved_store, method)(*args) + self.assertEqual(master_result, slaved_result) + if expected_result is not None: + self.assertEqual(master_result, expected_result) + self.assertEqual(slaved_result, expected_result) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py new file mode 100644 index 0000000000..c30c7c6063 --- /dev/null +++ b/tests/replication/slave/storage/test_events.py @@ -0,0 +1,114 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStoreTestCase + +from synapse.types import UserID +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext + +from twisted.internet import defer + +USER_ID = "@feeling:blue" +USER = UserID.from_string(USER_ID) +OUTLIER = {"outlier": True} +ROOM_ID = "!room:blue" + + +class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): + + @defer.inlineCallbacks + def test_room_name_and_aliases(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + yield self.persist(type="m.room.name", key="", name="name1") + yield self.persist( + type="m.room.aliases", key="blue", aliases=["#1:blue"] + ) + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"]) + ) + + # Set the room name. + yield self.persist(type="m.room.name", key="", name="name2") + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"]) + ) + + # Set the room aliases. + yield self.persist( + type="m.room.aliases", key="blue", aliases=["#2:blue"] + ) + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"]) + ) + + # Leave and join the room clobbering the state. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), (None, []) + ) + + event_id = 0 + + @defer.inlineCallbacks + def persist( + self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, + internal={}, + state=None, reset_state=False, backfill=False, + depth=None, prev_events=[], auth_events=[], prev_state=[], + **content + ): + if depth is None: + depth = self.event_id + + event_dict = { + "sender": sender, + "type": type, + "content": content, + "event_id": "$%d:blue" % (self.event_id,), + "room_id": room_id, + "depth": depth, + "origin_server_ts": self.event_id, + "prev_events": prev_events, + "auth_events": auth_events, + } + if key is not None: + event_dict["state_key"] = key + event_dict["prev_state"] = prev_state + + event = FrozenEvent(event_dict, internal_metadata_dict=internal) + + self.event_id += 1 + + context = EventContext(current_state=state) + + if backfill: + yield self.master_store.persist_events( + [(event, context)], backfilled=True + ) + else: + yield self.master_store.persist_event( + event, context, current_state=reset_state + ) + defer.returnValue(event) -- cgit 1.4.1 From 1e05637e37f62445d84e43ae89e441f1833a32e2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 Apr 2016 15:19:45 +0100 Subject: Let users see their own leave events ... otherwise clients get confused. Fixes https://matrix.org/jira/browse/SYN-662, https://github.com/vector-im/vector-web/issues/368 --- synapse/handlers/_base.py | 51 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 11 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index c77afe7f51..88d8b9ba54 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -37,6 +37,15 @@ VISIBILITY_PRIORITY = ( ) +MEMBERSHIP_PRIORITY = ( + Membership.JOIN, + Membership.INVITE, + Membership.KNOCK, + Membership.LEAVE, + Membership.BAN, +) + + class BaseHandler(object): """ Common base class for the event handlers. @@ -72,6 +81,7 @@ class BaseHandler(object): * the user is not currently a member of the room, and: * the user has not been a member of the room since the given events + events ([synapse.events.EventBase]): list of events to filter """ forgotten = yield defer.gatherResults([ self.store.who_forgot_in_room( @@ -86,6 +96,12 @@ class BaseHandler(object): ) def allowed(event, user_id, is_peeking): + """ + Args: + event (synapse.events.EventBase): event to check + user_id (str) + is_peeking (bool) + """ state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. @@ -117,17 +133,30 @@ class BaseHandler(object): if old_priority < new_priority: visibility = prev_visibility - # get the user's membership at the time of the event. (or rather, - # just *after* the event. Which means that people can see their - # own join events, but not (currently) their own leave events.) - membership_event = state.get((EventTypes.Member, user_id), None) - if membership_event: - if membership_event.event_id in event_id_forgotten: - membership = None - else: - membership = membership_event.membership - else: - membership = None + # likewise, if the event is the user's own membership event, use + # the 'most joined' membership + membership = None + if event.type == EventTypes.Member and event.state_key == user_id: + membership = event.content.get("membership", None) + if membership not in MEMBERSHIP_PRIORITY: + membership = "leave" + + prev_content = event.unsigned.get("prev_content", {}) + prev_membership = prev_content.get("membership", None) + if prev_membership not in MEMBERSHIP_PRIORITY: + prev_membership = "leave" + + new_priority = MEMBERSHIP_PRIORITY.index(membership) + old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) + if old_priority < new_priority: + membership = prev_membership + + # otherwise, get the user's membership at the time of the event. + if membership is None: + membership_event = state.get((EventTypes.Member, user_id), None) + if membership_event: + if membership_event.event_id not in event_id_forgotten: + membership = membership_event.membership # if the user was a member of the room at the time of the event, # they can see it. -- cgit 1.4.1 From 7e2c89a37f3a5261f43b4d472b36219ac41dfb16 Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 6 Apr 2016 15:42:15 +0100 Subject: Make pushers use the event_push_actions table instead of listening on an event stream & running the rules again. Sytest passes, but remaining to do: * Make badges work again * Remove old, unused code --- synapse/handlers/_base.py | 8 +- synapse/handlers/federation.py | 8 +- synapse/push/bulk_push_rule_evaluator.py | 25 +++- synapse/push/httppusher.py | 204 +++++++++++++++++++++++------ synapse/push/push_tools.py | 66 ++++++++++ synapse/push/pusher.py | 10 ++ synapse/push/pusherpool.py | 75 ++++++----- synapse/storage/event_push_actions.py | 48 +++++++ synapse/storage/events.py | 12 ++ synapse/storage/pusher.py | 81 ++++++++---- synapse/storage/registration.py | 20 --- synapse/storage/roommember.py | 1 + synapse/storage/schema/delta/31/pushers.py | 75 +++++++++++ 13 files changed, 503 insertions(+), 130 deletions(-) create mode 100644 synapse/push/push_tools.py create mode 100644 synapse/push/pusher.py create mode 100644 synapse/storage/schema/delta/31/pushers.py (limited to 'synapse') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index c77afe7f51..9c92ea01ed 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -21,7 +21,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.types import UserID, RoomAlias, Requester from synapse.push.action_generator import ActionGenerator -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import PreserveLoggingContext, preserve_fn import logging @@ -377,6 +377,12 @@ class BaseHandler(object): event, context=context ) + # this intentionally does not yield: we don't care about the result + # and don't need to wait for it. + preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + event_stream_id, max_stream_id + ) + destinations = set() for k, s in context.current_state.items(): try: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 026ebe52be..fc5e0b0590 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -26,7 +26,7 @@ from synapse.api.errors import ( from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.util.frozenutils import unfreeze @@ -1094,6 +1094,12 @@ class FederationHandler(BaseHandler): context=context, ) + # this intentionally does not yield: we don't care about the result + # and don't need to wait for it. + preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + event_stream_id, max_stream_id + ) + defer.returnValue((context, event_stream_id, max_stream_id)) @defer.inlineCallbacks diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 76d7eb7ce0..7f94591dcb 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -70,11 +70,17 @@ def _get_rules(room_id, user_ids, store): @defer.inlineCallbacks def evaluator_for_room_id(room_id, hs, store): - results = yield store.get_receipts_for_room(room_id, "m.read") - user_ids = [ - row["user_id"] for row in results - if hs.is_mine_id(row["user_id"]) - ] + users_with_pushers = yield store.get_users_with_pushers_in_room(room_id) + receipts = yield store.get_receipts_for_room(room_id, "m.read") + + # any users with pushers must be ours: they have pushers + user_ids = set(users_with_pushers) + for r in receipts: + if hs.is_mine_id(r['user_id']): + user_ids.add(r['user_id']) + + user_ids = list(user_ids) + rules_by_user = yield _get_rules(room_id, user_ids, store) defer.returnValue(BulkPushRuleEvaluator( @@ -101,10 +107,15 @@ class BulkPushRuleEvaluator: def action_for_event_by_user(self, event, handler, current_state): actions_by_user = {} - users_dict = yield self.store.are_guests(self.rules_by_user.keys()) + # None of these users can be peeking since this list of users comes + # from the set of users in the room, so we know for sure they're all + # actually in the room. + user_tuples = [ + (u, False) for u in self.rules_by_user.keys() + ] filtered_by_user = yield handler.filter_events_for_clients( - users_dict.items(), [event], {event.event_id: current_state} + user_tuples, [event], {event.event_id: current_state} ) room_members = yield self.store.get_users_in_room(self.room_id) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 9be4869360..d695885649 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -13,60 +13,188 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.push import Pusher, PusherConfigException +from synapse.push import PusherConfigException -from twisted.internet import defer +from twisted.internet import defer, reactor import logging +import push_rule_evaluator +import push_tools logger = logging.getLogger(__name__) -class HttpPusher(Pusher): - def __init__(self, _hs, user_id, app_id, - app_display_name, device_display_name, pushkey, pushkey_ts, - data, last_token, last_success, failing_since): - super(HttpPusher, self).__init__( - _hs, - user_id, - app_id, - app_display_name, - device_display_name, - pushkey, - pushkey_ts, - data, - last_token, - last_success, - failing_since +class HttpPusher(object): + INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes + MAX_BACKOFF_SEC = 60 * 60 + + # This one's in ms because we compare it against the clock + GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000 + + def __init__(self, hs, pusherdict): + self.hs = hs + self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() + self.user_id = pusherdict['user_name'] + self.app_id = pusherdict['app_id'] + self.app_display_name = pusherdict['app_display_name'] + self.device_display_name = pusherdict['device_display_name'] + self.pushkey = pusherdict['pushkey'] + self.pushkey_ts = pusherdict['ts'] + self.data = pusherdict['data'] + self.last_stream_ordering = pusherdict['last_stream_ordering'] + self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC + self.failing_since = pusherdict['failing_since'] + self.timed_call = None + + # This is the highest stream ordering we know it's safe to process. + # When new events arrive, we'll be given a window of new events: we + # should honour this rather than just looking for anything higher + # because of potential out-of-order event serialisation. This starts + # off as None though as we don't know any better. + self.max_stream_ordering = None + + if 'data' not in pusherdict: + raise PusherConfigException( + "No 'data' key for HTTP pusher" + ) + self.data = pusherdict['data'] + + self.name = "%s/%s/%s" % ( + pusherdict['user_name'], + pusherdict['app_id'], + pusherdict['pushkey'], ) - if 'url' not in data: + + if 'url' not in self.data: raise PusherConfigException( "'url' required in data for HTTP pusher" ) - self.url = data['url'] - self.http_client = _hs.get_simple_http_client() + self.url = self.data['url'] + self.http_client = hs.get_simple_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url['url'] + def on_started(self): + self._process() + + def on_new_notifications(self, min_stream_ordering, max_stream_ordering): + self.max_stream_ordering = max_stream_ordering + self._process() + + def on_timer(self): + self._process() + + def on_stop(self): + if self.timed_call: + self.timed_call.cancel() + @defer.inlineCallbacks - def _build_notification_dict(self, event, tweaks, badge): - # we probably do not want to push for every presence update - # (we may want to be able to set up notifications when specific - # people sign in, but we'd want to only deliver the pertinent ones) - # Actually, presence events will not get this far now because we - # need to filter them out in the main Pusher code. - if 'event_id' not in event: - defer.returnValue(None) + def _process(self): + unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( + self.user_id, self.last_stream_ordering, self.max_stream_ordering + ) + + for push_action in unprocessed: + processed = yield self._process_one(push_action) + if processed: + self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC + self.last_stream_ordering = push_action['stream_ordering'] + self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, self.pushkey, self.user_id, + self.last_stream_ordering, + self.clock.time_msec() + ) + self.failing_since = None + yield self.store.update_pusher_failing_since( + self.app_id, self.pushkey, self.user_id, + self.failing_since + ) + else: + self.failing_since = self.clock.time_msec() + yield self.store.update_pusher_failing_since( + self.app_id, self.pushkey, self.user_id, + self.failing_since + ) + + if ( + self.failing_since and + self.failing_since < + self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER + ): + # we really only give up so that if the URL gets + # fixed, we don't suddenly deliver a load + # of old notifications. + logger.warn("Giving up on a notification to user %s, " + "pushkey %s", + self.user_id, self.pushkey) + self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC + self.last_stream_ordering = push_action['stream_ordering'] + yield self.store.update_pusher_last_stream_ordering( + self.app_id, + self.pushkey, + self.user_id, + self.last_stream_ordering + ) + + self.failing_since = None + yield self.store.update_pusher_failing_since( + self.app_id, + self.pushkey, + self.user_id, + self.failing_since + ) + else: + logger.info("Push failed: delaying for %ds", self.backoff_delay) + self.timed_call = reactor.callLater(self.backoff_delay, self.on_timer) + self.backoff_delay = min(self.backoff_delay, self.MAX_BACKOFF_SEC) + break + + @defer.inlineCallbacks + def _process_one(self, push_action): + if 'notify' not in push_action['actions']: + defer.returnValue(True) - ctx = yield self.get_context_for_event(event) + tweaks = push_rule_evaluator.PushRuleEvaluator.tweaks_for_actions(push_action['actions']) + badge = yield push_tools.get_badge_count(self.hs, self.user_id) + + event = yield self.store.get_event(push_action['event_id'], allow_none=True) + if event is None: + defer.returnValue(True) # It's been redacted + rejected = yield self.dispatch_push(event, tweaks, badge) + if rejected is False: + defer.returnValue(False) + + if isinstance(rejected, list) or isinstance(rejected, tuple): + for pk in rejected: + if pk != self.pushkey: + # for sanity, we only remove the pushkey if it + # was the one we actually sent... + logger.warn( + ("Ignoring rejected pushkey %s because we" + " didn't send it"), pk + ) + else: + logger.info( + "Pushkey %s was rejected: removing", + pk + ) + yield self.hs.get_pusherpool().remove_pusher( + self.app_id, pk, self.user_id + ) + defer.returnValue(True) + + @defer.inlineCallbacks + def _build_notification_dict(self, event, tweaks, badge): + ctx = yield push_tools.get_context_for_event(self.hs, event) d = { 'notification': { - 'id': event['event_id'], - 'room_id': event['room_id'], - 'type': event['type'], - 'sender': event['user_id'], + 'id': event.event_id, + 'room_id': event.room_id, + 'type': event.type, + 'sender': event.user_id, 'counts': { # -- we don't mark messages as read yet so # we have no way of knowing # Just set the badge to 1 until we have read receipts @@ -84,11 +212,11 @@ class HttpPusher(Pusher): ] } } - if event['type'] == 'm.room.member': - d['notification']['membership'] = event['content']['membership'] - d['notification']['user_is_target'] = event['state_key'] == self.user_id + if event.type == 'm.room.member': + d['notification']['membership'] = event.content['membership'] + d['notification']['user_is_target'] = event.state_key == self.user_id if 'content' in event: - d['notification']['content'] = event['content'] + d['notification']['content'] = event.content if len(ctx['aliases']): d['notification']['room_alias'] = ctx['aliases'][0] diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py new file mode 100644 index 0000000000..e1e61e49e8 --- /dev/null +++ b/synapse/push/push_tools.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + + +@defer.inlineCallbacks +def get_badge_count(hs, user_id): + invites, joins = yield defer.gatherResults([ + hs.get_datastore().get_invited_rooms_for_user(user_id), + hs.get_datastore().get_rooms_for_user(user_id), + ], consumeErrors=True) + + my_receipts_by_room = yield hs.get_datastore().get_receipts_for_user( + user_id, "m.read", + ) + + badge = len(invites) + + for r in joins: + if r.room_id in my_receipts_by_room: + last_unread_event_id = my_receipts_by_room[r.room_id] + + notifs = yield ( + hs.get_datastore().get_unread_event_push_actions_by_room_for_user( + r.room_id, user_id, last_unread_event_id + ) + ) + badge += notifs["notify_count"] + defer.returnValue(badge) + + +@defer.inlineCallbacks +def get_context_for_event(hs, ev): + name_aliases = yield hs.get_datastore().get_room_name_and_aliases( + ev.room_id + ) + + ctx = {'aliases': name_aliases[1]} + if name_aliases[0] is not None: + ctx['name'] = name_aliases[0] + + their_member_events_for_room = yield hs.get_datastore().get_current_state( + room_id=ev.room_id, + event_type='m.room.member', + state_key=ev.user_id + ) + for mev in their_member_events_for_room: + if mev.content['membership'] == 'join' and 'displayname' in mev.content: + dn = mev.content['displayname'] + if dn is not None: + ctx['sender_display_name'] = dn + + defer.returnValue(ctx) diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py new file mode 100644 index 0000000000..4960837504 --- /dev/null +++ b/synapse/push/pusher.py @@ -0,0 +1,10 @@ +from httppusher import HttpPusher + +PUSHER_TYPES = { + 'http': HttpPusher +} + + +def create_pusher(hs, pusherdict): + if pusherdict['kind'] in PUSHER_TYPES: + return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 0b463c6fdb..b67ad455ea 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -16,9 +16,10 @@ from twisted.internet import defer -from .httppusher import HttpPusher +import pusher from synapse.push import PusherConfigException from synapse.util.logcontext import preserve_fn +from synapse.util.async import run_on_reactor import logging @@ -48,7 +49,7 @@ class PusherPool: # will then get pulled out of the database, # recreated, added and started: this means we have only one # code path adding pushers. - self._create_pusher({ + pusher.create_pusher(self.hs, { "user_name": user_id, "kind": kind, "app_id": app_id, @@ -58,10 +59,18 @@ class PusherPool: "ts": time_now_msec, "lang": lang, "data": data, - "last_token": None, + "last_stream_ordering": None, "last_success": None, "failing_since": None }) + + # create the pusher setting last_stream_ordering to the current maximum + # stream ordering in event_push_actions, so it will process + # pushes from this point onwards. + last_stream_ordering = ( + yield self.store.get_latest_push_action_stream_ordering() + ) + yield self.store.add_pusher( user_id=user_id, access_token=access_token, @@ -73,6 +82,7 @@ class PusherPool: pushkey_ts=time_now_msec, lang=lang, data=data, + last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, ) yield self._refresh_pusher(app_id, pushkey, user_id) @@ -106,26 +116,19 @@ class PusherPool: ) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) - def _create_pusher(self, pusherdict): - if pusherdict['kind'] == 'http': - return HttpPusher( - self.hs, - user_id=pusherdict['user_name'], - app_id=pusherdict['app_id'], - app_display_name=pusherdict['app_display_name'], - device_display_name=pusherdict['device_display_name'], - pushkey=pusherdict['pushkey'], - pushkey_ts=pusherdict['ts'], - data=pusherdict['data'], - last_token=pusherdict['last_token'], - last_success=pusherdict['last_success'], - failing_since=pusherdict['failing_since'] - ) - else: - raise PusherConfigException( - "Unknown pusher type '%s' for user %s" % - (pusherdict['kind'], pusherdict['user_name']) + @defer.inlineCallbacks + def on_new_notifications(self, min_stream_id, max_stream_id): + yield run_on_reactor() + try: + users_affected = yield self.store.get_push_action_users_in_range( + min_stream_id, max_stream_id ) + for u in users_affected: + if u in self.pushers: + for p in self.pushers[u].values(): + p.on_new_notifications(min_stream_id, max_stream_id) + except: + logger.exception("Exception in pusher on_new_notifications") @defer.inlineCallbacks def _refresh_pusher(self, app_id, pushkey, user_id): @@ -146,30 +149,34 @@ class PusherPool: logger.info("Starting %d pushers", len(pushers)) for pusherdict in pushers: try: - p = self._create_pusher(pusherdict) + p = pusher.create_pusher(self.hs, pusherdict) except PusherConfigException: logger.exception("Couldn't start a pusher: caught PusherConfigException") continue if p: - fullid = "%s:%s:%s" % ( + appid_pushkey = "%s:%s" % ( pusherdict['app_id'], pusherdict['pushkey'], - pusherdict['user_name'] ) - if fullid in self.pushers: - self.pushers[fullid].stop() - self.pushers[fullid] = p - preserve_fn(p.start)() + byuser = self.pushers.setdefault(pusherdict['user_name'], {}) + + if appid_pushkey in byuser: + byuser[appid_pushkey].on_stop() + byuser[appid_pushkey] = p + preserve_fn(p.on_started)() logger.info("Started pushers") @defer.inlineCallbacks def remove_pusher(self, app_id, pushkey, user_id): - fullid = "%s:%s:%s" % (app_id, pushkey, user_id) - if fullid in self.pushers: - logger.info("Stopping pusher %s", fullid) - self.pushers[fullid].stop() - del self.pushers[fullid] + appid_pushkey = "%s:%s" % (app_id, pushkey) + + byuser = self.pushers.get(user_id, {}) + + if appid_pushkey in byuser: + logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) + byuser[appid_pushkey].on_stop() + del byuser[appid_pushkey] yield self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 3933b6e2c5..5f61743e34 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -100,6 +100,54 @@ class EventPushActionsStore(SQLBaseStore): ) defer.returnValue(ret) + @defer.inlineCallbacks + def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): + def f(txn): + sql = ( + "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" + " stream_ordering >= ? AND stream_ordering >= ?" + ) + txn.execute(sql, (min_stream_ordering, max_stream_ordering)) + return [r[0] for r in txn.fetchall()] + ret = yield self.runInteraction("get_push_action_users_in_range", f) + defer.returnValue(ret) + + @defer.inlineCallbacks + def get_unread_push_actions_for_user_in_range(self, user_id, + min_stream_ordering, + max_stream_ordering=None): + def f(txn): + sql = ( + "SELECT event_id, stream_ordering, actions" + " FROM event_push_actions" + " WHERE user_id = ? AND stream_ordering > ?" + ) + args = [user_id, min_stream_ordering] + if max_stream_ordering is not None: + sql += " AND stream_ordering <= ?" + args.append(max_stream_ordering) + sql += " ORDER BY stream_ordering ASC" + txn.execute(sql, args) + return txn.fetchall() + ret = yield self.runInteraction("get_unread_push_actions_for_user_in_range", f) + defer.returnValue([ + { + "event_id": row[0], + "stream_ordering": row[1], + "actions": json.loads(row[2]), + } for row in ret + ]) + + @defer.inlineCallbacks + def get_latest_push_action_stream_ordering(self): + def f(txn): + txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") + return txn.fetchone() + result = yield self.runInteraction( + "get_latest_push_action_stream_ordering", f + ) + defer.returnValue(result[0] or 0) + def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): # Sad that we have to blow away the cache for the whole room here txn.call_after( diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5d299a1132..ceae8715ce 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -61,6 +61,17 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks def persist_events(self, events_and_contexts, backfilled=False): + """ + Write events to the database + Args: + events_and_contexts: list of tuples of (event, context) + backfilled: ? + + Returns: Tuple of stream_orderings where the first is the minimum and + last is the maximum stream ordering assigned to the events when + persisting. + + """ if not events_and_contexts: return @@ -191,6 +202,7 @@ class EventsStore(SQLBaseStore): txn.call_after(self._get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) + txn.call_after(self.get_users_with_pushers_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,)) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index d1669c778a..f7886dd1bb 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -18,6 +18,8 @@ from twisted.internet import defer from canonicaljson import encode_canonical_json +from synapse.util.caches.descriptors import cachedInlineCallbacks + import logging import simplejson as json import types @@ -107,31 +109,46 @@ class PusherStore(SQLBaseStore): "get_all_updated_pushers", get_all_updated_pushers_txn ) + @cachedInlineCallbacks(num_args=1) + def get_users_with_pushers_in_room(self, room_id): + users = yield self.get_users_in_room(room_id) + + result = yield self._simple_select_many_batch( + 'pushers', 'user_name', users, ['user_name'] + ) + + defer.returnValue([r['user_name'] for r in result]) + @defer.inlineCallbacks def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, - pushkey, pushkey_ts, lang, data, profile_tag=""): - with self._pushers_id_gen.get_next() as stream_id: - yield self._simple_upsert( - "pushers", - dict( - app_id=app_id, - pushkey=pushkey, - user_name=user_id, - ), - dict( - access_token=access_token, - kind=kind, - app_display_name=app_display_name, - device_display_name=device_display_name, - ts=pushkey_ts, - lang=lang, - data=encode_canonical_json(data), - profile_tag=profile_tag, - id=stream_id, - ), - desc="add_pusher", - ) + pushkey, pushkey_ts, lang, data, last_stream_ordering, + profile_tag=""): + def f(txn): + txn.call_after(self.get_users_with_pushers_in_room.invalidate_all) + with self._pushers_id_gen.get_next() as stream_id: + return self._simple_upsert_txn( + txn, + "pushers", + dict( + app_id=app_id, + pushkey=pushkey, + user_name=user_id, + ), + dict( + access_token=access_token, + kind=kind, + app_display_name=app_display_name, + device_display_name=device_display_name, + ts=pushkey_ts, + lang=lang, + data=encode_canonical_json(data), + last_stream_ordering=last_stream_ordering, + profile_tag=profile_tag, + id=stream_id, + ), + ) + defer.returnValue((yield self.runInteraction("add_pusher", f))) @defer.inlineCallbacks def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): @@ -153,22 +170,28 @@ class PusherStore(SQLBaseStore): ) @defer.inlineCallbacks - def update_pusher_last_token(self, app_id, pushkey, user_id, last_token): + def update_pusher_last_stream_ordering(self, app_id, pushkey, user_id, + last_stream_ordering): yield self._simple_update_one( "pushers", {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, - {'last_token': last_token}, - desc="update_pusher_last_token", + {'last_stream_ordering': last_stream_ordering}, + desc="update_pusher_last_stream_ordering", ) @defer.inlineCallbacks - def update_pusher_last_token_and_success(self, app_id, pushkey, user_id, - last_token, last_success): + def update_pusher_last_stream_ordering_and_success(self, app_id, pushkey, + user_id, + last_stream_ordering, + last_success): yield self._simple_update_one( "pushers", {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, - {'last_token': last_token, 'last_success': last_success}, - desc="update_pusher_last_token_and_success", + { + 'last_stream_ordering': last_stream_ordering, + 'last_success': last_success + }, + desc="update_pusher_last_stream_ordering_and_success", ) @defer.inlineCallbacks diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d46a963bb8..701dd2f656 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -319,26 +319,6 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) - @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, - inlineCallbacks=True) - def are_guests(self, user_ids): - sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( - ",".join("?" for _ in user_ids), - ) - - rows = yield self._execute( - "are_guests", self.cursor_to_dict, sql, *user_ids - ) - - result = {user_id: False for user_id in user_ids} - - result.update({ - row["name"]: bool(row["is_guest"]) - for row in rows - }) - - defer.returnValue(result) - def _query_for_auth(self, txn, token): sql = ( "SELECT users.name, users.is_guest, access_tokens.id as token_id" diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 66e7a40e3c..22a690aa8d 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -58,6 +58,7 @@ class RoomMemberStore(SQLBaseStore): txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) + txn.call_after(self.get_users_with_pushers_in_room.invalidate, (event.room_id,)) txn.call_after( self._membership_stream_cache.entity_has_changed, event.state_key, event.internal_metadata.stream_ordering diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/schema/delta/31/pushers.py new file mode 100644 index 0000000000..7e0e385fb5 --- /dev/null +++ b/synapse/storage/schema/delta/31/pushers.py @@ -0,0 +1,75 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Change the last_token to last_stream_ordering now that pushers no longer +# listen on an event stream but instead select out of the event_push_actions +# table. + + +import logging + +logger = logging.getLogger(__name__) + + +def token_to_stream_ordering(token): + return int(token[1:].split('_')[0]) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + logger.info("Porting pushers table, delta 31...") + cur.execute(""" + CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data TEXT, + last_stream_ordering INTEGER, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) + ) + """) + cur.execute("""SELECT + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + FROM pushers + """) + count = 0 + for row in cur.fetchall(): + row = list(row) + row[12] = token_to_stream_ordering(row[12]) + cur.execute(database_engine.convert_param_style(""" + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_stream_ordering, last_success, + failing_since + ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))), + row + ) + count += 1 + cur.execute("DROP TABLE pushers") + cur.execute("ALTER TABLE pushers2 RENAME TO pushers") + logger.info("Moved %d pushers to new table", count) -- cgit 1.4.1 From 6bfec56796132520ad864ad00f8dd6773512f9f4 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 16:17:15 +0100 Subject: Test that room membership is replicated --- synapse/replication/slave/storage/events.py | 7 +-- tests/replication/slave/storage/test_events.py | 71 +++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 11 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 68b924e37b..680dc89536 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -71,9 +71,6 @@ class SlavedEventStore(BaseSlavedStore): get_current_state = DataStore.get_current_state.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__ - _get_rooms_for_user_where_membership_is_txn = ( - DataStore._get_rooms_for_user_where_membership_is_txn.__func__ - ) get_rooms_for_user_where_membership_is = ( DataStore.get_rooms_for_user_where_membership_is.__func__ ) @@ -95,6 +92,10 @@ class SlavedEventStore(BaseSlavedStore): _fetch_events_txn = DataStore._fetch_events_txn.__func__ _fetch_event_rows = DataStore._fetch_event_rows.__func__ _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ + _get_rooms_for_user_where_membership_is_txn = ( + DataStore._get_rooms_for_user_where_membership_is_txn.__func__ + ) + _get_members_rows_txn = DataStore._get_members_rows_txn.__func__ def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index c30c7c6063..351d777fb2 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -14,14 +14,14 @@ from ._base import BaseSlavedStoreTestCase -from synapse.types import UserID from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext +from synapse.storage.roommember import RoomsForUser from twisted.internet import defer USER_ID = "@feeling:blue" -USER = UserID.from_string(USER_ID) +USER_ID_2 = "@bright:blue" OUTLIER = {"outlier": True} ROOM_ID = "!room:blue" @@ -69,16 +69,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): "get_room_name_and_aliases", (ROOM_ID,), (None, []) ) + @defer.inlineCallbacks + def test_room_members(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), []) + yield self.check("get_users_in_room", (ROOM_ID,), []) + + # Join the room. + join = yield self.persist(type="m.room.member", key=USER_ID, membership="join") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser( + room_id=ROOM_ID, + sender=USER_ID, + membership="join", + event_id=join.event_id, + stream_ordering=join.internal_metadata.stream_ordering, + )]) + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) + + # Leave the room. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), []) + yield self.check("get_users_in_room", (ROOM_ID,), []) + + # Add some other user to the room. + join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser( + room_id=ROOM_ID, + sender=USER_ID, + membership="join", + event_id=join.event_id, + stream_ordering=join.internal_metadata.stream_ordering, + )]) + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2]) + + # Join the room clobbering the state. + # This should remove any evidence of the other user being in the room. + yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) + yield self.check("get_rooms_for_user", (USER_ID_2,), []) + event_id = 0 @defer.inlineCallbacks def persist( - self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, - internal={}, - state=None, reset_state=False, backfill=False, - depth=None, prev_events=[], auth_events=[], prev_state=[], - **content + self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, + state=None, reset_state=False, backfill=False, + depth=None, prev_events=[], auth_events=[], prev_state=[], + **content ): + """ + Returns: + synapse.events.FrozenEvent: The event that was persisted. + """ if depth is None: depth = self.event_id @@ -103,12 +153,17 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): context = EventContext(current_state=state) + ordering = None if backfill: yield self.master_store.persist_events( [(event, context)], backfilled=True ) else: - yield self.master_store.persist_event( + ordering, _ = yield self.master_store.persist_event( event, context, current_state=reset_state ) + + if ordering: + event.internal_metadata.stream_ordering = ordering + defer.returnValue(event) -- cgit 1.4.1 From 0fd1cd24003b54e475985cf90db4223c3098375d Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 6 Apr 2016 16:50:47 +0100 Subject: pep8 --- synapse/storage/event_push_actions.py | 2 +- synapse/storage/events.py | 4 +++- synapse/storage/registration.py | 2 +- synapse/storage/roommember.py | 4 +++- 4 files changed, 8 insertions(+), 4 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 5f61743e34..4d72e4a85e 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -144,7 +144,7 @@ class EventPushActionsStore(SQLBaseStore): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() result = yield self.runInteraction( - "get_latest_push_action_stream_ordering", f + "get_latest_push_action_stream_ordering", f ) defer.returnValue(result[0] or 0) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ceae8715ce..5be5bc01b1 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -202,7 +202,9 @@ class EventsStore(SQLBaseStore): txn.call_after(self._get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) - txn.call_after(self.get_users_with_pushers_in_room.invalidate, (event.room_id,)) + txn.call_after( + self.get_users_with_pushers_in_room.invalidate, (event.room_id,) + ) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,)) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 701dd2f656..7af0cae6a5 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks class RegistrationStore(SQLBaseStore): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 22a690aa8d..088ad0f914 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -58,7 +58,9 @@ class RoomMemberStore(SQLBaseStore): txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) - txn.call_after(self.get_users_with_pushers_in_room.invalidate, (event.room_id,)) + txn.call_after( + self.get_users_with_pushers_in_room.invalidate, (event.room_id,) + ) txn.call_after( self._membership_stream_cache.entity_has_changed, event.state_key, event.internal_metadata.stream_ordering -- cgit 1.4.1 From 3d95405e5fe4ef1b795dbfd63a3532cac65b8cd4 Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Tue, 5 Apr 2016 17:26:37 +0200 Subject: Introduce LDAP authentication --- synapse/config/homeserver.py | 3 ++- synapse/config/ldap.py | 48 +++++++++++++++++++++++++++++++++ synapse/python_dependencies.py | 1 + synapse/rest/client/v1/login.py | 60 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 synapse/config/ldap.py (limited to 'synapse') diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index acf74c8761..9a80ac39ec 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -30,13 +30,14 @@ from .saml2 import SAML2Config from .cas import CasConfig from .password import PasswordConfig from .jwt import JWTConfig +from .ldap import LDAPConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, - JWTConfig, PasswordConfig,): + JWTConfig, LDAPConfig, PasswordConfig,): pass diff --git a/synapse/config/ldap.py b/synapse/config/ldap.py new file mode 100644 index 0000000000..86528139e2 --- /dev/null +++ b/synapse/config/ldap.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 Niklas Riekenbrauck +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class LDAPConfig(Config): + def read_config(self, config): + ldap_config = config.get("ldap_config", None) + if ldap_config: + self.ldap_enabled = ldap_config.get("enabled", False) + self.ldap_server = ldap_config["server"] + self.ldap_port = ldap_config["port"] + self.ldap_search_base = ldap_config["search_base"] + self.ldap_search_property = ldap_config["search_property"] + self.ldap_email_property = ldap_config["email_property"] + self.ldap_full_name_property = ldap_config["full_name_property"] + else: + self.ldap_enabled = False + self.ldap_server = None + self.ldap_port = None + self.ldap_search_base = None + self.ldap_search_property = None + self.ldap_email_property = None + self.ldap_full_name_property = None + + def default_config(self, **kwargs): + return """\ + # ldap_config: + # server: "ldap://localhost" + # port: 389 + # search_base: "ou=Users,dc=example,dc=com" + # search_property: "cn" + # email_property: "email" + # full_name_property: "givenName" + """ diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index cf1414b4db..d6b6e82bd7 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -37,6 +37,7 @@ REQUIREMENTS = { "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pymacaroons-pynacl": ["pymacaroons"], "pyjwt": ["jwt"], + "python-ldap": ["ldap"], } CONDITIONAL_REQUIREMENTS = { "web_client": { diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index d14ce3efa2..13720973be 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -36,6 +36,8 @@ import xml.etree.ElementTree as ET import jwt from jwt.exceptions import InvalidTokenError +import ldap + logger = logging.getLogger(__name__) @@ -47,6 +49,7 @@ class LoginRestServlet(ClientV1RestServlet): CAS_TYPE = "m.login.cas" TOKEN_TYPE = "m.login.token" JWT_TYPE = "m.login.jwt" + LDAP_TYPE = "m.login.ldap" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) @@ -56,6 +59,13 @@ class LoginRestServlet(ClientV1RestServlet): self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm + self.ldap_enabled = hs.config.ldap_enabled + self.ldap_server = hs.config.ldap_server + self.ldap_port = hs.config.ldap_port + self.ldap_search_base = hs.config.ldap_search_base + self.ldap_search_property = hs.config.ldap_search_property + self.ldap_email_property = hs.config.ldap_email_property + self.ldap_full_name_property = hs.config.ldap_full_name_property self.cas_enabled = hs.config.cas_enabled self.cas_server_url = hs.config.cas_server_url self.cas_required_attributes = hs.config.cas_required_attributes @@ -64,6 +74,8 @@ class LoginRestServlet(ClientV1RestServlet): def on_GET(self, request): flows = [] + if self.ldap_enabled: + flows.append({"type": LoginRestServlet.LDAP_TYPE}) if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) if self.saml2_enabled: @@ -107,6 +119,10 @@ class LoginRestServlet(ClientV1RestServlet): "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) + elif self.ldap_enabled and (login_submission["type"] == + LoginRestServlet.JWT_TYPE): + result = yield self.do_ldap_login(login_submission) + defer.returnValue(result) elif self.jwt_enabled and (login_submission["type"] == LoginRestServlet.JWT_TYPE): result = yield self.do_jwt_login(login_submission) @@ -160,6 +176,50 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) + @defer.inlineCallbacks + def do_ldap_login(self, login_submission): + if 'medium' in login_submission and 'address' in login_submission: + user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + login_submission['medium'], login_submission['address'] + ) + if not user_id: + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + else: + user_id = login_submission['user'] + + if not user_id.startswith('@'): + user_id = UserID.create( + user_id, self.hs.hostname + ).to_string() + + # FIXME check against LDAP Server!! + + auth_handler = self.handlers.auth_handler + user_exists = yield auth_handler.does_user_exist(user_id) + if user_exists: + user_id, access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(user_id) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "refresh_token": refresh_token, + "home_server": self.hs.hostname, + } + + else: + user_id, access_token = ( + yield self.handlers.registration_handler.register(localpart=user_id.localpart) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + @defer.inlineCallbacks def do_token_login(self, login_submission): token = login_submission['token'] -- cgit 1.4.1 From 7b9319b1c837991ab187e2f280ff267c35a7c653 Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 13:02:49 +0200 Subject: move LDAP authentication to AuthenticationHandler --- synapse/handlers/auth.py | 54 +++++++++++++++++++++++++++++++++++----- synapse/rest/client/v1/login.py | 55 ----------------------------------------- 2 files changed, 48 insertions(+), 61 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d5d6faa85f..eeca820845 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -30,6 +30,8 @@ import simplejson import synapse.util.stringutils as stringutils +import ldap + logger = logging.getLogger(__name__) @@ -49,6 +51,15 @@ class AuthHandler(BaseHandler): self.sessions = {} self.INVALID_TOKEN_HTTP_STATUS = 401 + self.ldap_enabled = hs.config.ldap_enabled + self.ldap_server = hs.config.ldap_server + self.ldap_port = hs.config.ldap_port + self.ldap_search_base = hs.config.ldap_search_base + self.ldap_search_property = hs.config.ldap_search_property + self.ldap_email_property = hs.config.ldap_email_property + self.ldap_full_name_property = hs.config.ldap_full_name_property + + @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ @@ -215,8 +226,8 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) + self._check_password(user_id, password) + defer.returnValue(user_id) @defer.inlineCallbacks @@ -340,8 +351,8 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) + + self._check_password(user_id, password) logger.info("Logging in user %s", user_id) access_token = yield self.issue_access_token(user_id) @@ -407,12 +418,43 @@ class AuthHandler(BaseHandler): else: defer.returnValue(user_infos.popitem()) - def _check_password(self, user_id, password, stored_hash): + def _check_password(self, user_id, password): """Checks that user_id has passed password, raises LoginError if not.""" - if not self.validate_hash(password, stored_hash): + + if not (self._check_ldap_password(user_id, password) or self._check_local_password(user_id, password)): logger.warn("Failed password login for user %s", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) + def _check_local_password(self, user_id, password): + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + return not self.validate_hash(password, password_hash) + + def _check_ldap_password(self, user_id, password): + if not self.ldap_enabled: + return False + + logger.info("Authenticating %s with LDAP" % user_id) + try: + l = ldap.initialize("%s:%s" % (ldap_server, ldap_port)) + if self.ldap_tls: + logger.debug("Initiating TLS") + self._connection.start_tls_s() + + dn = "%s=%s, %s" % (ldap_search_property, user_id.localpart, ldap_search_base) + logger.debug("DN for LDAP authentication: %s" % dn) + + l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) + + if not self.does_user_exist(user_id): + user_id, access_token = ( + yield self.handlers.registration_handler.register(localpart=user_id.localpart) + ) + + return True + except ldap.LDAPError, e: + logger.info(e) + return False + @defer.inlineCallbacks def issue_access_token(self, user_id): access_token = self.generate_access_token(user_id) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 13720973be..da0fd2a8e0 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -36,8 +36,6 @@ import xml.etree.ElementTree as ET import jwt from jwt.exceptions import InvalidTokenError -import ldap - logger = logging.getLogger(__name__) @@ -49,7 +47,6 @@ class LoginRestServlet(ClientV1RestServlet): CAS_TYPE = "m.login.cas" TOKEN_TYPE = "m.login.token" JWT_TYPE = "m.login.jwt" - LDAP_TYPE = "m.login.ldap" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) @@ -59,13 +56,6 @@ class LoginRestServlet(ClientV1RestServlet): self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm - self.ldap_enabled = hs.config.ldap_enabled - self.ldap_server = hs.config.ldap_server - self.ldap_port = hs.config.ldap_port - self.ldap_search_base = hs.config.ldap_search_base - self.ldap_search_property = hs.config.ldap_search_property - self.ldap_email_property = hs.config.ldap_email_property - self.ldap_full_name_property = hs.config.ldap_full_name_property self.cas_enabled = hs.config.cas_enabled self.cas_server_url = hs.config.cas_server_url self.cas_required_attributes = hs.config.cas_required_attributes @@ -74,8 +64,6 @@ class LoginRestServlet(ClientV1RestServlet): def on_GET(self, request): flows = [] - if self.ldap_enabled: - flows.append({"type": LoginRestServlet.LDAP_TYPE}) if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) if self.saml2_enabled: @@ -176,49 +164,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - @defer.inlineCallbacks - def do_ldap_login(self, login_submission): - if 'medium' in login_submission and 'address' in login_submission: - user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - login_submission['medium'], login_submission['address'] - ) - if not user_id: - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - else: - user_id = login_submission['user'] - - if not user_id.startswith('@'): - user_id = UserID.create( - user_id, self.hs.hostname - ).to_string() - - # FIXME check against LDAP Server!! - - auth_handler = self.handlers.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "refresh_token": refresh_token, - "home_server": self.hs.hostname, - } - - else: - user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user_id.localpart) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } - - defer.returnValue((200, result)) - @defer.inlineCallbacks def do_token_login(self, login_submission): -- cgit 1.4.1 From 92767dd70313c81c12b91bf35ed44044969b4ef6 Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 16:57:54 +0200 Subject: add tls property --- synapse/config/ldap.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) (limited to 'synapse') diff --git a/synapse/config/ldap.py b/synapse/config/ldap.py index 86528139e2..9c14593a99 100644 --- a/synapse/config/ldap.py +++ b/synapse/config/ldap.py @@ -23,6 +23,7 @@ class LDAPConfig(Config): self.ldap_enabled = ldap_config.get("enabled", False) self.ldap_server = ldap_config["server"] self.ldap_port = ldap_config["port"] + self.ldap_tls = ldap_config.get("tls", False) self.ldap_search_base = ldap_config["search_base"] self.ldap_search_property = ldap_config["search_property"] self.ldap_email_property = ldap_config["email_property"] @@ -31,6 +32,7 @@ class LDAPConfig(Config): self.ldap_enabled = False self.ldap_server = None self.ldap_port = None + self.ldap_tls = False self.ldap_search_base = None self.ldap_search_property = None self.ldap_email_property = None @@ -39,10 +41,12 @@ class LDAPConfig(Config): def default_config(self, **kwargs): return """\ # ldap_config: - # server: "ldap://localhost" - # port: 389 - # search_base: "ou=Users,dc=example,dc=com" - # search_property: "cn" - # email_property: "email" - # full_name_property: "givenName" + # enabled: true + # server: "ldap://localhost" + # port: 389 + # tls: false + # search_base: "ou=Users,dc=example,dc=com" + # search_property: "cn" + # email_property: "email" + # full_name_property: "givenName" """ -- cgit 1.4.1 From 823b8be4b706b54457d6f1d8f1065ba37a14026d Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 16:58:50 +0200 Subject: add tls property and twist my head around twisted --- synapse/handlers/auth.py | 44 +++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 15 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index eeca820845..14a2a4d8b9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -54,11 +54,13 @@ class AuthHandler(BaseHandler): self.ldap_enabled = hs.config.ldap_enabled self.ldap_server = hs.config.ldap_server self.ldap_port = hs.config.ldap_port + self.ldap_tls = hs.config.ldap_tls self.ldap_search_base = hs.config.ldap_search_base self.ldap_search_property = hs.config.ldap_search_property self.ldap_email_property = hs.config.ldap_email_property self.ldap_full_name_property = hs.config.ldap_full_name_property + self.hs = hs # FIXME better possibility to access registrationHandler later? @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -352,7 +354,10 @@ class AuthHandler(BaseHandler): LoginError if there was an authentication problem. """ - self._check_password(user_id, password) + if not self._check_password(user_id, password): + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + logger.info("Logging in user %s", user_id) access_token = yield self.issue_access_token(user_id) @@ -418,42 +423,51 @@ class AuthHandler(BaseHandler): else: defer.returnValue(user_infos.popitem()) + @defer.inlineCallbacks def _check_password(self, user_id, password): - """Checks that user_id has passed password, raises LoginError if not.""" + defer.returnValue(not ((yield self._check_ldap_password(user_id, password)) or (yield self._check_local_password(user_id, password)))) - if not (self._check_ldap_password(user_id, password) or self._check_local_password(user_id, password)): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) + @defer.inlineCallbacks def _check_local_password(self, user_id, password): - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - return not self.validate_hash(password, password_hash) + try: + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + defer.returnValue(not self.validate_hash(password, password_hash)) + except: + defer.returnValue(False) + + @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): if not self.ldap_enabled: - return False + logger.info("LDAP not configured") + defer.returnValue(False) logger.info("Authenticating %s with LDAP" % user_id) try: - l = ldap.initialize("%s:%s" % (ldap_server, ldap_port)) + ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) + logger.debug("Connecting LDAP server at %s" % ldap_url) + l = ldap.initialize(ldap_url) if self.ldap_tls: logger.debug("Initiating TLS") self._connection.start_tls_s() - dn = "%s=%s, %s" % (ldap_search_property, user_id.localpart, ldap_search_base) + local_name = UserID.from_string(user_id).localpart + + dn = "%s=%s, %s" % (self.ldap_search_property, local_name, self.ldap_search_base) logger.debug("DN for LDAP authentication: %s" % dn) l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) - if not self.does_user_exist(user_id): + if not (yield self.does_user_exist(user_id)): user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user_id.localpart) + yield self.hs.get_handlers().registration_handler.register(localpart=local_name) ) - return True + defer.returnValue(True) except ldap.LDAPError, e: - logger.info(e) - return False + logger.info("LDAP error: %s" % e) + defer.returnValue(False) @defer.inlineCallbacks def issue_access_token(self, user_id): -- cgit 1.4.1 From 8f0e47fae81c314f8d6e664e60d5ce5b136d99d4 Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 17:04:53 +0200 Subject: cleanup --- synapse/rest/client/v1/login.py | 5 ----- 1 file changed, 5 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index da0fd2a8e0..d14ce3efa2 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -107,10 +107,6 @@ class LoginRestServlet(ClientV1RestServlet): "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) - elif self.ldap_enabled and (login_submission["type"] == - LoginRestServlet.JWT_TYPE): - result = yield self.do_ldap_login(login_submission) - defer.returnValue(result) elif self.jwt_enabled and (login_submission["type"] == LoginRestServlet.JWT_TYPE): result = yield self.do_jwt_login(login_submission) @@ -164,7 +160,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - @defer.inlineCallbacks def do_token_login(self, login_submission): token = login_submission['token'] -- cgit 1.4.1 From afff321e9a4d4a4d27841cc5cd737720d78dbffd Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 17:32:06 +0200 Subject: code style --- synapse/handlers/auth.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 14a2a4d8b9..37cbaa0b46 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -60,7 +60,7 @@ class AuthHandler(BaseHandler): self.ldap_email_property = hs.config.ldap_email_property self.ldap_full_name_property = hs.config.ldap_full_name_property - self.hs = hs # FIXME better possibility to access registrationHandler later? + self.hs = hs # FIXME better possibility to access registrationHandler later? @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -425,8 +425,12 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_password(self, user_id, password): - defer.returnValue(not ((yield self._check_ldap_password(user_id, password)) or (yield self._check_local_password(user_id, password)))) - + defer.returnValue( + not ( + (yield self._check_ldap_password(user_id, password)) + or + (yield self._check_local_password(user_id, password)) + )) @defer.inlineCallbacks def _check_local_password(self, user_id, password): @@ -436,7 +440,6 @@ class AuthHandler(BaseHandler): except: defer.returnValue(False) - @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): if not self.ldap_enabled: @@ -454,14 +457,18 @@ class AuthHandler(BaseHandler): local_name = UserID.from_string(user_id).localpart - dn = "%s=%s, %s" % (self.ldap_search_property, local_name, self.ldap_search_base) + dn = "%s=%s, %s" % ( + self.ldap_search_property, + local_name, + self.ldap_search_base) logger.debug("DN for LDAP authentication: %s" % dn) l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) if not (yield self.does_user_exist(user_id)): user_id, access_token = ( - yield self.hs.get_handlers().registration_handler.register(localpart=local_name) + handler = self.hs.get_handlers().registration_handler + yield handler.register(localpart=local_name) ) defer.returnValue(True) -- cgit 1.4.1 From 67f3a50e9ae68b66b465b5b4b86bc81da625d1e6 Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 17:44:03 +0200 Subject: fix exception handling --- synapse/handlers/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 37cbaa0b46..0e6b7c9e26 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -437,7 +437,7 @@ class AuthHandler(BaseHandler): try: user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) defer.returnValue(not self.validate_hash(password, password_hash)) - except: + except LoginError: defer.returnValue(False) @defer.inlineCallbacks @@ -473,7 +473,7 @@ class AuthHandler(BaseHandler): defer.returnValue(True) except ldap.LDAPError, e: - logger.info("LDAP error: %s" % e) + logger.warn("LDAP error: %s", e) defer.returnValue(False) @defer.inlineCallbacks -- cgit 1.4.1 From 875ed05bdcfdee72452a3eab196e3935a79e4004 Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 17:48:36 +0200 Subject: fix pep8 --- synapse/handlers/auth.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0e6b7c9e26..ee2b285cc1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -358,7 +358,6 @@ class AuthHandler(BaseHandler): logger.warn("Failed password login for user %s", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) - logger.info("Logging in user %s", user_id) access_token = yield self.issue_access_token(user_id) refresh_token = yield self.issue_refresh_token(user_id) @@ -466,8 +465,8 @@ class AuthHandler(BaseHandler): l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) if not (yield self.does_user_exist(user_id)): + handler = self.hs.get_handlers().registration_handler user_id, access_token = ( - handler = self.hs.get_handlers().registration_handler yield handler.register(localpart=local_name) ) -- cgit 1.4.1 From 4c5e8adf8b326798ec71a1cc1caac49f63980ba8 Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 17:56:12 +0200 Subject: conditionally import ldap --- synapse/handlers/auth.py | 7 +++++-- synapse/python_dependencies.py | 1 - 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ee2b285cc1..12585abb1b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -30,8 +30,6 @@ import simplejson import synapse.util.stringutils as stringutils -import ldap - logger = logging.getLogger(__name__) @@ -60,6 +58,9 @@ class AuthHandler(BaseHandler): self.ldap_email_property = hs.config.ldap_email_property self.ldap_full_name_property = hs.config.ldap_full_name_property + if self.ldap_enabled: + import ldap + self.hs = hs # FIXME better possibility to access registrationHandler later? @defer.inlineCallbacks @@ -445,6 +446,8 @@ class AuthHandler(BaseHandler): logger.info("LDAP not configured") defer.returnValue(False) + import ldap + logger.info("Authenticating %s with LDAP" % user_id) try: ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index d6b6e82bd7..cf1414b4db 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -37,7 +37,6 @@ REQUIREMENTS = { "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pymacaroons-pynacl": ["pymacaroons"], "pyjwt": ["jwt"], - "python-ldap": ["ldap"], } CONDITIONAL_REQUIREMENTS = { "web_client": { -- cgit 1.4.1 From 3555a659ec5a78ef1dad2a9fb1e28d2fcb4f06b5 Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 18:03:55 +0200 Subject: output ldap version for info and to pacify pep8 --- synapse/handlers/auth.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 12585abb1b..cae81dbd67 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -60,6 +60,8 @@ class AuthHandler(BaseHandler): if self.ldap_enabled: import ldap + logger.info("Import ldap version: %s", ldap.__version__) + self.hs = hs # FIXME better possibility to access registrationHandler later? -- cgit 1.4.1 From 27a0c21c38f83572f984b9556ab5740a91428caf Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 18:10:14 +0200 Subject: make tests for ldap more specific to not be fooled by Mocks --- synapse/handlers/auth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index cae81dbd67..f3acdf00da 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -58,7 +58,7 @@ class AuthHandler(BaseHandler): self.ldap_email_property = hs.config.ldap_email_property self.ldap_full_name_property = hs.config.ldap_full_name_property - if self.ldap_enabled: + if self.ldap_enabled is True: import ldap logger.info("Import ldap version: %s", ldap.__version__) @@ -444,8 +444,8 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): - if not self.ldap_enabled: - logger.info("LDAP not configured") + if self.ldap_enabled is not True: + logger.debug("LDAP not configured") defer.returnValue(False) import ldap -- cgit 1.4.1 From 9c62fcdb688d889c6d3deffbc82ac4bbfbd4ffc4 Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 18:16:35 +0200 Subject: remove line --- synapse/handlers/auth.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f3acdf00da..7c62f833ae 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -62,7 +62,6 @@ class AuthHandler(BaseHandler): import ldap logger.info("Import ldap version: %s", ldap.__version__) - self.hs = hs # FIXME better possibility to access registrationHandler later? @defer.inlineCallbacks -- cgit 1.4.1 From ed4d18f516385c2d367388aed00d13879273e99c Mon Sep 17 00:00:00 2001 From: Christoph Witzany Date: Wed, 6 Apr 2016 18:30:11 +0200 Subject: fix check for failed authentication --- synapse/handlers/auth.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7c62f833ae..7a13a8b11c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -230,7 +230,9 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - self._check_password(user_id, password) + if not (yield self._check_password(user_id, password)): + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) defer.returnValue(user_id) @@ -356,7 +358,7 @@ class AuthHandler(BaseHandler): LoginError if there was an authentication problem. """ - if not self._check_password(user_id, password): + if not (yield self._check_password(user_id, password)): logger.warn("Failed password login for user %s", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) -- cgit 1.4.1 From 1ef036567051218c38de5529472d3c9000c6960d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 7 Apr 2016 09:42:52 +0100 Subject: Set profile information when joining rooms remotely --- synapse/handlers/room_member.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'synapse') diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fe2315df8f..8c41cb6f3c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -233,6 +233,11 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts.append(inviter.domain) content = {"membership": Membership.JOIN} + + profile = self.hs.get_handlers().profile_handler + content["displayname"] = yield profile.get_displayname(target) + content["avatar_url"] = yield profile.get_avatar_url(target) + if requester.is_guest: content["kind"] = "guest" -- cgit 1.4.1 From 60ec9793fb44ad445dd1233594957baeede60e4f Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 7 Apr 2016 13:17:56 +0100 Subject: Add tests for get_latest_event_ids_in_room and get_current_state --- synapse/events/__init__.py | 9 ++++ synapse/replication/slave/storage/events.py | 5 +++ tests/replication/slave/storage/test_events.py | 62 ++++++++++++++++++++++++++ 3 files changed, 76 insertions(+) (limited to 'synapse') diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 13154b1723..81e2126202 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -36,6 +36,10 @@ class _EventInternalMetadata(object): def is_invite_from_remote(self): return getattr(self, "invite_from_remote", False) + def __eq__(self, other): + "Equality check for unit tests." + return self.__dict__ == other.__dict__ + def _event_dict_property(key): def getter(self): @@ -180,3 +184,8 @@ class FrozenEvent(EventBase): self.get("type", None), self.get("state_key", None), ) + + def __eq__(self, other): + """Equality check for unit tests. Compares internal_metadata as well + as the event fields""" + return self.__dict__ == other.__dict__ diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 680dc89536..707ddd248a 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -89,8 +89,11 @@ class SlavedEventStore(BaseSlavedStore): _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ _parse_events_txn = DataStore._parse_events_txn.__func__ _get_events_txn = DataStore._get_events_txn.__func__ + _enqueue_events = DataStore._enqueue_events.__func__ + _do_fetch = DataStore._do_fetch.__func__ _fetch_events_txn = DataStore._fetch_events_txn.__func__ _fetch_event_rows = DataStore._fetch_event_rows.__func__ + _get_event_from_row = DataStore._get_event_from_row.__func__ _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ _get_rooms_for_user_where_membership_is_txn = ( DataStore._get_rooms_for_user_where_membership_is_txn.__func__ @@ -158,6 +161,8 @@ class SlavedEventStore(BaseSlavedStore): self._invalidate_get_event_cache(event.event_id) + self.get_latest_event_ids_in_room.invalidate((event.room_id,)) + if not backfilled: self._events_stream_cache.entity_has_changed( event.room_id, event.internal_metadata.stream_ordering diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 351d777fb2..d5d0ef1148 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -116,6 +116,68 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) yield self.check("get_rooms_for_user", (USER_ID_2,), []) + @defer.inlineCallbacks + def test_get_latest_event_ids_in_room(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.replicate() + yield self.check( + "get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id] + ) + + join = yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + prev_events=[(create.event_id, {})], + ) + yield self.replicate() + yield self.check( + "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id] + ) + + @defer.inlineCallbacks + def test_get_current_state(self): + # Create the room. + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), [] + ) + + # Join the room. + join1 = yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + ) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), + [join1] + ) + + # Add some other user to the room. + join2 = yield self.persist( + type="m.room.member", key=USER_ID_2, membership="join", + ) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2), + [join2] + ) + + # Leave the room, then rejoin the room clobbering state. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + join3 = yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2), + [] + ) + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), + [join3] + ) + event_id = 0 @defer.inlineCallbacks -- cgit 1.4.1 From af03ecf35223f93971596f38393c62f4694705fa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 6 Apr 2016 15:44:22 +0100 Subject: Deduplicate joins --- synapse/handlers/room_member.py | 31 ++++++++++++++++++++++++ synapse/util/async.py | 42 +++++++++++++++++++++++++++++++++ synapse/util/caches/response_cache.py | 2 +- tests/util/test_linearizer.py | 44 +++++++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 tests/util/test_linearizer.py (limited to 'synapse') diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fe2315df8f..0fcc9445a8 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -24,6 +24,7 @@ from synapse.api.constants import ( ) from synapse.api.errors import AuthError, SynapseError, Codes from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.async import Linearizer from signedjson.sign import verify_signed_json from signedjson.key import decode_verify_key_bytes @@ -60,6 +61,8 @@ class RoomMemberHandler(BaseHandler): def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) + self.member_linearizer = Linearizer() + self.clock = hs.get_clock() self.distributor = hs.get_distributor() @@ -182,6 +185,34 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts=None, third_party_signed=None, ratelimit=True, + ): + key = (target, room_id,) + + with (yield self.member_linearizer.queue(key)): + result = yield self._update_membership( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + ) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _update_membership( + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, ): effective_membership_state = action if action in ["kick", "unban"]: diff --git a/synapse/util/async.py b/synapse/util/async.py index cd4d90f3cf..408c86be91 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -19,6 +19,8 @@ from twisted.internet import defer, reactor from .logcontext import PreserveLoggingContext, preserve_fn from synapse.util import unwrapFirstError +from contextlib import contextmanager + @defer.inlineCallbacks def sleep(seconds): @@ -137,3 +139,43 @@ def concurrently_execute(func, args, limit): preserve_fn(_concurrently_execute_inner)() for _ in xrange(limit) ], consumeErrors=True).addErrback(unwrapFirstError) + + +@contextmanager +def _trigger_defer_manager(d): + try: + yield + finally: + d.callback(None) + + +class Linearizer(object): + """Linearizes access to resources based on a key. Useful to ensure only one + thing is happening at a time on a given resource. + + Example: + + with (yield linearizer.queue("test_key")): + # do some work. + + """ + def __init__(self): + self.key_to_defer = {} + + @defer.inlineCallbacks + def queue(self, key): + current_defer = self.key_to_defer.get(key) + + new_defer = defer.Deferred() + self.key_to_defer[key] = new_defer + + def remove_if_current(_): + d = self.key_to_defer.get(key) + if d is new_defer: + self.key_to_defer.pop(key, None) + + new_defer.addBoth(remove_if_current) + + yield current_defer + + defer.returnValue(_trigger_defer_manager(new_defer)) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index be310ba320..36686b479e 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -35,7 +35,7 @@ class ResponseCache(object): return None def set(self, key, deferred): - result = ObservableDeferred(deferred) + result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result def remove(r): diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py new file mode 100644 index 0000000000..afcba482f9 --- /dev/null +++ b/tests/util/test_linearizer.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest + +from twisted.internet import defer + +from synapse.util.async import Linearizer + + +class LinearizerTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def test_linearizer(self): + linearizer = Linearizer() + + key = object() + + d1 = linearizer.queue(key) + cm1 = yield d1 + + d2 = linearizer.queue(key) + self.assertFalse(d2.called) + + with cm1: + self.assertFalse(d2.called) + + self.assertTrue(d2.called) + + with (yield d2): + pass -- cgit 1.4.1 From 639cd07d6d4e22e3413349bbd3bfb33db37a8d2f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 7 Apr 2016 14:24:12 +0100 Subject: Add comment --- synapse/util/async.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'synapse') diff --git a/synapse/util/async.py b/synapse/util/async.py index 408c86be91..14a3dfd43f 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -164,6 +164,14 @@ class Linearizer(object): @defer.inlineCallbacks def queue(self, key): + # If there is already a deferred in the queue, we pull it out so that + # we can wait on it later. + # Then we replace it with a deferred that we resolve *after* the + # context manager has exited. + # We only return the context manager after the previous deferred has + # resolved. + # This all has the net effect of creating a chain of deferreds that + # wait for the previous deferred before starting their work. current_defer = self.key_to_defer.get(key) new_defer = defer.Deferred() -- cgit 1.4.1 From ee5aef6c72575045fc441076b29b0c06eb46a28c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 7 Apr 2016 15:29:34 +0100 Subject: Log contexts and squash things together --- synapse/util/async.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) (limited to 'synapse') diff --git a/synapse/util/async.py b/synapse/util/async.py index 14a3dfd43f..072b6362b5 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,7 +16,9 @@ from twisted.internet import defer, reactor -from .logcontext import PreserveLoggingContext, preserve_fn +from .logcontext import ( + PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, +) from synapse.util import unwrapFirstError from contextlib import contextmanager @@ -141,14 +143,6 @@ def concurrently_execute(func, args, limit): ], consumeErrors=True).addErrback(unwrapFirstError) -@contextmanager -def _trigger_defer_manager(d): - try: - yield - finally: - d.callback(None) - - class Linearizer(object): """Linearizes access to resources based on a key. Useful to ensure only one thing is happening at a time on a given resource. @@ -177,13 +171,17 @@ class Linearizer(object): new_defer = defer.Deferred() self.key_to_defer[key] = new_defer - def remove_if_current(_): - d = self.key_to_defer.get(key) - if d is new_defer: - self.key_to_defer.pop(key, None) - - new_defer.addBoth(remove_if_current) + if current_defer: + yield preserve_context_over_deferred(current_defer) - yield current_defer + @contextmanager + def _ctx_manager(d): + try: + yield + finally: + d.callback(None) + d = self.key_to_defer.get(key) + if d is new_defer: + self.key_to_defer.pop(key, None) - defer.returnValue(_trigger_defer_manager(new_defer)) + defer.returnValue(_ctx_manager(new_defer)) -- cgit 1.4.1 From 92e3071623c34350bf072bb77e089d5d6d5f41c2 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 15:39:53 +0100 Subject: Send badge count pushes. Also fix bugs with retrying. --- synapse/handlers/receipts.py | 21 +++++++++++++++++---- synapse/push/httppusher.py | 45 ++++++++++++++++++++++++++++---------------- synapse/push/pusherpool.py | 20 +++++++++++++++++++- synapse/storage/receipts.py | 9 ++++++--- 4 files changed, 71 insertions(+), 24 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 935c339707..26b0368080 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -80,6 +80,9 @@ class ReceiptsHandler(BaseHandler): def _handle_new_receipts(self, receipts): """Takes a list of receipts, stores them and informs the notifier. """ + min_batch_id = None + max_batch_id = None + for receipt in receipts: room_id = receipt["room_id"] receipt_type = receipt["receipt_type"] @@ -97,10 +100,20 @@ class ReceiptsHandler(BaseHandler): stream_id, max_persisted_id = res - with PreserveLoggingContext(): - self.notifier.on_new_event( - "receipt_key", max_persisted_id, rooms=[room_id] - ) + if min_batch_id is None or stream_id < min_batch_id: + min_batch_id = stream_id + if max_batch_id is None or max_persisted_id > max_batch_id: + max_batch_id = max_persisted_id + + affected_room_ids = list(set([r["room_id"] for r in receipts])) + + with PreserveLoggingContext(): + self.notifier.on_new_event( + "receipt_key", max_batch_id, rooms=affected_room_ids + ) + self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids + ) defer.returnValue(True) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index d695885649..0d5450bc01 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -76,15 +76,25 @@ class HttpPusher(object): self.data_minus_url.update(self.data) del self.data_minus_url['url'] + @defer.inlineCallbacks def on_started(self): - self._process() + yield self._process() + @defer.inlineCallbacks def on_new_notifications(self, min_stream_ordering, max_stream_ordering): self.max_stream_ordering = max_stream_ordering - self._process() + yield self._process() + + @defer.inlineCallbacks + def on_new_receipts(self, min_stream_id, max_stream_id): + # We could check the receipts are actually m.read receipts here, + # but currently that's the only type of receipt anyway... + badge = yield push_tools.get_badge_count(self.hs, self.user_id) + yield self.send_badge(badge) + @defer.inlineCallbacks def on_timer(self): - self._process() + yield self._process() def on_stop(self): if self.timed_call: @@ -106,22 +116,24 @@ class HttpPusher(object): self.last_stream_ordering, self.clock.time_msec() ) - self.failing_since = None - yield self.store.update_pusher_failing_since( - self.app_id, self.pushkey, self.user_id, - self.failing_since - ) + if self.failing_since: + self.failing_since = None + yield self.store.update_pusher_failing_since( + self.app_id, self.pushkey, self.user_id, + self.failing_since + ) else: - self.failing_since = self.clock.time_msec() - yield self.store.update_pusher_failing_since( - self.app_id, self.pushkey, self.user_id, - self.failing_since - ) + if not self.failing_since: + self.failing_since = self.clock.time_msec() + yield self.store.update_pusher_failing_since( + self.app_id, self.pushkey, self.user_id, + self.failing_since + ) if ( self.failing_since and self.failing_since < - self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER + self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS ): # we really only give up so that if the URL gets # fixed, we don't suddenly deliver a load @@ -148,7 +160,7 @@ class HttpPusher(object): else: logger.info("Push failed: delaying for %ds", self.backoff_delay) self.timed_call = reactor.callLater(self.backoff_delay, self.on_timer) - self.backoff_delay = min(self.backoff_delay, self.MAX_BACKOFF_SEC) + self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC) break @defer.inlineCallbacks @@ -191,7 +203,8 @@ class HttpPusher(object): d = { 'notification': { - 'id': event.event_id, + 'id': event.event_id, # deprecated: remove soon + 'event_id': event.event_id, 'room_id': event.room_id, 'type': event.type, 'sender': event.user_id, diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index b67ad455ea..7b1ce81e9a 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -126,10 +126,28 @@ class PusherPool: for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): - p.on_new_notifications(min_stream_id, max_stream_id) + yield p.on_new_notifications(min_stream_id, max_stream_id) except: logger.exception("Exception in pusher on_new_notifications") + @defer.inlineCallbacks + def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + yield run_on_reactor() + try: + # Need to subtract 1 from the minimum because the lower bound here + # is not inclusive + updated_receipts = yield self.store.get_all_updated_receipts( + min_stream_id - 1, max_stream_id + ) + # This returns a tuple, user_id is at index 3 + users_affected = set([r[3] for r in updated_receipts]) + for u in users_affected: + if u in self.pushers: + for p in self.pushers[u].values(): + yield p.on_new_receipts(min_stream_id, max_stream_id) + except: + logger.exception("Exception in pusher on_new_receipts") + @defer.inlineCallbacks def _refresh_pusher(self, app_id, pushkey, user_id): resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 4befebc8e2..59d1ac0314 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -390,16 +390,19 @@ class ReceiptsStore(SQLBaseStore): } ) - def get_all_updated_receipts(self, last_id, current_id, limit): + def get_all_updated_receipts(self, last_id, current_id, limit=None): def get_all_updated_receipts_txn(txn): sql = ( "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" " FROM receipts_linearized" " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" - " LIMIT ?" ) - txn.execute(sql, (last_id, current_id, limit)) + args = [last_id, current_id] + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + txn.execute(sql, args) return txn.fetchall() return self.runInteraction( -- cgit 1.4.1 From 95ac3078da54908855721361b1305ed0c41215d5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 7 Apr 2016 16:07:16 +0100 Subject: Rename things --- synapse/util/async.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/util/async.py b/synapse/util/async.py index 072b6362b5..0d6f48e2d8 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -175,13 +175,13 @@ class Linearizer(object): yield preserve_context_over_deferred(current_defer) @contextmanager - def _ctx_manager(d): + def _ctx_manager(): try: yield finally: - d.callback(None) - d = self.key_to_defer.get(key) - if d is new_defer: + new_defer.callback(None) + current_d = self.key_to_defer.get(key) + if current_d is new_defer: self.key_to_defer.pop(key, None) - defer.returnValue(_ctx_manager(new_defer)) + defer.returnValue(_ctx_manager()) -- cgit 1.4.1 From d549fdfa22f6927479d2a185f7420cadbfbf5607 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 16:31:38 +0100 Subject: Remove code that's now been obsoleted or moved elsewhere --- synapse/push/__init__.py | 327 ------------------------------------ synapse/push/httppusher.py | 2 +- synapse/push/push_rule_evaluator.py | 134 +-------------- 3 files changed, 9 insertions(+), 454 deletions(-) (limited to 'synapse') diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 296c4447ec..edf45dc599 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -13,333 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from synapse.streams.config import PaginationConfig -from synapse.types import StreamToken -from synapse.util.logcontext import LoggingContext -from synapse.util.metrics import Measure - -import synapse.util.async -from .push_rule_evaluator import evaluator_for_user_id - -import logging -import random - -logger = logging.getLogger(__name__) - - -_NEXT_ID = 1 - - -def _get_next_id(): - global _NEXT_ID - _id = _NEXT_ID - _NEXT_ID += 1 - return _id - - -# Pushers could now be moved to pull out of the event_push_actions table instead -# of listening on the event stream: this would avoid them having to run the -# rules again. -class Pusher(object): - INITIAL_BACKOFF = 1000 - MAX_BACKOFF = 60 * 60 * 1000 - GIVE_UP_AFTER = 24 * 60 * 60 * 1000 - - def __init__(self, _hs, user_id, app_id, - app_display_name, device_display_name, pushkey, pushkey_ts, - data, last_token, last_success, failing_since): - self.hs = _hs - self.evStreamHandler = self.hs.get_handlers().event_stream_handler - self.store = self.hs.get_datastore() - self.clock = self.hs.get_clock() - self.user_id = user_id - self.app_id = app_id - self.app_display_name = app_display_name - self.device_display_name = device_display_name - self.pushkey = pushkey - self.pushkey_ts = pushkey_ts - self.data = data - self.last_token = last_token - self.last_success = last_success # not actually used - self.backoff_delay = Pusher.INITIAL_BACKOFF - self.failing_since = failing_since - self.alive = True - self.badge = None - - self.name = "Pusher-%d" % (_get_next_id(),) - - # The last value of last_active_time that we saw - self.last_last_active_time = 0 - self.has_unread = True - - @defer.inlineCallbacks - def get_context_for_event(self, ev): - name_aliases = yield self.store.get_room_name_and_aliases( - ev['room_id'] - ) - - ctx = {'aliases': name_aliases[1]} - if name_aliases[0] is not None: - ctx['name'] = name_aliases[0] - - their_member_events_for_room = yield self.store.get_current_state( - room_id=ev['room_id'], - event_type='m.room.member', - state_key=ev['user_id'] - ) - for mev in their_member_events_for_room: - if mev.content['membership'] == 'join' and 'displayname' in mev.content: - dn = mev.content['displayname'] - if dn is not None: - ctx['sender_display_name'] = dn - - defer.returnValue(ctx) - - @defer.inlineCallbacks - def start(self): - with LoggingContext(self.name): - if not self.last_token: - # First-time setup: get a token to start from (we can't - # just start from no token, ie. 'now' - # because we need the result to be reproduceable in case - # we fail to dispatch the push) - config = PaginationConfig(from_token=None, limit='1') - chunk = yield self.evStreamHandler.get_stream( - self.user_id, config, timeout=0, affect_presence=False - ) - self.last_token = chunk['end'] - yield self.store.update_pusher_last_token( - self.app_id, self.pushkey, self.user_id, self.last_token - ) - logger.info("New pusher %s for user %s starting from token %s", - self.pushkey, self.user_id, self.last_token) - - else: - logger.info( - "Old pusher %s for user %s starting", - self.pushkey, self.user_id, - ) - - wait = 0 - while self.alive: - try: - if wait > 0: - yield synapse.util.async.sleep(wait) - with Measure(self.clock, "push"): - yield self.get_and_dispatch() - wait = 0 - except: - if wait == 0: - wait = 1 - else: - wait = min(wait * 2, 1800) - logger.exception( - "Exception in pusher loop for pushkey %s. Pausing for %ds", - self.pushkey, wait - ) - - @defer.inlineCallbacks - def get_and_dispatch(self): - from_tok = StreamToken.from_string(self.last_token) - config = PaginationConfig(from_token=from_tok, limit='1') - timeout = (300 + random.randint(-60, 60)) * 1000 - chunk = yield self.evStreamHandler.get_stream( - self.user_id, config, timeout=timeout, affect_presence=False, - only_keys=("room", "receipt",), - ) - - # limiting to 1 may get 1 event plus 1 presence event, so - # pick out the actual event - single_event = None - read_receipt = None - for c in chunk['chunk']: - if 'event_id' in c: # Hmmm... - single_event = c - elif c['type'] == 'm.receipt': - read_receipt = c - - have_updated_badge = False - if read_receipt: - for receipt_part in read_receipt['content'].values(): - if 'm.read' in receipt_part: - if self.user_id in receipt_part['m.read'].keys(): - have_updated_badge = True - - if not single_event: - if have_updated_badge: - yield self.update_badge() - self.last_token = chunk['end'] - yield self.store.update_pusher_last_token( - self.app_id, - self.pushkey, - self.user_id, - self.last_token - ) - return - - if not self.alive: - return - - processed = False - - rule_evaluator = yield \ - evaluator_for_user_id( - self.user_id, single_event['room_id'], self.store - ) - - actions = yield rule_evaluator.actions_for_event(single_event) - tweaks = rule_evaluator.tweaks_for_actions(actions) - - if 'notify' in actions: - self.badge = yield self._get_badge_count() - rejected = yield self.dispatch_push(single_event, tweaks, self.badge) - self.has_unread = True - if isinstance(rejected, list) or isinstance(rejected, tuple): - processed = True - for pk in rejected: - if pk != self.pushkey: - # for sanity, we only remove the pushkey if it - # was the one we actually sent... - logger.warn( - ("Ignoring rejected pushkey %s because we" - " didn't send it"), pk - ) - else: - logger.info( - "Pushkey %s was rejected: removing", - pk - ) - yield self.hs.get_pusherpool().remove_pusher( - self.app_id, pk, self.user_id - ) - else: - if have_updated_badge: - yield self.update_badge() - processed = True - - if not self.alive: - return - - if processed: - self.backoff_delay = Pusher.INITIAL_BACKOFF - self.last_token = chunk['end'] - yield self.store.update_pusher_last_token_and_success( - self.app_id, - self.pushkey, - self.user_id, - self.last_token, - self.clock.time_msec() - ) - if self.failing_since: - self.failing_since = None - yield self.store.update_pusher_failing_since( - self.app_id, - self.pushkey, - self.user_id, - self.failing_since) - else: - if not self.failing_since: - self.failing_since = self.clock.time_msec() - yield self.store.update_pusher_failing_since( - self.app_id, - self.pushkey, - self.user_id, - self.failing_since - ) - - if (self.failing_since and - self.failing_since < - self.clock.time_msec() - Pusher.GIVE_UP_AFTER): - # we really only give up so that if the URL gets - # fixed, we don't suddenly deliver a load - # of old notifications. - logger.warn("Giving up on a notification to user %s, " - "pushkey %s", - self.user_id, self.pushkey) - self.backoff_delay = Pusher.INITIAL_BACKOFF - self.last_token = chunk['end'] - yield self.store.update_pusher_last_token( - self.app_id, - self.pushkey, - self.user_id, - self.last_token - ) - - self.failing_since = None - yield self.store.update_pusher_failing_since( - self.app_id, - self.pushkey, - self.user_id, - self.failing_since - ) - else: - logger.warn("Failed to dispatch push for user %s " - "(failing for %dms)." - "Trying again in %dms", - self.user_id, - self.clock.time_msec() - self.failing_since, - self.backoff_delay) - yield synapse.util.async.sleep(self.backoff_delay / 1000.0) - self.backoff_delay *= 2 - if self.backoff_delay > Pusher.MAX_BACKOFF: - self.backoff_delay = Pusher.MAX_BACKOFF - - def stop(self): - self.alive = False - - def dispatch_push(self, p, tweaks, badge): - """ - Overridden by implementing classes to actually deliver the notification - Args: - p: The event to notify for as a single event from the event stream - Returns: If the notification was delivered, an array containing any - pushkeys that were rejected by the push gateway. - False if the notification could not be delivered (ie. - should be retried). - """ - pass - - @defer.inlineCallbacks - def update_badge(self): - new_badge = yield self._get_badge_count() - if self.badge != new_badge: - self.badge = new_badge - yield self.send_badge(self.badge) - - def send_badge(self, badge): - """ - Overridden by implementing classes to send an updated badge count - """ - pass - - @defer.inlineCallbacks - def _get_badge_count(self): - invites, joins = yield defer.gatherResults([ - self.store.get_invited_rooms_for_user(self.user_id), - self.store.get_rooms_for_user(self.user_id), - ], consumeErrors=True) - - my_receipts_by_room = yield self.store.get_receipts_for_user( - self.user_id, - "m.read", - ) - - badge = len(invites) - - for r in joins: - if r.room_id in my_receipts_by_room: - last_unread_event_id = my_receipts_by_room[r.room_id] - - notifs = yield ( - self.store.get_unread_event_push_actions_by_room_for_user( - r.room_id, self.user_id, last_unread_event_id - ) - ) - badge += notifs["notify_count"] - defer.returnValue(badge) - class PusherConfigException(Exception): def __init__(self, msg): diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 0d5450bc01..cc030a57a0 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -168,7 +168,7 @@ class HttpPusher(object): if 'notify' not in push_action['actions']: defer.returnValue(True) - tweaks = push_rule_evaluator.PushRuleEvaluator.tweaks_for_actions(push_action['actions']) + tweaks = push_rule_evaluator.tweaks_for_actions(push_action['actions']) badge = yield push_tools.get_badge_count(self.hs, self.user_id) event = yield self.store.get_event(push_action['event_id'], allow_none=True) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index c3c2877629..4db76f18bd 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -13,12 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from .baserules import list_with_base_rules - import logging -import simplejson as json import re from synapse.types import UserID @@ -32,22 +27,6 @@ IS_GLOB = re.compile(r'[\?\*\[\]]') INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") -@defer.inlineCallbacks -def evaluator_for_user_id(user_id, room_id, store): - rawrules = yield store.get_push_rules_for_user(user_id) - enabled_map = yield store.get_push_rules_enabled_for_user(user_id) - our_member_event = yield store.get_current_state( - room_id=room_id, - event_type='m.room.member', - state_key=user_id, - ) - - defer.returnValue(PushRuleEvaluator( - user_id, rawrules, enabled_map, - room_id, our_member_event, store - )) - - def _room_member_count(ev, condition, room_member_count): if 'is' not in condition: return False @@ -74,111 +53,14 @@ def _room_member_count(ev, condition, room_member_count): return False -class PushRuleEvaluator: - DEFAULT_ACTIONS = [] - - def __init__(self, user_id, raw_rules, enabled_map, room_id, - our_member_event, store): - self.user_id = user_id - self.room_id = room_id - self.our_member_event = our_member_event - self.store = store - - rules = [] - for raw_rule in raw_rules: - rule = dict(raw_rule) - rule['conditions'] = json.loads(raw_rule['conditions']) - rule['actions'] = json.loads(raw_rule['actions']) - rules.append(rule) - - self.rules = list_with_base_rules(rules) - - self.enabled_map = enabled_map - - @staticmethod - def tweaks_for_actions(actions): - tweaks = {} - for a in actions: - if not isinstance(a, dict): - continue - if 'set_tweak' in a and 'value' in a: - tweaks[a['set_tweak']] = a['value'] - return tweaks - - @defer.inlineCallbacks - def actions_for_event(self, ev): - """ - This should take into account notification settings that the user - has configured both globally and per-room when we have the ability - to do such things. - """ - if ev['user_id'] == self.user_id: - # let's assume you probably know about messages you sent yourself - defer.returnValue([]) - - room_id = ev['room_id'] - - # get *our* member event for display name matching - my_display_name = None - - if self.our_member_event: - my_display_name = self.our_member_event[0].content.get("displayname") - - room_members = yield self.store.get_users_in_room(room_id) - room_member_count = len(room_members) - - evaluator = PushRuleEvaluatorForEvent(ev, room_member_count) - - for r in self.rules: - enabled = self.enabled_map.get(r['rule_id'], None) - if enabled is not None and not enabled: - continue - elif enabled is None and not r.get("enabled", True): - # if no override, check enabled on the rule itself - # (may have come from a base rule) - continue - - conditions = r['conditions'] - actions = r['actions'] - - # ignore rules with no actions (we have an explict 'dont_notify') - if len(actions) == 0: - logger.warn( - "Ignoring rule id %s with no actions for user %s", - r['rule_id'], self.user_id - ) - continue - - matches = True - for c in conditions: - matches = evaluator.matches( - c, self.user_id, my_display_name - ) - if not matches: - break - - logger.debug( - "Rule %s %s", - r['rule_id'], "matches" if matches else "doesn't match" - ) - - if matches: - logger.debug( - "%s matches for user %s, event %s", - r['rule_id'], self.user_id, ev['event_id'] - ) - - # filter out dont_notify as we treat an empty actions list - # as dont_notify, and this doesn't take up a row in our database - actions = [x for x in actions if x != 'dont_notify'] - - defer.returnValue(actions) - - logger.debug( - "No rules match for user %s, event %s", - self.user_id, ev['event_id'] - ) - defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS) +def tweaks_for_actions(actions): + tweaks = {} + for a in actions: + if not isinstance(a, dict): + continue + if 'set_tweak' in a and 'value' in a: + tweaks[a['set_tweak']] = a['value'] + return tweaks class PushRuleEvaluatorForEvent(object): -- cgit 1.4.1 From 57fa1801c336603c6c710d6d868aaae596a7f5b8 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 7 Apr 2016 16:41:37 +0100 Subject: Add sensible __eq__ operators inside the tests. Rather than adding them globally. This limits the changes to only affect the tests. --- synapse/events/__init__.py | 9 -------- tests/replication/slave/storage/test_events.py | 29 +++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 10 deletions(-) (limited to 'synapse') diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 81e2126202..13154b1723 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -36,10 +36,6 @@ class _EventInternalMetadata(object): def is_invite_from_remote(self): return getattr(self, "invite_from_remote", False) - def __eq__(self, other): - "Equality check for unit tests." - return self.__dict__ == other.__dict__ - def _event_dict_property(key): def getter(self): @@ -184,8 +180,3 @@ class FrozenEvent(EventBase): self.get("type", None), self.get("state_key", None), ) - - def __eq__(self, other): - """Equality check for unit tests. Compares internal_metadata as well - as the event fields""" - return self.__dict__ == other.__dict__ diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index d5d0ef1148..9af62702b3 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -14,20 +14,47 @@ from ._base import BaseSlavedStoreTestCase -from synapse.events import FrozenEvent +from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext from synapse.storage.roommember import RoomsForUser from twisted.internet import defer + USER_ID = "@feeling:blue" USER_ID_2 = "@bright:blue" OUTLIER = {"outlier": True} ROOM_ID = "!room:blue" +def dict_equals(self, other): + return self.__dict__ == other.__dict__ + + +def patch__eq__(cls): + eq = getattr(cls, "__eq__", None) + cls.__eq__ = dict_equals + + def unpatch(): + if eq is not None: + cls.__eq__ = eq + return unpatch + + class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): + def setUp(self): + # Patch up the equality operator for events so that we can check + # whether lists of events match using assertEquals + self.unpatches = [ + patch__eq__(_EventInternalMetadata), + patch__eq__(FrozenEvent), + ] + return super(SlavedEventStoreTestCase, self).setUp() + + def tearDown(self): + [unpatch() for unpatch in self.unpatches] + @defer.inlineCallbacks def test_room_name_and_aliases(self): create = yield self.persist(type="m.room.create", key="", creator=USER_ID) -- cgit 1.4.1 From 2d5c693fd3c72800980f906b8255e3619ac524e2 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 16:43:54 +0100 Subject: Fix port script for changes merged from develop --- synapse/storage/schema/delta/31/pushers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/schema/delta/31/pushers.py index 7e0e385fb5..d07bab012f 100644 --- a/synapse/storage/schema/delta/31/pushers.py +++ b/synapse/storage/schema/delta/31/pushers.py @@ -27,7 +27,7 @@ def token_to_stream_ordering(token): return int(token[1:].split('_')[0]) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table, delta 31...") cur.execute(""" CREATE TABLE IF NOT EXISTS pushers2 ( @@ -73,3 +73,6 @@ def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute("DROP TABLE pushers") cur.execute("ALTER TABLE pushers2 RENAME TO pushers") logger.info("Moved %d pushers to new table", count) + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass -- cgit 1.4.1 From 05d044aac396de9dff64ffb47e8b9a3f43ad0919 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 16:45:38 +0100 Subject: pep8 --- synapse/storage/schema/delta/31/pushers.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse') diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/schema/delta/31/pushers.py index d07bab012f..93367fa09e 100644 --- a/synapse/storage/schema/delta/31/pushers.py +++ b/synapse/storage/schema/delta/31/pushers.py @@ -74,5 +74,6 @@ def run_create(cur, database_engine, *args, **kwargs): cur.execute("ALTER TABLE pushers2 RENAME TO pushers") logger.info("Moved %d pushers to new table", count) + def run_upgrade(cur, database_engine, *args, **kwargs): pass -- cgit 1.4.1 From ceb599e789ef34a3733a99d17701a75fc5409d17 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 7 Apr 2016 16:26:52 +0100 Subject: Add tests for redactions --- synapse/replication/slave/storage/events.py | 4 +- synapse/storage/util/id_generators.py | 2 +- tests/replication/slave/storage/_base.py | 2 +- tests/replication/slave/storage/test_events.py | 51 +++++++++++++++++++++++++- 4 files changed, 54 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 707ddd248a..cfc728a038 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore): "_get_current_state_for_key" ] + get_event = DataStore.get_event.__func__ get_current_state = DataStore.get_current_state.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_rooms_for_user_where_membership_is = ( @@ -103,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore): def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() result["events"] = self._stream_id_gen.get_current_token() - result["backfilled"] = self._backfill_id_gen.get_current_token() + result["backfill"] = self._backfill_id_gen.get_current_token() return result def process_replication(self, result): @@ -145,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore): position = row[0] internal = json.loads(row[1]) event_json = json.loads(row[2]) - event = FrozenEvent(event_json, internal_metadata_dict=internal) self._invalidate_caches_for_event( event, backfilled, reset_state=position in state_resets diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f69f1cdad4..46cf93ff87 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -112,7 +112,7 @@ class StreamIdGenerator(object): self._current + self._step * (n + 1), self._step ) - self._current += n + self._current += n * self._step for next_id in next_ids: self._unfinished_ids.append(next_id) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 0f525a8943..983caafe8a 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): def check(self, method, args, expected_result=None): master_result = yield getattr(self.master_store, method)(*args) slaved_result = yield getattr(self.slaved_store, method)(*args) - self.assertEqual(master_result, slaved_result) if expected_result is not None: self.assertEqual(master_result, expected_result) self.assertEqual(slaved_result, expected_result) + self.assertEqual(master_result, slaved_result) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 9af62702b3..baa4a26eb5 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -205,13 +205,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): [join3] ) + @defer.inlineCallbacks + def test_redactions(self): + yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = yield self.persist( + type="m.room.message", msgtype="m.text", body="Hello" + ) + yield self.replicate() + yield self.check("get_event", [msg.event_id], msg) + + redaction = yield self.persist( + type="m.room.redaction", redacts=msg.event_id + ) + yield self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + yield self.check("get_event", [msg.event_id], redacted) + + @defer.inlineCallbacks + def test_backfilled_redactions(self): + yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = yield self.persist( + type="m.room.message", msgtype="m.text", body="Hello" + ) + yield self.replicate() + yield self.check("get_event", [msg.event_id], msg) + + redaction = yield self.persist( + type="m.room.redaction", redacts=msg.event_id, backfill=True + ) + yield self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + yield self.check("get_event", [msg.event_id], redacted) + event_id = 0 @defer.inlineCallbacks def persist( self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, state=None, reset_state=False, backfill=False, - depth=None, prev_events=[], auth_events=[], prev_state=[], + depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None, **content ): """ @@ -236,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event_dict["state_key"] = key event_dict["prev_state"] = prev_state + if redacts is not None: + event_dict["redacts"] = redacts + event = FrozenEvent(event_dict, internal_metadata_dict=internal) self.event_id += 1 -- cgit 1.4.1 From e1e042f2a1bc489c922f0deccfac54572788c933 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 17:09:36 +0100 Subject: Add comments on min_stream_id saying that the min stream id won't be completely accurate all the time --- synapse/handlers/receipts.py | 1 + synapse/push/httppusher.py | 2 ++ 2 files changed, 3 insertions(+) (limited to 'synapse') diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 26b0368080..a390a1b8bd 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -111,6 +111,7 @@ class ReceiptsHandler(BaseHandler): self.notifier.on_new_event( "receipt_key", max_batch_id, rooms=affected_room_ids ) + # Note that the min here shouldn't be relied upon to be accurate. self.hs.get_pusherpool().on_new_receipts( min_batch_id, max_batch_id, affected_room_ids ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index cc030a57a0..0a1d3817de 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -87,6 +87,8 @@ class HttpPusher(object): @defer.inlineCallbacks def on_new_receipts(self, min_stream_id, max_stream_id): + # Note that the min here shouldn't be relied upon to be accurate. + # We could check the receipts are actually m.read receipts here, # but currently that's the only type of receipt anyway... badge = yield push_tools.get_badge_count(self.hs, self.user_id) -- cgit 1.4.1 From fa129ce5b5cfa2b6cbbb0f1a884f47d740ba1300 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 17:12:29 +0100 Subject: Add measure blocks --- synapse/push/httppusher.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 0a1d3817de..ea45b603c6 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -21,6 +21,8 @@ import logging import push_rule_evaluator import push_tools +from synapse.util.metrics import Measure + logger = logging.getLogger(__name__) @@ -82,8 +84,9 @@ class HttpPusher(object): @defer.inlineCallbacks def on_new_notifications(self, min_stream_ordering, max_stream_ordering): - self.max_stream_ordering = max_stream_ordering - yield self._process() + with Measure(self.clock, "push.on_new_notifications"): + self.max_stream_ordering = max_stream_ordering + yield self._process() @defer.inlineCallbacks def on_new_receipts(self, min_stream_id, max_stream_id): @@ -91,12 +94,14 @@ class HttpPusher(object): # We could check the receipts are actually m.read receipts here, # but currently that's the only type of receipt anyway... - badge = yield push_tools.get_badge_count(self.hs, self.user_id) - yield self.send_badge(badge) + with Measure(self.clock, "push.on_new_receipts"): + badge = yield push_tools.get_badge_count(self.hs, self.user_id) + yield self.send_badge(badge) @defer.inlineCallbacks def on_timer(self): - yield self._process() + with Measure(self.clock, "push.on_timer"): + yield self._process() def on_stop(self): if self.timed_call: -- cgit 1.4.1 From 25cd5bb697996c1764c7746e4dfc1d8fffaaf8b2 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 17:22:14 +0100 Subject: defer.gatherResults rather than doing all the pokes in series --- synapse/push/pusherpool.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 7b1ce81e9a..8da444179c 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -123,10 +123,17 @@ class PusherPool: users_affected = yield self.store.get_push_action_users_in_range( min_stream_id, max_stream_id ) + + deferreds = [] + for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): - yield p.on_new_notifications(min_stream_id, max_stream_id) + deferreds.append( + p.on_new_notifications(min_stream_id, max_stream_id) + ) + + yield defer.gatherResults(deferreds) except: logger.exception("Exception in pusher on_new_notifications") @@ -141,10 +148,17 @@ class PusherPool: ) # This returns a tuple, user_id is at index 3 users_affected = set([r[3] for r in updated_receipts]) + + deferreds = [] + for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): - yield p.on_new_receipts(min_stream_id, max_stream_id) + deferreds.append( + p.on_new_receipts(min_stream_id, max_stream_id) + ) + + yield defer.gatherResults(deferreds) except: logger.exception("Exception in pusher on_new_receipts") -- cgit 1.4.1 From 6ec02e9ecf7ffe3d3737a69a480939a07d62428b Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 17:24:05 +0100 Subject: indenting --- synapse/push/pusherpool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 8da444179c..ba513601e7 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -155,7 +155,7 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( - p.on_new_receipts(min_stream_id, max_stream_id) + p.on_new_receipts(min_stream_id, max_stream_id) ) yield defer.gatherResults(deferreds) -- cgit 1.4.1 From 15e0f1696f2556f72b65f14466df51e9a9f00c4b Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 17:31:08 +0100 Subject: Wrap process in a flag so we don't process whist already processing. --- synapse/push/httppusher.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'synapse') diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index ea45b603c6..a0d0234e2e 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -48,6 +48,7 @@ class HttpPusher(object): self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.failing_since = pusherdict['failing_since'] self.timed_call = None + self.processing = False # This is the highest stream ordering we know it's safe to process. # When new events arrive, we'll be given a window of new events: we @@ -109,6 +110,14 @@ class HttpPusher(object): @defer.inlineCallbacks def _process(self): + try: + self.processing = True + yield self._unsafe_process() + finally: + self.processing = False + + @defer.inlineCallbacks + def _unsafe_process(self): unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) -- cgit 1.4.1 From 3fb35cbd6fc52905c88344fd3ea55a4ee1d1c478 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 17:33:37 +0100 Subject: Oops, inequality fail --- synapse/storage/event_push_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 4d72e4a85e..355478957d 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -105,7 +105,7 @@ class EventPushActionsStore(SQLBaseStore): def f(txn): sql = ( "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" - " stream_ordering >= ? AND stream_ordering >= ?" + " stream_ordering >= ? AND stream_ordering <= ?" ) txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn.fetchall()] -- cgit 1.4.1 From a4a31fa8dca0199e8958b71a39b749eb4f6efb9e Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 17:37:19 +0100 Subject: Only pass in what we need --- synapse/push/httppusher.py | 8 +++++--- synapse/push/push_tools.py | 16 ++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) (limited to 'synapse') diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index a0d0234e2e..9f51106d0f 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -96,7 +96,9 @@ class HttpPusher(object): # We could check the receipts are actually m.read receipts here, # but currently that's the only type of receipt anyway... with Measure(self.clock, "push.on_new_receipts"): - badge = yield push_tools.get_badge_count(self.hs, self.user_id) + badge = yield push_tools.get_badge_count( + self.hs.get_datastore(), self.user_id + ) yield self.send_badge(badge) @defer.inlineCallbacks @@ -185,7 +187,7 @@ class HttpPusher(object): defer.returnValue(True) tweaks = push_rule_evaluator.tweaks_for_actions(push_action['actions']) - badge = yield push_tools.get_badge_count(self.hs, self.user_id) + badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) event = yield self.store.get_event(push_action['event_id'], allow_none=True) if event is None: @@ -215,7 +217,7 @@ class HttpPusher(object): @defer.inlineCallbacks def _build_notification_dict(self, event, tweaks, badge): - ctx = yield push_tools.get_context_for_event(self.hs, event) + ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event) d = { 'notification': { diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index e1e61e49e8..e71d01e77d 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -17,13 +17,13 @@ from twisted.internet import defer @defer.inlineCallbacks -def get_badge_count(hs, user_id): +def get_badge_count(store, user_id): invites, joins = yield defer.gatherResults([ - hs.get_datastore().get_invited_rooms_for_user(user_id), - hs.get_datastore().get_rooms_for_user(user_id), + store.get_invited_rooms_for_user(user_id), + store.get_rooms_for_user(user_id), ], consumeErrors=True) - my_receipts_by_room = yield hs.get_datastore().get_receipts_for_user( + my_receipts_by_room = yield store.get_receipts_for_user( user_id, "m.read", ) @@ -34,7 +34,7 @@ def get_badge_count(hs, user_id): last_unread_event_id = my_receipts_by_room[r.room_id] notifs = yield ( - hs.get_datastore().get_unread_event_push_actions_by_room_for_user( + store.get_unread_event_push_actions_by_room_for_user( r.room_id, user_id, last_unread_event_id ) ) @@ -43,8 +43,8 @@ def get_badge_count(hs, user_id): @defer.inlineCallbacks -def get_context_for_event(hs, ev): - name_aliases = yield hs.get_datastore().get_room_name_and_aliases( +def get_context_for_event(store, ev): + name_aliases = yield store.get_room_name_and_aliases( ev.room_id ) @@ -52,7 +52,7 @@ def get_context_for_event(hs, ev): if name_aliases[0] is not None: ctx['name'] = name_aliases[0] - their_member_events_for_room = yield hs.get_datastore().get_current_state( + their_member_events_for_room = yield store.get_current_state( room_id=ev.room_id, event_type='m.room.member', state_key=ev.user_id -- cgit 1.4.1 From 4836864f5681fbcca34a4c40384a7c4c8309b4e2 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 17:38:48 +0100 Subject: generate id in the main thread --- synapse/storage/pusher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index f7886dd1bb..b314e3ab4f 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -124,9 +124,9 @@ class PusherStore(SQLBaseStore): app_display_name, device_display_name, pushkey, pushkey_ts, lang, data, last_stream_ordering, profile_tag=""): - def f(txn): - txn.call_after(self.get_users_with_pushers_in_room.invalidate_all) - with self._pushers_id_gen.get_next() as stream_id: + with self._pushers_id_gen.get_next() as stream_id: + def f(txn): + txn.call_after(self.get_users_with_pushers_in_room.invalidate_all) return self._simple_upsert_txn( txn, "pushers", -- cgit 1.4.1 From d9f38561c8855fa6893868069f0ec00d802618df Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 7 Apr 2016 17:45:01 +0100 Subject: Literally a dictionary --- synapse/storage/pusher.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index b314e3ab4f..b34a30a8fb 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -130,23 +130,23 @@ class PusherStore(SQLBaseStore): return self._simple_upsert_txn( txn, "pushers", - dict( - app_id=app_id, - pushkey=pushkey, - user_name=user_id, - ), - dict( - access_token=access_token, - kind=kind, - app_display_name=app_display_name, - device_display_name=device_display_name, - ts=pushkey_ts, - lang=lang, - data=encode_canonical_json(data), - last_stream_ordering=last_stream_ordering, - profile_tag=profile_tag, - id=stream_id, - ), + { + "app_id": app_id, + "pushkey": pushkey, + "user_name": user_id, + }, + { + "access_token": access_token, + "kind": kind, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "ts": pushkey_ts, + "lang": lang, + "data": encode_canonical_json(data), + "last_stream_ordering": last_stream_ordering, + "profile_tag": profile_tag, + "id": stream_id, + }, ) defer.returnValue((yield self.runInteraction("add_pusher", f))) -- cgit 1.4.1 From 86be915ccef824176b4aa127e5d62d30e00bb6b7 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 7 Apr 2016 18:11:49 +0100 Subject: Call profile handler get_displayname directly rather than using collect_presencelike_data --- synapse/handlers/message.py | 10 +--------- synapse/handlers/profile.py | 28 ---------------------------- 2 files changed, 1 insertion(+), 37 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 10608c0dd9..fa78d4acec 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -34,10 +34,6 @@ import logging logger = logging.getLogger(__name__) -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - class MessageHandler(BaseHandler): def __init__(self, hs): @@ -202,12 +198,8 @@ class MessageHandler(BaseHandler): membership = builder.content.get("membership", None) target = UserID.from_string(builder.state_key) - if membership == Membership.JOIN: + if membership in {Membership.JOIN, Membership.INVITE}: # If event doesn't include a display name, add one. - yield collect_presencelike_data( - self.distributor, target, builder.content - ) - elif membership == Membership.INVITE: profile = self.hs.get_handlers().profile_handler content = builder.content diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b45eafbb49..165e020694 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,7 +17,6 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.types import UserID, Requester -from synapse.util import unwrapFirstError from ._base import BaseHandler @@ -31,10 +30,6 @@ def changed_presencelike_data(distributor, user, state): return distributor.fire("changed_presencelike_data", user, state) -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - class ProfileHandler(BaseHandler): def __init__(self, hs): @@ -48,15 +43,10 @@ class ProfileHandler(BaseHandler): distributor = hs.get_distributor() self.distributor = distributor - distributor.declare("collect_presencelike_data") distributor.declare("changed_presencelike_data") distributor.observe("registered_user", self.registered_user) - distributor.observe( - "collect_presencelike_data", self.collect_presencelike_data - ) - def registered_user(self, user): return self.store.create_profile(user.localpart) @@ -158,24 +148,6 @@ class ProfileHandler(BaseHandler): yield self._update_join_states(requester) - @defer.inlineCallbacks - def collect_presencelike_data(self, user, state): - if not self.hs.is_mine(user): - defer.returnValue(None) - - (displayname, avatar_url) = yield defer.gatherResults( - [ - self.store.get_profile_displayname(user.localpart), - self.store.get_profile_avatar_url(user.localpart), - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - - state["displayname"] = displayname - state["avatar_url"] = avatar_url - - defer.returnValue(None) - @defer.inlineCallbacks def on_profile_query(self, args): user = UserID.from_string(args["user_id"]) -- cgit 1.4.1 From caef3375874d6ad7c5e55fc59d7e9b9a79952bd9 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 7 Apr 2016 17:46:47 +0100 Subject: changed_presencelike_data isn't observed anywhere in synapse so can be removed --- synapse/handlers/profile.py | 15 --------------- 1 file changed, 15 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 165e020694..e37409170d 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -26,10 +26,6 @@ import logging logger = logging.getLogger(__name__) -def changed_presencelike_data(distributor, user, state): - return distributor.fire("changed_presencelike_data", user, state) - - class ProfileHandler(BaseHandler): def __init__(self, hs): @@ -41,9 +37,6 @@ class ProfileHandler(BaseHandler): ) distributor = hs.get_distributor() - self.distributor = distributor - - distributor.declare("changed_presencelike_data") distributor.observe("registered_user", self.registered_user) @@ -95,10 +88,6 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_displayname ) - yield changed_presencelike_data(self.distributor, target_user, { - "displayname": new_displayname, - }) - yield self._update_join_states(requester) @defer.inlineCallbacks @@ -142,10 +131,6 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_avatar_url ) - yield changed_presencelike_data(self.distributor, target_user, { - "avatar_url": new_avatar_url, - }) - yield self._update_join_states(requester) @defer.inlineCallbacks -- cgit 1.4.1 From b9ee5650b0027b664aa700a7ce451a546f404350 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 8 Apr 2016 11:01:38 +0100 Subject: Move all the wrapper functions for distributor.fire Move the functions inside the distributor and import them where needed. This reduces duplication and makes it possible for flake8 to detect when the functions aren't used in a given file. --- synapse/handlers/federation.py | 5 +---- synapse/handlers/register.py | 5 +---- synapse/handlers/room.py | 15 --------------- synapse/handlers/room_member.py | 16 +--------------- synapse/util/distributor.py | 22 +++++++++++++++++++++- 5 files changed, 24 insertions(+), 39 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eb02f0e000..c28226f840 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -40,6 +40,7 @@ from synapse.events.utils import prune_event from synapse.util.retryutils import NotRetryingDestination from synapse.push.action_generator import ActionGenerator +from synapse.util.distributor import user_joined_room from twisted.internet import defer @@ -49,10 +50,6 @@ import logging logger = logging.getLogger(__name__) -def user_joined_room(distributor, user, room_id): - return distributor.fire("user_joined_room", user, room_id) - - class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f287ee247b..b0862067e1 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -23,6 +23,7 @@ from synapse.api.errors import ( from ._base import BaseHandler from synapse.util.async import run_on_reactor from synapse.http.client import CaptchaServerHttpClient +from synapse.util.distributor import registered_user import logging import urllib @@ -30,10 +31,6 @@ import urllib logger = logging.getLogger(__name__) -def registered_user(distributor, user): - return distributor.fire("registered_user", user) - - class RegistrationHandler(BaseHandler): def __init__(self, hs): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3e1d9282d7..ea306cd42a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,7 +25,6 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.util import stringutils from synapse.util.async import concurrently_execute -from synapse.util.logcontext import preserve_context_over_fn from synapse.util.caches.response_cache import ResponseCache from collections import OrderedDict @@ -39,20 +38,6 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) - - -def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) - - class RoomCreationHandler(BaseHandler): PRESETS_DICT = { diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b6ef3c91af..753c75d9c1 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -23,8 +23,8 @@ from synapse.api.constants import ( EventTypes, Membership, ) from synapse.api.errors import AuthError, SynapseError, Codes -from synapse.util.logcontext import preserve_context_over_fn from synapse.util.async import Linearizer +from synapse.util.distributor import user_left_room, user_joined_room from signedjson.sign import verify_signed_json from signedjson.key import decode_verify_key_bytes @@ -38,20 +38,6 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) - - -def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) - - class RoomMemberHandler(BaseHandler): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 8875813de4..d7cccc06b1 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -15,7 +15,9 @@ from twisted.internet import defer -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_context_over_fn +) from synapse.util import unwrapFirstError @@ -25,6 +27,24 @@ import logging logger = logging.getLogger(__name__) +def registered_user(distributor, user): + return distributor.fire("registered_user", user) + + +def user_left_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_left_room", user=user, room_id=room_id + ) + + +def user_joined_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_joined_room", user=user, room_id=room_id + ) + + class Distributor(object): """A central dispatch point for loosely-connected pieces of code to register, observe, and fire signals. -- cgit 1.4.1 From 7e2f971c08250cf432d43dd6244faefb2074ff8c Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 8 Apr 2016 14:01:56 +0100 Subject: Remove some unused functions (#711) * Remove some unused functions * get_room_events_stream is only used in tests * is_exclusive_room might actually be something we want --- synapse/appservice/api.py | 5 - synapse/handlers/message.py | 29 ------ synapse/handlers/room_member.py | 13 --- synapse/storage/_base.py | 6 -- synapse/storage/prepare_database.py | 12 --- synapse/storage/presence.py | 10 -- synapse/storage/roommember.py | 33 ------- synapse/storage/stream.py | 90 ------------------ synapse/util/__init__.py | 3 - synapse/util/ratelimitutils.py | 14 --- synapse/util/stringutils.py | 4 - tests/storage/test_presence.py | 27 ------ tests/storage/test_redaction.py | 51 +--------- tests/storage/test_roommember.py | 7 -- tests/storage/test_stream.py | 185 ------------------------------------ 15 files changed, 4 insertions(+), 485 deletions(-) delete mode 100644 tests/storage/test_stream.py (limited to 'synapse') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index bc90605324..6da6a1b62e 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -100,11 +100,6 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("push_bulk to %s threw exception %s", uri, ex) defer.returnValue(False) - @defer.inlineCallbacks - def push(self, service, event, txn_id=None): - response = yield self.push_bulk(service, [event], txn_id) - defer.returnValue(response) - def _serialize(self, events): time_now = self.clock.time_msec() return [ diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index fa78d4acec..f51feda2f4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -44,35 +44,6 @@ class MessageHandler(BaseHandler): self.validator = EventValidator() self.snapshot_cache = SnapshotCache() - @defer.inlineCallbacks - def get_message(self, msg_id=None, room_id=None, sender_id=None, - user_id=None): - """ Retrieve a message. - - Args: - msg_id (str): The message ID to obtain. - room_id (str): The room where the message resides. - sender_id (str): The user ID of the user who sent the message. - user_id (str): The user ID of the user making this request. - Returns: - The message, or None if no message exists. - Raises: - SynapseError if something went wrong. - """ - yield self.auth.check_joined_room(room_id, user_id) - - # Pull out the message from the db -# msg = yield self.store.get_message( -# room_id=room_id, -# msg_id=msg_id, -# user_id=sender_id -# ) - - # TODO (erikj): Once we work out the correct c-s api we need to think - # on how to do this. - - defer.returnValue(None) - @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, as_client_event=True): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 753c75d9c1..b69f36aefe 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -392,19 +392,6 @@ class RoomMemberHandler(BaseHandler): and guest_access.content["guest_access"] == "can_join" ) - def _should_do_dance(self, current_state, inviter, room_hosts=None): - # TODO: Shouldn't this be remote_room_host? - room_hosts = room_hosts or [] - - is_host_in_room = self.is_host_in_room(current_state) - if is_host_in_room: - return False, room_hosts - - if inviter and not self.hs.is_mine(inviter): - room_hosts.append(inviter.domain) - - return True, room_hosts - @defer.inlineCallbacks def lookup_room_alias(self, room_alias): """ diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 04d7fcf6d6..1e27c2c0ce 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -810,12 +810,6 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) - def get_next_stream_id(self): - with self._next_stream_id_lock: - i = self._next_stream_id - self._next_stream_id += 1 - return i - def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): # Fetch a mapping of room_id -> max stream position for "recent" rooms. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 00833422af..57f14fd12b 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -30,18 +30,6 @@ SCHEMA_VERSION = 31 dir_path = os.path.abspath(os.path.dirname(__file__)) -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - class PrepareDatabaseException(Exception): pass diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 59b4ef5ce6..07f5fae8dd 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -176,16 +176,6 @@ class PresenceStore(SQLBaseStore): desc="disallow_presence_visible", ) - def is_presence_visible(self, observed_localpart, observer_userid): - return self._simple_select_one( - table="presence_allow_inbound", - keyvalues={"observed_user_id": observed_localpart, - "observer_user_id": observer_userid}, - retcols=["observed_user_id"], - allow_none=True, - desc="is_presence_visible", - ) - def add_presence_list_pending(self, observer_localpart, observed_userid): return self._simple_insert( table="presence_list", diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 66e7a40e3c..77518e893f 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -121,26 +121,6 @@ class RoomMemberStore(SQLBaseStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) - def get_room_member(self, user_id, room_id): - """Retrieve the current state of a room member. - - Args: - user_id (str): The member's user ID. - room_id (str): The room the member is in. - Returns: - Deferred: Results in a MembershipEvent or None. - """ - return self.runInteraction( - "get_room_member", - self._get_members_events_txn, - room_id, - user_id=user_id, - ).addCallback( - self._get_events - ).addCallback( - lambda events: events[0] if events else None - ) - @cached(max_entries=5000) def get_users_in_room(self, room_id): def f(txn): @@ -203,19 +183,6 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(invite) defer.returnValue(None) - def get_leave_and_ban_events_for_user(self, user_id): - """ Get all the leave events for a user - Args: - user_id (str): The user ID. - Returns: - A deferred list of event objects. - """ - return self.get_rooms_for_user_where_membership_is( - user_id, (Membership.LEAVE, Membership.BAN) - ).addCallback(lambda leaves: self._get_events([ - leave.event_id for leave in leaves - ])) - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user matches one in the membership list. diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 76bcd9cd00..95b12559a6 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -303,96 +303,6 @@ class StreamStore(SQLBaseStore): defer.returnValue(ret) - def get_room_events_stream( - self, - user_id, - from_key, - to_key, - limit=0, - is_guest=False, - room_ids=None - ): - room_ids = room_ids or [] - room_ids = [r for r in room_ids] - if is_guest: - current_room_membership_sql = ( - "SELECT c.room_id FROM history_visibility AS h" - " INNER JOIN current_state_events AS c" - " ON h.event_id = c.event_id" - " WHERE c.room_id IN (%s)" - " AND h.history_visibility = 'world_readable'" % ( - ",".join(map(lambda _: "?", room_ids)) - ) - ) - current_room_membership_args = room_ids - else: - current_room_membership_sql = ( - "SELECT m.room_id FROM room_memberships as m " - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id AND c.state_key = m.user_id" - " WHERE m.user_id = ? AND m.membership = 'join'" - ) - current_room_membership_args = [user_id] - - # We also want to get any membership events about that user, e.g. - # invites or leave notifications. - membership_sql = ( - "SELECT m.event_id FROM room_memberships as m " - "INNER JOIN current_state_events as c ON m.event_id = c.event_id " - "WHERE m.user_id = ? " - ) - membership_args = [user_id] - - if limit: - limit = max(limit, MAX_STREAM_SIZE) - else: - limit = MAX_STREAM_SIZE - - # From and to keys should be integers from ordering. - from_id = RoomStreamToken.parse_stream_token(from_key) - to_id = RoomStreamToken.parse_stream_token(to_key) - - if from_key == to_key: - return defer.succeed(([], to_key)) - - sql = ( - "SELECT e.event_id, e.stream_ordering FROM events AS e WHERE " - "(e.outlier = ? AND (room_id IN (%(current)s)) OR " - "(event_id IN (%(invites)s))) " - "AND e.stream_ordering > ? AND e.stream_ordering <= ? " - "ORDER BY stream_ordering ASC LIMIT %(limit)d " - ) % { - "current": current_room_membership_sql, - "invites": membership_sql, - "limit": limit - } - - def f(txn): - args = ([False] + current_room_membership_args + membership_args + - [from_id.stream, to_id.stream]) - txn.execute(sql, args) - - rows = self.cursor_to_dict(txn) - - ret = self._get_events_txn( - txn, - [r["event_id"] for r in rows], - get_prev_content=True - ) - - self._set_before_and_after(ret, rows) - - if rows: - key = "s%d" % max(r["stream_ordering"] for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = to_key - - return ret, key - - return self.runInteraction("get_room_events_stream", f) - @defer.inlineCallbacks def paginate_room_events(self, room_id, from_key, to_key=None, direction='b', limit=-1): diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 3b9da5b34a..b462495eb8 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -49,9 +49,6 @@ class Clock(object): l.start(msec / 1000.0, now=False) return l - def stop_looping_call(self, loop): - loop.stop() - def call_later(self, delay, callback, *args, **kwargs): """Call something later diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 4076eed269..1101881a2d 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -100,20 +100,6 @@ class _PerHostRatelimiter(object): self.current_processing = set() self.request_times = [] - def is_empty(self): - time_now = self.clock.time_msec() - self.request_times[:] = [ - r for r in self.request_times - if time_now - r < self.window_size - ] - - return not ( - self.ready_request_queue - or self.sleeping_requests - or self.current_processing - or self.request_times - ) - @contextlib.contextmanager def ratelimit(self): # `contextlib.contextmanager` takes a generator and turns it into a diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index b490bb8725..a100f151d4 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -21,10 +21,6 @@ _string_with_symbols = ( ) -def origin_from_ucid(ucid): - return ucid.split("@", 1)[1] - - def random_string(length): return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index ec78f007ca..63203cea35 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -34,33 +34,6 @@ class PresenceStoreTestCase(unittest.TestCase): self.u_apple = UserID.from_string("@apple:test") self.u_banana = UserID.from_string("@banana:test") - @defer.inlineCallbacks - def test_visibility(self): - self.assertFalse((yield self.store.is_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ))) - - yield self.store.allow_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ) - - self.assertTrue((yield self.store.is_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ))) - - yield self.store.disallow_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ) - - self.assertFalse((yield self.store.is_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ))) - @defer.inlineCallbacks def test_presence_list(self): self.assertEquals( diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 5880409867..6afaca3a61 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -110,22 +110,10 @@ class RedactionTestCase(unittest.TestCase): self.room1, self.u_alice, Membership.JOIN ) - start = yield self.store.get_room_events_max_id() - msg_event = yield self.inject_message(self.room1, self.u_alice, u"t") - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - # Check event has not been redacted: - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertObjectHasAttributes( { @@ -144,17 +132,7 @@ class RedactionTestCase(unittest.TestCase): self.room1, msg_event.event_id, self.u_alice, reason ) - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - # Check redaction - - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertEqual(msg_event.event_id, event.event_id) @@ -184,25 +162,12 @@ class RedactionTestCase(unittest.TestCase): self.room1, self.u_alice, Membership.JOIN ) - start = yield self.store.get_room_events_max_id() - msg_event = yield self.inject_room_member( self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}, ) - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - # Check event has not been redacted: - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertObjectHasAttributes( { @@ -221,17 +186,9 @@ class RedactionTestCase(unittest.TestCase): self.room1, msg_event.event_id, self.u_alice, reason ) - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - # Check redaction - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertTrue("redacted_because" in event.unsigned) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index b029ff0584..997090fe35 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -70,13 +70,6 @@ class RoomMemberStoreTestCase(unittest.TestCase): def test_one_member(self): yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) - self.assertEquals( - Membership.JOIN, - (yield self.store.get_room_member( - user_id=self.u_alice.to_string(), - room_id=self.room.to_string(), - )).membership - ) self.assertEquals( [self.u_alice.to_string()], [m.user_id for m in ( diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py deleted file mode 100644 index da322152c7..0000000000 --- a/tests/storage/test_stream.py +++ /dev/null @@ -1,185 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from tests import unittest -from twisted.internet import defer - -from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID, RoomID -from tests.storage.event_injector import EventInjector - -from tests.utils import setup_test_homeserver - -from mock import Mock - - -class StreamStoreTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver( - resource_for_federation=Mock(), - http_client=None, - ) - - self.store = hs.get_datastore() - self.event_builder_factory = hs.get_event_builder_factory() - self.event_injector = EventInjector(hs) - self.handlers = hs.get_handlers() - self.message_handler = self.handlers.message_handler - - self.u_alice = UserID.from_string("@alice:test") - self.u_bob = UserID.from_string("@bob:test") - - self.room1 = RoomID.from_string("!abc123:test") - self.room2 = RoomID.from_string("!xyx987:test") - - @defer.inlineCallbacks - def test_event_stream_get_other(self): - # Both bob and alice joins the room - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - # Initial stream key: - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_bob.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - event = results[0] - - self.assertObjectHasAttributes( - { - "type": EventTypes.Message, - "user_id": self.u_alice.to_string(), - "content": {"body": "test", "msgtype": "message"}, - }, - event, - ) - - @defer.inlineCallbacks - def test_event_stream_get_own(self): - # Both bob and alice joins the room - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - # Initial stream key: - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - event = results[0] - - self.assertObjectHasAttributes( - { - "type": EventTypes.Message, - "user_id": self.u_alice.to_string(), - "content": {"body": "test", "msgtype": "message"}, - }, - event, - ) - - @defer.inlineCallbacks - def test_event_stream_join_leave(self): - # Both bob and alice joins the room - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - # Then bob leaves again. - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.LEAVE - ) - - # Initial stream key: - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_bob.to_string(), - start, - end, - ) - - # We should not get the message, as it happened *after* bob left. - self.assertEqual(0, len(results)) - - @defer.inlineCallbacks - def test_event_stream_prev_content(self): - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN, - ) - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_bob.to_string(), - start, - end, - ) - - # We should not get the message, as it happened *after* bob left. - self.assertEqual(1, len(results)) - - event = results[0] - - self.assertTrue( - "prev_content" in event.unsigned, - msg="No prev_content key" - ) -- cgit 1.4.1 From ce3fe52498547144191fadf8ff0a8cb6d244334e Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 8 Apr 2016 14:02:38 +0100 Subject: Comment why unsafe process is unsafe --- synapse/push/httppusher.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'synapse') diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 9f51106d0f..685c5e48df 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -120,6 +120,11 @@ class HttpPusher(object): @defer.inlineCallbacks def _unsafe_process(self): + """ + Looks for unset notifications and dispatch them, in order + Never call this directly: use _process which will only allow this to + run once per pusher. + """ unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) -- cgit 1.4.1 From 52d1008661cebe8551bffd97b938369550851bc6 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 8 Apr 2016 14:06:54 +0100 Subject: Unsafe process should call itself if the max has changed --- synapse/push/httppusher.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'synapse') diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 685c5e48df..6ef8bf62b3 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -125,6 +125,7 @@ class HttpPusher(object): Never call this directly: use _process which will only allow this to run once per pusher. """ + starting_max_ordering = self.max_stream_ordering unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) @@ -185,6 +186,8 @@ class HttpPusher(object): self.timed_call = reactor.callLater(self.backoff_delay, self.on_timer) self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC) break + if self.max_stream_ordering != starting_max_ordering: + self._unsafe_process() @defer.inlineCallbacks def _process_one(self, push_action): -- cgit 1.4.1 From 7b6d51948241090799d3bb9f52a9082b29a27d73 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 8 Apr 2016 14:08:16 +0100 Subject: Make sure max stream ordering only increases --- synapse/push/httppusher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 6ef8bf62b3..38b758e6af 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -86,7 +86,7 @@ class HttpPusher(object): @defer.inlineCallbacks def on_new_notifications(self, min_stream_ordering, max_stream_ordering): with Measure(self.clock, "push.on_new_notifications"): - self.max_stream_ordering = max_stream_ordering + self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) yield self._process() @defer.inlineCallbacks -- cgit 1.4.1 From ed3979df5faac6d63990f4230662ff8cdcf59584 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 8 Apr 2016 15:29:59 +0100 Subject: Fix invite pushes * If the event is an invite event, add the invitee to list of user we run push rules for (if they have a pusher etc) * Move invite_for_me to be higher prio than member events otherwise member events matches them * Spell override right --- synapse/push/action_generator.py | 6 +-- synapse/push/baserules.py | 72 ++++++++++++++++---------------- synapse/push/bulk_push_rule_evaluator.py | 12 +++++- synapse/storage/pusher.py | 7 ++++ 4 files changed, 58 insertions(+), 39 deletions(-) (limited to 'synapse') diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 84efcdd184..59e512f507 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from .bulk_push_rule_evaluator import evaluator_for_room_id +from .bulk_push_rule_evaluator import evaluator_for_event import logging @@ -35,8 +35,8 @@ class ActionGenerator: @defer.inlineCallbacks def handle_push_actions_for_event(self, event, context, handler): - bulk_evaluator = yield evaluator_for_room_id( - event.room_id, self.hs, self.store + bulk_evaluator = yield evaluator_for_event( + event, self.hs, self.store ) actions_by_user = yield bulk_evaluator.action_for_event_by_user( diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 6add94beeb..8a174feeaf 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -79,7 +79,7 @@ def make_base_append_rules(kind, modified_base_rules): rules = [] if kind == 'override': - rules = BASE_APPEND_OVRRIDE_RULES + rules = BASE_APPEND_OVERRIDE_RULES elif kind == 'underride': rules = BASE_APPEND_UNDERRIDE_RULES elif kind == 'content': @@ -148,7 +148,7 @@ BASE_PREPEND_OVERRIDE_RULES = [ ] -BASE_APPEND_OVRRIDE_RULES = [ +BASE_APPEND_OVERRIDE_RULES = [ { 'rule_id': 'global/override/.m.rule.suppress_notices', 'conditions': [ @@ -163,6 +163,40 @@ BASE_APPEND_OVRRIDE_RULES = [ 'dont_notify', ] }, + # NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event + # otherwise invites will be matched by .m.rule.member_event + { + 'rule_id': 'global/underride/.m.rule.invite_for_me', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.member', + '_id': '_member', + }, + { + 'kind': 'event_match', + 'key': 'content.membership', + 'pattern': 'invite', + '_id': '_invite_member', + }, + { + 'kind': 'event_match', + 'key': 'state_key', + 'pattern_type': 'user_id' + }, + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'default' + }, { + 'set_tweak': 'highlight', + 'value': False + } + ] + }, # Will we sometimes want to know about people joining and leaving? # Perhaps: if so, this could be expanded upon. Seems the most usual case # is that we don't though. We add this override rule so that even if @@ -251,38 +285,6 @@ BASE_APPEND_UNDERRIDE_RULES = [ } ] }, - { - 'rule_id': 'global/underride/.m.rule.invite_for_me', - 'conditions': [ - { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.member', - '_id': '_member', - }, - { - 'kind': 'event_match', - 'key': 'content.membership', - 'pattern': 'invite', - '_id': '_invite_member', - }, - { - 'kind': 'event_match', - 'key': 'state_key', - 'pattern_type': 'user_id' - }, - ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] - }, { 'rule_id': 'global/underride/.m.rule.message', 'conditions': [ @@ -315,7 +317,7 @@ for r in BASE_PREPEND_OVERRIDE_RULES: r['default'] = True BASE_RULE_IDS.add(r['rule_id']) -for r in BASE_APPEND_OVRRIDE_RULES: +for r in BASE_APPEND_OVERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True BASE_RULE_IDS.add(r['rule_id']) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7f94591dcb..49216f0c15 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -69,7 +69,8 @@ def _get_rules(room_id, user_ids, store): @defer.inlineCallbacks -def evaluator_for_room_id(room_id, hs, store): +def evaluator_for_event(event, hs, store): + room_id = event.room_id users_with_pushers = yield store.get_users_with_pushers_in_room(room_id) receipts = yield store.get_receipts_for_room(room_id, "m.read") @@ -79,6 +80,15 @@ def evaluator_for_room_id(room_id, hs, store): if hs.is_mine_id(r['user_id']): user_ids.add(r['user_id']) + # if this event is an invite event, we may need to run rules for the user + # who's been invited, otherwise they won't get told they've been invited + if event.type == 'm.room.member' and event.content['membership'] == 'invite': + invited_user = event.state_key + if invited_user and hs.is_mine_id(invited_user): + has_pusher = yield store.user_has_pusher(invited_user) + if has_pusher: + user_ids.add(invited_user) + user_ids = list(user_ids) rules_by_user = yield _get_rules(room_id, user_ids, store) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index b34a30a8fb..19888a8e76 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -49,6 +49,13 @@ class PusherStore(SQLBaseStore): return rows + @defer.inlineCallbacks + def user_has_pusher(self, user_id): + ret = yield self._simple_select_one_onecol( + "pushers", {"user_name": user_id}, "id", allow_none=True + ) + defer.returnValue(ret is not None) + @defer.inlineCallbacks def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): def r(txn): -- cgit 1.4.1 From d96a070a3a6da4e2ff868f656a28f1bfd5f3ea82 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 8 Apr 2016 16:49:39 +0100 Subject: Actually check if we;re processing --- synapse/push/httppusher.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse') diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 38b758e6af..b3b11c5f43 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -112,6 +112,8 @@ class HttpPusher(object): @defer.inlineCallbacks def _process(self): + if self.processing: + return try: self.processing = True yield self._unsafe_process() -- cgit 1.4.1 From dafef5a688b8684232346a26a789a2da600ec58e Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Fri, 8 Apr 2016 18:37:15 +0100 Subject: Add url_preview_enabled config option to turn on/off preview_url endpoint. defaults to off. Add url_preview_ip_range_blacklist to let admins specify internal IP ranges that must not be spidered. Add url_preview_url_blacklist to let admins specify URL patterns that must not be spidered. Implement a custom SpiderEndpoint and associated support classes to implement url_preview_ip_range_blacklist Add commentary and generally address PR feedback --- synapse/config/repository.py | 77 +++++++++++++++++++++++++-- synapse/http/client.py | 44 +++++++++++++-- synapse/http/endpoint.py | 35 +++++++++++- synapse/python_dependencies.py | 7 ++- synapse/rest/media/v1/media_repository.py | 7 ++- synapse/rest/media/v1/preview_url_resource.py | 75 ++++++++++++++++++++------ 6 files changed, 214 insertions(+), 31 deletions(-) (limited to 'synapse') diff --git a/synapse/config/repository.py b/synapse/config/repository.py index f4ab705701..da1007d767 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -16,6 +16,8 @@ from ._base import Config from collections import namedtuple +import sys + ThumbnailRequirement = namedtuple( "ThumbnailRequirement", ["width", "height", "method", "media_type"] ) @@ -23,7 +25,7 @@ ThumbnailRequirement = namedtuple( def parse_thumbnail_requirements(thumbnail_sizes): """ Takes a list of dictionaries with "width", "height", and "method" keys - and creates a map from image media types to the thumbnail size, thumnailing + and creates a map from image media types to the thumbnail size, thumbnailing method, and thumbnail media type to precalculate Args: @@ -60,6 +62,18 @@ class ContentRepositoryConfig(Config): self.thumbnail_requirements = parse_thumbnail_requirements( config["thumbnail_sizes"] ) + self.url_preview_enabled = config["url_preview_enabled"] + if self.url_preview_enabled: + try: + from netaddr import IPSet + if "url_preview_ip_range_blacklist" in config: + self.url_preview_ip_range_blacklist = IPSet( + config["url_preview_ip_range_blacklist"] + ) + if "url_preview_url_blacklist" in config: + self.url_preview_url_blacklist = config["url_preview_url_blacklist"] + except ImportError: + sys.stderr.write("\nmissing netaddr dep - disabling preview_url API\n") def default_config(self, **kwargs): media_store = self.default_path("media_store") @@ -74,9 +88,6 @@ class ContentRepositoryConfig(Config): # The largest allowed upload size in bytes max_upload_size: "10M" - # The largest allowed URL preview spidering size in bytes - max_spider_size: "10M" - # Maximum number of pixels that will be thumbnailed max_image_pixels: "32M" @@ -104,4 +115,62 @@ class ContentRepositoryConfig(Config): - width: 800 height: 600 method: scale + + # Is the preview URL API enabled? If enabled, you *must* specify + # an explicit url_preview_ip_range_blacklist of IPs that the spider is + # denied from accessing. + url_preview_enabled: False + + # List of IP address CIDR ranges that the URL preview spider is denied + # from accessing. There are no defaults: you must explicitly + # specify a list for URL previewing to work. You should specify any + # internal services in your network that you do not want synapse to try + # to connect to, otherwise anyone in any Matrix room could cause your + # synapse to issue arbitrary GET requests to your internal services, + # causing serious security issues. + # + # url_preview_ip_range_blacklist: + # - '127.0.0.0/8' + # - '10.0.0.0/8' + # - '172.16.0.0/12' + # - '192.168.0.0/16' + + # Optional list of URL matches that the URL preview spider is + # denied from accessing. You should use url_preview_ip_range_blacklist + # in preference to this, otherwise someone could define a public DNS + # entry that points to a private IP address and circumvent the blacklist. + # This is more useful if you know there is an entire shape of URL that + # you know that will never want synapse to try to spider. + # + # Each list entry is a dictionary of url component attributes as returned + # by urlparse.urlsplit as applied to the absolute form of the URL. See + # https://docs.python.org/2/library/urlparse.html#urlparse.urlsplit + # The values of the dictionary are treated as an filename match pattern + # applied to that component of URLs, unless they start with a ^ in which + # case they are treated as a regular expression match. If all the + # specified component matches for a given list item succeed, the URL is + # blacklisted. + # + # url_preview_url_blacklist: + # # blacklist any URL with a username in its URI + # - username: '*'' + # + # # blacklist all *.google.com URLs + # - netloc: 'google.com' + # - netloc: '*.google.com' + # + # # blacklist all plain HTTP URLs + # - scheme: 'http' + # + # # blacklist http(s)://www.acme.com/foo + # - netloc: 'www.acme.com' + # path: '/foo' + # + # # blacklist any URL with a literal IPv4 address + # - netloc: '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$' + + # The largest allowed URL preview spidering size in bytes + max_spider_size: "10M" + + """ % locals() diff --git a/synapse/http/client.py b/synapse/http/client.py index 442b4bb73d..3b8ffcd3ef 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -20,10 +20,12 @@ from synapse.api.errors import ( ) from synapse.util.logcontext import preserve_context_over_fn import synapse.metrics +from synapse.http.endpoint import SpiderEndpoint from canonicaljson import encode_canonical_json from twisted.internet import defer, reactor, ssl, protocol +from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint from twisted.web.client import ( BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, readBody, FileBodyProducer, PartialDownloadError, @@ -364,6 +366,35 @@ class CaptchaServerHttpClient(SimpleHttpClient): defer.returnValue(e.response) +class SpiderEndpointFactory(object): + def __init__(self, hs): + self.blacklist = hs.config.url_preview_ip_range_blacklist + self.policyForHTTPS = hs.get_http_client_context_factory() + + def endpointForURI(self, uri): + logger.info("Getting endpoint for %s", uri.toBytes()) + if uri.scheme == "http": + return SpiderEndpoint( + reactor, uri.host, uri.port, self.blacklist, + endpoint=TCP4ClientEndpoint, + endpoint_kw_args={ + 'timeout': 15 + }, + ) + elif uri.scheme == "https": + tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) + return SpiderEndpoint( + reactor, uri.host, uri.port, self.blacklist, + endpoint=SSL4ClientEndpoint, + endpoint_kw_args={ + 'sslContextFactory': tlsPolicy, + 'timeout': 15 + }, + ) + else: + logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme) + + class SpiderHttpClient(SimpleHttpClient): """ Separate HTTP client for spidering arbitrary URLs. @@ -375,11 +406,14 @@ class SpiderHttpClient(SimpleHttpClient): def __init__(self, hs): SimpleHttpClient.__init__(self, hs) # clobber the base class's agent and UA: - self.agent = ContentDecoderAgent(BrowserLikeRedirectAgent(Agent( - reactor, - connectTimeout=15, - contextFactory=hs.get_http_client_context_factory() - )), [('gzip', GzipDecoder)]) + self.agent = ContentDecoderAgent( + BrowserLikeRedirectAgent( + Agent.usingEndpointFactory( + reactor, + SpiderEndpointFactory(hs) + ) + ), [('gzip', GzipDecoder)] + ) # We could look like Chrome: # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) # Chrome Safari" % hs.version_string) diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 4775f6707d..de5c762f50 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -74,6 +74,37 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, return transport_endpoint(reactor, domain, port, **endpoint_kw_args) +class SpiderEndpoint(object): + """An endpoint which refuses to connect to blacklisted IP addresses + Implements twisted.internet.interfaces.IStreamClientEndpoint. + """ + def __init__(self, reactor, host, port, blacklist, + endpoint=TCP4ClientEndpoint, endpoint_kw_args={}): + self.reactor = reactor + self.host = host + self.port = port + self.blacklist = blacklist + self.endpoint = endpoint + self.endpoint_kw_args = endpoint_kw_args + + @defer.inlineCallbacks + def connect(self, protocolFactory): + address = yield self.reactor.resolve(self.host) + + from netaddr import IPAddress + if IPAddress(address) in self.blacklist: + raise ConnectError( + "Refusing to spider blacklisted IP address %s" % address + ) + + logger.info("Connecting to %s:%s", address, self.port) + endpoint = self.endpoint( + self.reactor, address, self.port, **self.endpoint_kw_args + ) + connection = yield endpoint.connect(protocolFactory) + defer.returnValue(connection) + + class SRVClientEndpoint(object): """An endpoint which looks up SRV records for a service. Cycles through the list of servers starting with each call to connect @@ -118,7 +149,7 @@ class SRVClientEndpoint(object): return self.default_server else: raise ConnectError( - "Not server available for %s", self.service_name + "Not server available for %s" % self.service_name ) min_priority = self.servers[0].priority @@ -166,7 +197,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): and answers[0].type == dns.SRV and answers[0].payload and answers[0].payload.target == dns.Name('.')): - raise ConnectError("Service %s unavailable", service_name) + raise ConnectError("Service %s unavailable" % service_name) for answer in answers: if answer.type != dns.SRV or not answer.payload: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 86b8331760..1adbdd9421 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -36,13 +36,16 @@ REQUIREMENTS = { "blist": ["blist"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pymacaroons-pynacl": ["pymacaroons"], - "lxml>=3.6.0": ["lxml"], "pyjwt": ["jwt"], } CONDITIONAL_REQUIREMENTS = { "web_client": { "matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"], - } + }, + "preview_url": { + "lxml>=3.6.0": ["lxml"], + "netaddr>=0.7.18": ["netaddr"], + }, } diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 11f672aeab..97b7e84af9 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -79,4 +79,9 @@ class MediaRepositoryResource(Resource): self.putChild("download", DownloadResource(hs, filepaths)) self.putChild("thumbnail", ThumbnailResource(hs, filepaths)) self.putChild("identicon", IdenticonResource()) - self.putChild("preview_url", PreviewUrlResource(hs, filepaths)) + if hs.config.url_preview_enabled: + try: + self.putChild("preview_url", PreviewUrlResource(hs, filepaths)) + except Exception as e: + logger.warn("Failed to mount preview_url") + logger.exception(e) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index f5ec32d8f2..faa88deb6e 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -17,34 +17,52 @@ from .base_resource import BaseMediaResource from twisted.web.server import NOT_DONE_YET from twisted.internet import defer -from lxml import html -from urlparse import urlparse, urlunparse +from urlparse import urlparse, urlsplit, urlunparse -from synapse.api.errors import Codes from synapse.util.stringutils import random_string from synapse.util.caches.expiringcache import ExpiringCache from synapse.http.client import SpiderHttpClient from synapse.http.server import ( - request_handler, respond_with_json, respond_with_json_bytes + request_handler, respond_with_json_bytes ) from synapse.util.async import ObservableDeferred from synapse.util.stringutils import is_ascii import os import re +import fnmatch import cgi import ujson as json import logging logger = logging.getLogger(__name__) +try: + from lxml import html +except ImportError: + pass + class PreviewUrlResource(BaseMediaResource): isLeaf = True def __init__(self, hs, filepaths): + if not html: + logger.warn("Disabling PreviewUrlResource as lxml not available") + raise + + if not hasattr(hs.config, "url_preview_ip_range_blacklist"): + logger.warn( + "For security, you must specify an explicit target IP address " + "blacklist in url_preview_ip_range_blacklist for url previewing " + "to work" + ) + raise + BaseMediaResource.__init__(self, hs, filepaths) self.client = SpiderHttpClient(hs) + if hasattr(hs.config, "url_preview_url_blacklist"): + self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist # simple memory cache mapping urls to OG metadata self.cache = ExpiringCache( @@ -74,6 +92,36 @@ class PreviewUrlResource(BaseMediaResource): else: ts = self.clock.time_msec() + # impose the URL pattern blacklist + if hasattr(self, "url_preview_url_blacklist"): + url_tuple = urlsplit(url) + for entry in self.url_preview_url_blacklist: + match = True + for attrib in entry: + pattern = entry[attrib] + value = getattr(url_tuple, attrib) + logger.debug("Matching attrib '%s' with value '%s' against pattern '%s'" % ( + attrib, value, pattern + )) + + if value is None: + match = False + continue + + if pattern.startswith('^'): + if not re.match(pattern, getattr(url_tuple, attrib)): + match = False + continue + else: + if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern): + match = False + continue + if match: + logger.warn( + "URL %s blocked by url_blacklist entry %s", url, entry + ) + raise + # first check the memory cache - good to handle all the clients on this # HS thundering away to preview the same URL at the same time. try: @@ -177,17 +225,6 @@ class PreviewUrlResource(BaseMediaResource): respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) except: - # XXX: if we don't explicitly respond here, the request never returns. - # isn't this what server.py's wrapper is meant to be doing for us? - respond_with_json( - request, - 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, - send_cors=True - ) raise @defer.inlineCallbacks @@ -282,8 +319,12 @@ class PreviewUrlResource(BaseMediaResource): if meta_description: og['og:description'] = meta_description[0] else: - # text_nodes = tree.xpath("//h1/text() | //h2/text() | //h3/text() | " - # "//p/text() | //div/text() | //span/text() | //a/text()") + # grab any text nodes which are inside the tag... + # unless they are within an HTML5 semantic markup tag... + #
,