From e3e3fbc23aab45c50ca3c605568de10c8e04a518 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Wed, 17 Aug 2016 12:46:49 +0100 Subject: Initial empty implementation that just registers an API endpoint handler --- synapse/rest/__init__.py | 2 ++ synapse/rest/client/v2_alpha/thirdparty.py | 38 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 synapse/rest/client/v2_alpha/thirdparty.py (limited to 'synapse') diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 14227f1cdb..2e0e6babef 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -47,6 +47,7 @@ from synapse.rest.client.v2_alpha import ( report_event, openid, devices, + thirdparty, ) from synapse.http.server import JsonResource @@ -92,3 +93,4 @@ class ClientRestResource(JsonResource): report_event.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource) + thirdparty.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py new file mode 100644 index 0000000000..9be88b2ba1 --- /dev/null +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from synapse.http.servlet import RestServlet +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class ThirdPartyUserServlet(RestServlet): + PATTERNS = client_v2_patterns("/3pu(/(?P[^/]+))?$", + releases=()) + + def __init__(self, hs): + super(ThirdPartyUserServlet, self).__init__() + pass + + def on_GET(self, request, protocol): + return (200, {"TODO":"TODO"}) + + +def register_servlets(hs, http_server): + ThirdPartyUserServlet(hs).register(http_server) -- cgit 1.4.1 From fa87c981e1efabe85a88144479b0dd7131b7da12 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Wed, 17 Aug 2016 13:15:06 +0100 Subject: Thread 3PU lookup through as far as the AS API object; which currently noƶps it MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- synapse/appservice/api.py | 3 +++ synapse/handlers/appservice.py | 21 +++++++++++++++++++++ synapse/rest/client/v2_alpha/thirdparty.py | 11 +++++++++-- 3 files changed, 33 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6da6a1b62e..6e5f7dc404 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -71,6 +71,9 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_alias to %s threw exception %s", uri, ex) defer.returnValue(False) + def query_3pu(self, service, protocol, fields): + return False + @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): events = self._serialize(events) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 051ccdb380..69fd766613 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -120,6 +120,21 @@ class ApplicationServicesHandler(object): ) defer.returnValue(result) + @defer.inlineCallbacks + def query_3pu(self, protocol, fields): + services = yield self._get_services_for_3pn(protocol) + + # TODO(paul): scattergather + results = [] + for service in services: + result = yield self.appservice_api.query_3pu( + service, protocol, fields + ) + if result: + results.append(result) + + defer.returnValue(results) + @defer.inlineCallbacks def _get_services_for_event(self, event, restrict_to="", alias_list=None): """Retrieve a list of application services interested in this event. @@ -163,6 +178,12 @@ class ApplicationServicesHandler(object): ] defer.returnValue(interested_list) + @defer.inlineCallbacks + def _get_services_for_3pn(self, protocol): + # TODO(paul): Filter by protocol + services = yield self.store.get_app_services() + defer.returnValue(services) + @defer.inlineCallbacks def _is_unknown_user(self, user_id): if not self.is_mine_id(user_id): diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 9be88b2ba1..0180b73e9f 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -16,6 +16,8 @@ import logging +from twisted.internet import defer + from synapse.http.servlet import RestServlet from ._base import client_v2_patterns @@ -28,10 +30,15 @@ class ThirdPartyUserServlet(RestServlet): def __init__(self, hs): super(ThirdPartyUserServlet, self).__init__() - pass + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks def on_GET(self, request, protocol): - return (200, {"TODO":"TODO"}) + fields = {} # TODO + results = yield self.appservice_handler.query_3pu(protocol, fields) + + defer.returnValue((200, results)) def register_servlets(hs, http_server): -- cgit 1.4.1 From 3ec10dffd6105e8fc78cb60f08c4636abf9d76e6 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 00:39:09 +0100 Subject: Actually make 3PU lookup calls out to ASes --- synapse/appservice/api.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6e5f7dc404..bfc1866591 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -71,8 +71,17 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_alias to %s threw exception %s", uri, ex) defer.returnValue(False) + @defer.inlineCallbacks def query_3pu(self, service, protocol, fields): - return False + uri = service.url + ("/3pu/%s" % urllib.quote(protocol)) + response = None + try: + response = yield self.get_json(uri, fields) + defer.returnValue(response) + except: + # TODO: would be noisy to log lookup failures, but we want to log + # other things. Hrm. + defer.returnValue([]) @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): -- cgit 1.4.1 From b3511adb920f81f8847e4cf3112018df08466ad6 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 13:45:18 +0100 Subject: Don't catch the return-value-as-exception that @defer.inlineCallbacks will use --- synapse/appservice/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index bfc1866591..39b4bff556 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -78,7 +78,7 @@ class ApplicationServiceApi(SimpleHttpClient): try: response = yield self.get_json(uri, fields) defer.returnValue(response) - except: + except Exception: # TODO: would be noisy to log lookup failures, but we want to log # other things. Hrm. defer.returnValue([]) -- cgit 1.4.1 From f0c73a1e7a723585d5ca983d6743a64cab92d1f5 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 13:53:54 +0100 Subject: Extend individual list results into the main return list, don't append --- synapse/handlers/appservice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 69fd766613..f124590e4a 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -131,7 +131,7 @@ class ApplicationServicesHandler(object): service, protocol, fields ) if result: - results.append(result) + results.extend(result) defer.returnValue(results) -- cgit 1.4.1 From 38565827418b71e12b3ff37e1482ff71ffe170d9 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 14:06:02 +0100 Subject: Ensure that 3PU lookup request fields actually get passed in --- synapse/rest/client/v2_alpha/thirdparty.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 0180b73e9f..4b2a93f1bb 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -35,7 +35,11 @@ class ThirdPartyUserServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - fields = {} # TODO + fields = request.args + del fields["access_token"] + + # TODO(paul): Some type checking on the request args might be nice + # They should probably all be strings results = yield self.appservice_handler.query_3pu(protocol, fields) defer.returnValue((200, results)) -- cgit 1.4.1 From 718ffcf8bbf83975a211a1b840de696c0eabec01 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 14:18:37 +0100 Subject: Since empty lookups now return 200/empty list not 404, we can safely log failures as exceptions --- synapse/appservice/api.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 39b4bff556..e05570cc8b 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -78,9 +78,8 @@ class ApplicationServiceApi(SimpleHttpClient): try: response = yield self.get_json(uri, fields) defer.returnValue(response) - except Exception: - # TODO: would be noisy to log lookup failures, but we want to log - # other things. Hrm. + except Exception as ex: + logger.warning("query_3pu to %s threw exception %s", uri, ex) defer.returnValue([]) @defer.inlineCallbacks -- cgit 1.4.1 From 434bbf2cb5b31f5a8430a06f53549248f7306cfd Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 14:56:02 +0100 Subject: Filter 3PU lookups by only ASes that declare knowledge of that protocol --- synapse/appservice/__init__.py | 9 ++++++++- synapse/config/appservice.py | 10 ++++++++++ synapse/handlers/appservice.py | 6 ++++-- 3 files changed, 22 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index b1b91d0a55..bde9b51b2e 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -81,13 +81,17 @@ class ApplicationService(object): NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] def __init__(self, token, url=None, namespaces=None, hs_token=None, - sender=None, id=None): + sender=None, id=None, protocols=None): self.token = token self.url = url self.hs_token = hs_token self.sender = sender self.namespaces = self._check_namespaces(namespaces) self.id = id + if protocols: + self.protocols = set(protocols) + else: + self.protocols = set() def _check_namespaces(self, namespaces): # Sanity check that it is of the form: @@ -219,6 +223,9 @@ class ApplicationService(object): or user_id == self.sender ) + def is_interested_in_protocol(self, protocol): + return protocol in self.protocols + def is_exclusive_alias(self, alias): return self._is_exclusive(ApplicationService.NS_ALIASES, alias) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index eade803909..3184d2084c 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -122,6 +122,15 @@ def _load_appservice(hostname, as_info, config_filename): raise ValueError( "Missing/bad type 'exclusive' key in %s", regex_obj ) + # protocols check + protocols = as_info.get("protocols") + if protocols: + # Because strings are lists in python + if isinstance(protocols, str) or not isinstance(protocols, list): + raise KeyError("Optional 'protocols' must be a list if present.") + for p in protocols: + if not isinstance(p, str): + raise KeyError("Bad value for 'protocols' item") return ApplicationService( token=as_info["as_token"], url=as_info["url"], @@ -129,4 +138,5 @@ def _load_appservice(hostname, as_info, config_filename): hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], + protocols=protocols, ) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 52e897d8d9..e0a6c9f19d 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -191,9 +191,11 @@ class ApplicationServicesHandler(object): @defer.inlineCallbacks def _get_services_for_3pn(self, protocol): - # TODO(paul): Filter by protocol services = yield self.store.get_app_services() - defer.returnValue(services) + interested_list = [ + s for s in services if s.is_interested_in_protocol(protocol) + ] + defer.returnValue(interested_list) @defer.inlineCallbacks def _is_unknown_user(self, user_id): -- cgit 1.4.1 From 80f4740c8f638e7d07b72d87fcb608435f3f9c15 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 15:40:41 +0100 Subject: Scattergather the call out to ASes; validate received results --- synapse/handlers/appservice.py | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index e0a6c9f19d..cd55f6b7f1 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -34,6 +34,26 @@ def log_failure(failure): ) ) +def _is_valid_3pu_result(r): + if not isinstance(r, dict): + return False + + for k in ("userid", "protocol"): + if k not in r: + return False + if not isinstance(r[k], str): + return False + + if "fields" not in r: + return False + fields = r["fields"] + if not isinstance(fields, dict): + return False + for k in fields.keys(): + if not isinstance(fields[k], str): + return False + + return True class ApplicationServicesHandler(object): @@ -150,16 +170,23 @@ class ApplicationServicesHandler(object): def query_3pu(self, protocol, fields): services = yield self._get_services_for_3pn(protocol) - # TODO(paul): scattergather - results = [] + deferreds = [] for service in services: - result = yield self.appservice_api.query_3pu( + deferreds.append(self.appservice_api.query_3pu( service, protocol, fields - ) - if result: - results.extend(result) + )) + + results = yield defer.DeferredList(deferreds, consumeErrors=True) + + ret = [] + for (success, result) in results: + if not success: + continue + if not isinstance(result, list): + continue + ret.extend(r for r in result if _is_valid_3pu_result(r)) - defer.returnValue(results) + defer.returnValue(ret) @defer.inlineCallbacks def _get_services_for_event(self, event): -- cgit 1.4.1 From d7b42afc74662afef983bc42ff6e50b2deb91e0e Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 15:49:55 +0100 Subject: Log a warning if an AS yields an invalid 3PU lookup result --- synapse/handlers/appservice.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index cd55f6b7f1..5ed694e711 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -184,7 +184,12 @@ class ApplicationServicesHandler(object): continue if not isinstance(result, list): continue - ret.extend(r for r in result if _is_valid_3pu_result(r)) + for r in result: + if _is_valid_3pu_result(r): + ret.append(r) + else: + logger.warn("Application service returned an " + + "invalid result %r", r) defer.returnValue(ret) -- cgit 1.4.1 From f3afd6ef1a44ef8b87a3f7257a5e42e69c75523e Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 15:53:01 +0100 Subject: Remove TODO note about request fields being strings - they're always strings --- synapse/rest/client/v2_alpha/thirdparty.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 4b2a93f1bb..bce104c545 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -38,8 +38,6 @@ class ThirdPartyUserServlet(RestServlet): fields = request.args del fields["access_token"] - # TODO(paul): Some type checking on the request args might be nice - # They should probably all be strings results = yield self.appservice_handler.query_3pu(protocol, fields) defer.returnValue((200, results)) -- cgit 1.4.1 From 06964c4a0adabf7d983cbd0d2c6d83eba6fcaf79 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 16:09:50 +0100 Subject: Copypasta the 3PU support code to also do 3PL --- synapse/appservice/api.py | 11 ++++++++++ synapse/handlers/appservice.py | 33 +++++++++++++++++++++++++++--- synapse/rest/client/v2_alpha/thirdparty.py | 20 ++++++++++++++++++ 3 files changed, 61 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index e05570cc8b..4ccb5c43c1 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -82,6 +82,17 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_3pu to %s threw exception %s", uri, ex) defer.returnValue([]) + @defer.inlineCallbacks + def query_3pl(self, service, protocol, fields): + uri = service.url + ("/3pl/%s" % urllib.quote(protocol)) + response = None + try: + response = yield self.get_json(uri, fields) + defer.returnValue(response) + except Exception as ex: + logger.warning("query_3pl to %s threw exception %s", uri, ex) + defer.returnValue([]) + @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): events = self._serialize(events) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 5ed694e711..72c36615df 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -34,11 +34,11 @@ def log_failure(failure): ) ) -def _is_valid_3pu_result(r): +def _is_valid_3pentity_result(r, field): if not isinstance(r, dict): return False - for k in ("userid", "protocol"): + for k in (field, "protocol"): if k not in r: return False if not isinstance(r[k], str): @@ -185,7 +185,34 @@ class ApplicationServicesHandler(object): if not isinstance(result, list): continue for r in result: - if _is_valid_3pu_result(r): + if _is_valid_3pentity_result(r, field="userid"): + ret.append(r) + else: + logger.warn("Application service returned an " + + "invalid result %r", r) + + defer.returnValue(ret) + + @defer.inlineCallbacks + def query_3pl(self, protocol, fields): + services = yield self._get_services_for_3pn(protocol) + + deferreds = [] + for service in services: + deferreds.append(self.appservice_api.query_3pl( + service, protocol, fields + )) + + results = yield defer.DeferredList(deferreds, consumeErrors=True) + + ret = [] + for (success, result) in results: + if not success: + continue + if not isinstance(result, list): + continue + for r in result: + if _is_valid_3pentity_result(r, field="alias"): ret.append(r) else: logger.warn("Application service returned an " + diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index bce104c545..eec08425e6 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -43,5 +43,25 @@ class ThirdPartyUserServlet(RestServlet): defer.returnValue((200, results)) +class ThirdPartyLocationServlet(RestServlet): + PATTERNS = client_v2_patterns("/3pl(/(?P[^/]+))?$", + releases=()) + + def __init__(self, hs): + super(ThirdPartyLocationServlet, self).__init__() + + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks + def on_GET(self, request, protocol): + fields = request.args + del fields["access_token"] + + results = yield self.appservice_handler.query_3pl(protocol, fields) + + defer.returnValue((200, results)) + + def register_servlets(hs, http_server): ThirdPartyUserServlet(hs).register(http_server) + ThirdPartyLocationServlet(hs).register(http_server) -- cgit 1.4.1 From 105ff162d4ae3776674cb1cbec6581e1511871d2 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 16:19:23 +0100 Subject: Authenticate 3PE lookup requests --- synapse/rest/client/v2_alpha/thirdparty.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse') diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index eec08425e6..d229e4b818 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -31,10 +31,13 @@ class ThirdPartyUserServlet(RestServlet): def __init__(self, hs): super(ThirdPartyUserServlet, self).__init__() + self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @defer.inlineCallbacks def on_GET(self, request, protocol): + yield self.auth.get_user_by_req(request) + fields = request.args del fields["access_token"] @@ -50,10 +53,13 @@ class ThirdPartyLocationServlet(RestServlet): def __init__(self, hs): super(ThirdPartyLocationServlet, self).__init__() + self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @defer.inlineCallbacks def on_GET(self, request, protocol): + yield self.auth.get_user_by_req(request) + fields = request.args del fields["access_token"] -- cgit 1.4.1 From fcf1dec809e35826b50ed6841730dc0bfeff724a Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 16:26:19 +0100 Subject: Appease pep8 --- synapse/handlers/appservice.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse') diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 72c36615df..a2715e5cf6 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -34,6 +34,7 @@ def log_failure(failure): ) ) + def _is_valid_3pentity_result(r, field): if not isinstance(r, dict): return False @@ -55,6 +56,7 @@ def _is_valid_3pentity_result(r, field): return True + class ApplicationServicesHandler(object): def __init__(self, hs): -- cgit 1.4.1 From 2a91799fccd0791083131e8b23ac0b900e42b7f4 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 16:58:25 +0100 Subject: Minor syntax neatenings --- synapse/appservice/api.py | 4 ++-- synapse/handlers/appservice.py | 22 ++++++++-------------- 2 files changed, 10 insertions(+), 16 deletions(-) (limited to 'synapse') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 4ccb5c43c1..d4cad1b1ed 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -73,7 +73,7 @@ class ApplicationServiceApi(SimpleHttpClient): @defer.inlineCallbacks def query_3pu(self, service, protocol, fields): - uri = service.url + ("/3pu/%s" % urllib.quote(protocol)) + uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) response = None try: response = yield self.get_json(uri, fields) @@ -84,7 +84,7 @@ class ApplicationServiceApi(SimpleHttpClient): @defer.inlineCallbacks def query_3pl(self, service, protocol, fields): - uri = service.url + ("/3pl/%s" % urllib.quote(protocol)) + uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) response = None try: response = yield self.get_json(uri, fields) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index a2715e5cf6..03452f6bb0 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -172,13 +172,10 @@ class ApplicationServicesHandler(object): def query_3pu(self, protocol, fields): services = yield self._get_services_for_3pn(protocol) - deferreds = [] - for service in services: - deferreds.append(self.appservice_api.query_3pu( - service, protocol, fields - )) - - results = yield defer.DeferredList(deferreds, consumeErrors=True) + results = yield defer.DeferredList([ + self.appservice_api.query_3pu(service, protocol, fields) + for service in services + ], consumeErrors=True) ret = [] for (success, result) in results: @@ -199,13 +196,10 @@ class ApplicationServicesHandler(object): def query_3pl(self, protocol, fields): services = yield self._get_services_for_3pn(protocol) - deferreds = [] - for service in services: - deferreds.append(self.appservice_api.query_3pl( - service, protocol, fields - )) - - results = yield defer.DeferredList(deferreds, consumeErrors=True) + results = yield defer.DeferredList([ + self.appservice_api.query_3pl(service, protocol, fields) + for service in services + ], consumeErrors=True) ret = [] for (success, result) in results: -- cgit 1.4.1 From b515f844ee07c7d6aa1d7e56faa8b65d282e9341 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 17:19:55 +0100 Subject: Avoid so much copypasta between 3PU and 3PL query by unifying around a ThirdPartyEntityKind enumeration --- synapse/appservice/api.py | 25 ++++++++++----------- synapse/handlers/appservice.py | 35 ++++++++---------------------- synapse/rest/client/v2_alpha/thirdparty.py | 9 ++++++-- synapse/types.py | 7 ++++++ 4 files changed, 34 insertions(+), 42 deletions(-) (limited to 'synapse') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d4cad1b1ed..dd5e762e0d 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event +from synapse.types import ThirdPartyEntityKind import logging import urllib @@ -72,25 +73,21 @@ class ApplicationServiceApi(SimpleHttpClient): defer.returnValue(False) @defer.inlineCallbacks - def query_3pu(self, service, protocol, fields): - uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) - response = None - try: - response = yield self.get_json(uri, fields) - defer.returnValue(response) - except Exception as ex: - logger.warning("query_3pu to %s threw exception %s", uri, ex) - defer.returnValue([]) + def query_3pe(self, service, kind, protocol, fields): + if kind == ThirdPartyEntityKind.USER: + uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) + elif kind == ThirdPartyEntityKind.LOCATION: + uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) + else: + raise ValueError( + "Unrecognised 'kind' argument %r to query_3pe()", kind + ) - @defer.inlineCallbacks - def query_3pl(self, service, protocol, fields): - uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) - response = None try: response = yield self.get_json(uri, fields) defer.returnValue(response) except Exception as ex: - logger.warning("query_3pl to %s threw exception %s", uri, ex) + logger.warning("query_3pe to %s threw exception %s", uri, ex) defer.returnValue([]) @defer.inlineCallbacks diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 03452f6bb0..52c127d2c1 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.util.metrics import Measure from synapse.util.logcontext import preserve_fn +from synapse.types import ThirdPartyEntityKind import logging @@ -169,37 +170,19 @@ class ApplicationServicesHandler(object): defer.returnValue(result) @defer.inlineCallbacks - def query_3pu(self, protocol, fields): + def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) results = yield defer.DeferredList([ - self.appservice_api.query_3pu(service, protocol, fields) + self.appservice_api.query_3pe(service, kind, protocol, fields) for service in services ], consumeErrors=True) - ret = [] - for (success, result) in results: - if not success: - continue - if not isinstance(result, list): - continue - for r in result: - if _is_valid_3pentity_result(r, field="userid"): - ret.append(r) - else: - logger.warn("Application service returned an " + - "invalid result %r", r) - - defer.returnValue(ret) - - @defer.inlineCallbacks - def query_3pl(self, protocol, fields): - services = yield self._get_services_for_3pn(protocol) - - results = yield defer.DeferredList([ - self.appservice_api.query_3pl(service, protocol, fields) - for service in services - ], consumeErrors=True) + required_field = ( + "userid" if kind == ThirdPartyEntityKind.USER else + "alias" if kind == ThirdPartyEntityKind.LOCATION else + None + ) ret = [] for (success, result) in results: @@ -208,7 +191,7 @@ class ApplicationServicesHandler(object): if not isinstance(result, list): continue for r in result: - if _is_valid_3pentity_result(r, field="alias"): + if _is_valid_3pentity_result(r, field=required_field): ret.append(r) else: logger.warn("Application service returned an " + diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index d229e4b818..9abca3a8ad 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -19,6 +19,7 @@ import logging from twisted.internet import defer from synapse.http.servlet import RestServlet +from synapse.types import ThirdPartyEntityKind from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -41,7 +42,9 @@ class ThirdPartyUserServlet(RestServlet): fields = request.args del fields["access_token"] - results = yield self.appservice_handler.query_3pu(protocol, fields) + results = yield self.appservice_handler.query_3pe( + ThirdPartyEntityKind.USER, protocol, fields + ) defer.returnValue((200, results)) @@ -63,7 +66,9 @@ class ThirdPartyLocationServlet(RestServlet): fields = request.args del fields["access_token"] - results = yield self.appservice_handler.query_3pl(protocol, fields) + results = yield self.appservice_handler.query_3pe( + ThirdPartyEntityKind.LOCATION, protocol, fields + ) defer.returnValue((200, results)) diff --git a/synapse/types.py b/synapse/types.py index 5349b0c450..fd17ecbbe0 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) + + +# Some arbitrary constants used for internal API enumerations. Don't rely on +# exact values; always pass or compare symbolically +class ThirdPartyEntityKind(object): + USER = 'user' + LOCATION = 'location' -- cgit 1.4.1 From 697872cf087d983d77e7c2174ad71361f703fb49 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 17:24:39 +0100 Subject: More warnings about invalid results from AS 3PE query --- synapse/handlers/appservice.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 52c127d2c1..6f162a3c00 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -187,15 +187,20 @@ class ApplicationServicesHandler(object): ret = [] for (success, result) in results: if not success: + logger.warn("Application service failed %r", result) continue if not isinstance(result, list): + logger.warn( + "Application service returned an invalid response %r", result + ) continue for r in result: if _is_valid_3pentity_result(r, field=required_field): ret.append(r) else: - logger.warn("Application service returned an " + - "invalid result %r", r) + logger.warn( + "Application service returned an invalid result %r", r + ) defer.returnValue(ret) -- cgit 1.4.1 From 65201631a407b71087bb52da8b591e0975c463ec Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 17:33:56 +0100 Subject: Move validation logic for AS 3PE query response into ApplicationServiceApi class, to keep the handler logic neater --- synapse/appservice/api.py | 43 ++++++++++++++++++++++++++++++++++++++- synapse/handlers/appservice.py | 46 ++---------------------------------------- 2 files changed, 44 insertions(+), 45 deletions(-) (limited to 'synapse') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index dd5e762e0d..066127b666 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -25,6 +25,28 @@ import urllib logger = logging.getLogger(__name__) +def _is_valid_3pe_result(r, field): + if not isinstance(r, dict): + return False + + for k in (field, "protocol"): + if k not in r: + return False + if not isinstance(r[k], str): + return False + + if "fields" not in r: + return False + fields = r["fields"] + if not isinstance(fields, dict): + return False + for k in fields.keys(): + if not isinstance(fields[k], str): + return False + + return True + + class ApplicationServiceApi(SimpleHttpClient): """This class manages HS -> AS communications, including querying and pushing. @@ -76,8 +98,10 @@ class ApplicationServiceApi(SimpleHttpClient): def query_3pe(self, service, kind, protocol, fields): if kind == ThirdPartyEntityKind.USER: uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) + required_field = "userid" elif kind == ThirdPartyEntityKind.LOCATION: uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) + required_field = "alias" else: raise ValueError( "Unrecognised 'kind' argument %r to query_3pe()", kind @@ -85,7 +109,24 @@ class ApplicationServiceApi(SimpleHttpClient): try: response = yield self.get_json(uri, fields) - defer.returnValue(response) + if not isinstance(response, list): + logger.warning( + "query_3pe to %s returned an invalid response %r", + uri, response + ) + defer.returnValue([]) + + ret = [] + for r in response: + if _is_valid_3pe_result(r, field=required_field): + ret.append(r) + else: + logger.warning( + "query_3pe to %s returned an invalid result %r", + uri, r + ) + + defer.returnValue(ret) except Exception as ex: logger.warning("query_3pe to %s threw exception %s", uri, ex) defer.returnValue([]) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 6f162a3c00..18dca462a9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -18,7 +18,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.util.metrics import Measure from synapse.util.logcontext import preserve_fn -from synapse.types import ThirdPartyEntityKind import logging @@ -36,28 +35,6 @@ def log_failure(failure): ) -def _is_valid_3pentity_result(r, field): - if not isinstance(r, dict): - return False - - for k in (field, "protocol"): - if k not in r: - return False - if not isinstance(r[k], str): - return False - - if "fields" not in r: - return False - fields = r["fields"] - if not isinstance(fields, dict): - return False - for k in fields.keys(): - if not isinstance(fields[k], str): - return False - - return True - - class ApplicationServicesHandler(object): def __init__(self, hs): @@ -178,29 +155,10 @@ class ApplicationServicesHandler(object): for service in services ], consumeErrors=True) - required_field = ( - "userid" if kind == ThirdPartyEntityKind.USER else - "alias" if kind == ThirdPartyEntityKind.LOCATION else - None - ) - ret = [] for (success, result) in results: - if not success: - logger.warn("Application service failed %r", result) - continue - if not isinstance(result, list): - logger.warn( - "Application service returned an invalid response %r", result - ) - continue - for r in result: - if _is_valid_3pentity_result(r, field=required_field): - ret.append(r) - else: - logger.warn( - "Application service returned an invalid result %r", r - ) + if success: + ret.extend(result) defer.returnValue(ret) -- cgit 1.4.1 From b770435389a9c827582884912b0a2761d0eed812 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 19 Aug 2016 10:19:29 +0100 Subject: Make get_new_events_for_appservice use indices --- synapse/storage/appservice.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index b496b918b7..a854a87eab 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -366,8 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore): def get_new_events_for_appservice_txn(txn): sql = ( "SELECT e.stream_ordering, e.event_id" - " FROM events AS e, appservice_stream_position AS a" - " WHERE a.stream_ordering < e.stream_ordering AND e.stream_ordering <= ?" + " FROM events AS e" + " WHERE" + " (SELECT stream_ordering FROM appservice_stream_position)" + " < e.stream_ordering" + " AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" " LIMIT ?" ) -- cgit 1.4.1 From 4161ff2fc45781dd69623f95721533e0a594f807 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 19 Aug 2016 11:18:26 +0100 Subject: Add concept of cache contexts --- synapse/util/caches/descriptors.py | 77 ++++++++++++++++++++----- synapse/util/caches/lrucache.py | 39 ++++++++++--- synapse/util/caches/treecache.py | 3 + tests/storage/test__base.py | 66 ++++++++++++++++++++++ tests/util/test_lrucache.py | 113 +++++++++++++++++++++++++++++++++++++ 5 files changed, 278 insertions(+), 20 deletions(-) (limited to 'synapse') diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index f31dfb22b7..5cd277f2f2 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -55,7 +55,7 @@ class Cache(object): ) def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): - if lru: + if True: cache_type = TreeCache if tree else dict self.cache = LruCache( max_size=max_entries, keylen=keylen, cache_type=cache_type @@ -81,8 +81,8 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, key, default=_CacheSentinel): - val = self.cache.get(key, _CacheSentinel) + def get(self, key, default=_CacheSentinel, callback=None): + val = self.cache.get(key, _CacheSentinel, callback=callback) if val is not _CacheSentinel: self.metrics.inc_hits() return val @@ -94,19 +94,19 @@ class Cache(object): else: return default - def update(self, sequence, key, value): + def update(self, sequence, key, value, callback=None): self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value) + self.prefill(key, value, callback=callback) - def prefill(self, key, value): + def prefill(self, key, value, callback=None): if self.max_entries is not None: while len(self.cache) >= self.max_entries: - self.cache.popitem(last=False) + self.cache.popitem(last=False, callback=None) - self.cache[key] = value + self.cache.set(key, value, callback=callback) def invalidate(self, key): self.check_thread() @@ -151,6 +151,18 @@ class CacheDescriptor(object): The wrapped function has another additional callable, called "prefill", which can be used to insert values into the cache specifically, without calling the calculation function. + + Cached functions can be "chained" (i.e. a cached function can call other cached + functions and get appropriately invalidated when they called caches are + invalidated) by adding a special "cache_context" argument to the function + and passing that as a kwarg to all caches called. For example:: + + @cachedInlineCallbacks() + def foo(self, key, cache_context): + r1 = yield self.bar1(key, cache_context=cache_context) + r2 = yield self.bar2(key, cache_context=cache_context) + defer.returnValue(r1 + r2) + """ def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, inlineCallbacks=False): @@ -168,7 +180,13 @@ class CacheDescriptor(object): self.lru = lru self.tree = tree - self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] + all_args = inspect.getargspec(orig) + self.arg_names = all_args.args[1:num_args + 1] + + if "cache_context" in self.arg_names: + self.arg_names.remove("cache_context") + + self.add_cache_context = "cache_context" in all_args.args if len(self.arg_names) < self.num_args: raise Exception( @@ -188,10 +206,23 @@ class CacheDescriptor(object): @functools.wraps(self.orig) def wrapped(*args, **kwargs): + cache_context = kwargs.pop("cache_context", None) + if cache_context: + context_callback = cache_context.invalidate + else: + context_callback = None + + self_context = _CacheContext(cache, None) + if self.add_cache_context: + kwargs["cache_context"] = self_context + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + + self_context.key = cache_key + try: - cached_result_d = cache.get(cache_key) + cached_result_d = cache.get(cache_key, callback=context_callback) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -228,7 +259,7 @@ class CacheDescriptor(object): ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - cache.update(sequence, cache_key, ret) + cache.update(sequence, cache_key, ret, callback=context_callback) return preserve_context_over_deferred(ret.observe()) @@ -297,6 +328,12 @@ class CacheListDescriptor(object): @functools.wraps(self.orig) def wrapped(*args, **kwargs): + cache_context = kwargs.pop("cache_context", None) + if cache_context: + context_callback = cache_context.invalidate + else: + context_callback = None + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] @@ -311,7 +348,7 @@ class CacheListDescriptor(object): key[self.list_pos] = arg try: - res = cache.get(tuple(key)) + res = cache.get(tuple(key), callback=context_callback) if not res.has_succeeded(): res = res.observe() res.addCallback(lambda r, arg: (arg, r), arg) @@ -345,7 +382,10 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - cache.update(sequence, tuple(key), observer) + cache.update( + sequence, tuple(key), observer, + callback=context_callback + ) def invalidate(f, key): cache.invalidate(key) @@ -376,6 +416,17 @@ class CacheListDescriptor(object): return wrapped +class _CacheContext(object): + __slots__ = ["cache", "key"] + + def __init__(self, cache, key): + self.cache = cache + self.key = key + + def invalidate(self): + self.cache.invalidate(self.key) + + def cached(max_entries=1000, num_args=1, lru=True, tree=False): return lambda orig: CacheDescriptor( orig, diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index f9df445a8d..a5a827b4d1 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -30,13 +30,14 @@ def enumerate_leaves(node, depth): class _Node(object): - __slots__ = ["prev_node", "next_node", "key", "value"] + __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] - def __init__(self, prev_node, next_node, key, value): + def __init__(self, prev_node, next_node, key, value, callbacks=[]): self.prev_node = prev_node self.next_node = next_node self.key = key self.value = value + self.callbacks = callbacks class LruCache(object): @@ -44,6 +45,9 @@ class LruCache(object): Least-recently-used cache. Supports del_multi only if cache_type=TreeCache If cache_type=TreeCache, all keys must be tuples. + + Can also set callbacks on objects when getting/setting which are fired + when that key gets invalidated/evicted. """ def __init__(self, max_size, keylen=1, cache_type=dict): cache = cache_type() @@ -62,10 +66,10 @@ class LruCache(object): return inner - def add_node(key, value): + def add_node(key, value, callbacks=[]): prev_node = list_root next_node = prev_node.next_node - node = _Node(prev_node, next_node, key, value) + node = _Node(prev_node, next_node, key, value, callbacks) prev_node.next_node = node next_node.prev_node = node cache[key] = node @@ -88,23 +92,41 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + for cb in node.callbacks: + cb() + node.callbacks = [] + @synchronized - def cache_get(key, default=None): + def cache_get(key, default=None, callback=None): node = cache.get(key, None) if node is not None: move_node_to_front(node) + if callback: + node.callbacks.append(callback) return node.value else: return default @synchronized - def cache_set(key, value): + def cache_set(key, value, callback=None): node = cache.get(key, None) if node is not None: + if value != node.value: + for cb in node.callbacks: + cb() + node.callbacks = [] + + if callback: + node.callbacks.append(callback) + move_node_to_front(node) node.value = value else: - add_node(key, value) + if callback: + callbacks = [callback] + else: + callbacks = [] + add_node(key, value, callbacks) if len(cache) > max_size: todelete = list_root.prev_node delete_node(todelete) @@ -148,6 +170,9 @@ class LruCache(object): def cache_clear(): list_root.next_node = list_root list_root.prev_node = list_root + for node in cache.values(): + for cb in node.callbacks: + cb() cache.clear() @synchronized diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 03bc1401b7..c31585aea3 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -64,6 +64,9 @@ class TreeCache(object): self.size -= cnt return popped + def values(self): + return [e.value for e in self.root.values()] + def __len__(self): return self.size diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 96b7dba5fe..9d99eea8d0 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -199,3 +199,69 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) + + @defer.inlineCallbacks + def test_invalidate_context(self): + callcount = [0] + callcount2 = [0] + + class A(object): + @cached() + def func(self, key): + callcount[0] += 1 + return key + + @cached() + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, cache_context=cache_context) + + a = A() + yield a.func2("foo") + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 1) + + a.func.invalidate(("foo",)) + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 1) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + @defer.inlineCallbacks + def test_eviction_context(self): + callcount = [0] + callcount2 = [0] + + class A(object): + @cached(max_entries=2) + def func(self, key): + callcount[0] += 1 + return key + + @cached() + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, cache_context=cache_context) + + a = A() + yield a.func2("foo") + yield a.func2("foo2") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func("foo3") + + self.assertEquals(callcount[0], 3) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 4) + self.assertEquals(callcount2[0], 3) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index bab366fb7f..bacec2f465 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -19,6 +19,8 @@ from .. import unittest from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache +from mock import Mock + class LruCacheTestCase(unittest.TestCase): @@ -79,3 +81,114 @@ class LruCacheTestCase(unittest.TestCase): cache["key"] = 1 cache.clear() self.assertEquals(len(cache), 0) + + +class LruCacheCallbacksTestCase(unittest.TestCase): + def test_set(self): + m = Mock() + cache = LruCache(1) + + cache.set("key", "value", m) + self.assertFalse(m.called) + + cache.set("key", "value") + self.assertFalse(m.called) + + cache.set("key", "value2") + self.assertEquals(m.call_count, 1) + + cache.set("key", "value") + self.assertEquals(m.call_count, 1) + + def test_pop(self): + m = Mock() + cache = LruCache(1) + + cache.set("key", "value", m) + self.assertFalse(m.called) + + cache.pop("key") + self.assertEquals(m.call_count, 1) + + cache.set("key", "value") + self.assertEquals(m.call_count, 1) + + cache.pop("key") + self.assertEquals(m.call_count, 1) + + def test_del_multi(self): + m1 = Mock() + m2 = Mock() + m3 = Mock() + m4 = Mock() + cache = LruCache(4, 2, cache_type=TreeCache) + + cache.set(("a", "1"), "value", m1) + cache.set(("a", "2"), "value", m2) + cache.set(("b", "1"), "value", m3) + cache.set(("b", "2"), "value", m4) + + self.assertEquals(m1.call_count, 0) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 0) + self.assertEquals(m4.call_count, 0) + + cache.del_multi(("a",)) + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 1) + self.assertEquals(m3.call_count, 0) + self.assertEquals(m4.call_count, 0) + + def test_clear(self): + m1 = Mock() + m2 = Mock() + cache = LruCache(5) + + cache.set("key1", "value", m1) + cache.set("key2", "value", m2) + + self.assertEquals(m1.call_count, 0) + self.assertEquals(m2.call_count, 0) + + cache.clear() + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 1) + + def test_eviction(self): + m1 = Mock(name="m1") + m2 = Mock(name="m2") + m3 = Mock(name="m3") + cache = LruCache(2) + + cache.set("key1", "value", m1) + cache.set("key2", "value", m2) + + self.assertEquals(m1.call_count, 0) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 0) + + cache.set("key3", "value", m3) + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 0) + + cache.set("key3", "value") + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 0) + + cache.get("key2") + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 0) + + cache.set("key1", "value", m1) + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 1) -- cgit 1.4.1 From ba214a5e325adbf8ab430cb15f55d2c7544eba8b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 19 Aug 2016 11:59:29 +0100 Subject: Remove lru option --- synapse/storage/_base.py | 2 +- synapse/storage/event_push_actions.py | 2 +- synapse/storage/push_rule.py | 4 ++-- synapse/storage/pusher.py | 2 +- synapse/storage/receipts.py | 2 +- synapse/storage/signatures.py | 2 +- synapse/storage/state.py | 4 ++-- synapse/util/caches/descriptors.py | 31 ++++++++----------------------- tests/storage/test__base.py | 2 +- 9 files changed, 18 insertions(+), 33 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 029f6612e6..49fa8614f2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -166,7 +166,7 @@ class SQLBaseStore(object): self._txn_perf_counters = PerformanceCounters() self._get_event_counters = PerformanceCounters() - self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, + self._get_event_cache = Cache("*getEvent*", keylen=3, max_entries=hs.config.event_cache_size) self._state_group_cache = DictionaryCache( diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index df4000d0da..c65c9c9c47 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore): ) self._simple_insert_many_txn(txn, "event_push_actions", values) - @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000) + @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 8183b7f1b0..86e4a3a81d 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -48,7 +48,7 @@ def _load_rules(rawrules, enabled_map): class PushRuleStore(SQLBaseStore): - @cachedInlineCallbacks(lru=True) + @cachedInlineCallbacks() def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", @@ -72,7 +72,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(rules) - @cachedInlineCallbacks(lru=True) + @cachedInlineCallbacks() def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index a7d7c54d7e..8f5f8f24a9 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore): "get_all_updated_pushers", get_all_updated_pushers_txn ) - @cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000) + @cachedInlineCallbacks(num_args=1, max_entries=15000) def get_if_user_has_pusher(self, user_id): result = yield self._simple_select_many_batch( table='pushers', diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 8c26f39fbb..3ad916103f 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -120,7 +120,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue([ev for res in results.values() for ev in res]) - @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True) + @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True) def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index ea6823f18d..e1dca927d7 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList class SignatureStore(SQLBaseStore): """Persistence for event signatures and hashes""" - @cached(lru=True) + @cached() def get_event_reference_hash(self, event_id): return self._get_event_reference_hashes_txn(event_id) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 5b743db67a..0e8fa93e1f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -174,7 +174,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, lru=True, max_entries=1000) + @cached(num_args=2, max_entries=1000) def _get_state_group_from_group(self, group, types): raise NotImplementedError() @@ -272,7 +272,7 @@ class StateStore(SQLBaseStore): state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) - @cached(num_args=2, lru=True, max_entries=10000) + @cached(num_args=2, max_entries=10000) def _get_state_group_for_event(self, room_id, event_id): return self._simple_select_one_onecol( table="event_to_state_groups", diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5cd277f2f2..c38f01ead0 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -26,8 +26,6 @@ from . import DEBUG_CACHES, register_cache from twisted.internet import defer -from collections import OrderedDict - import os import functools import inspect @@ -54,16 +52,11 @@ class Cache(object): "metrics", ) - def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): - if True: - cache_type = TreeCache if tree else dict - self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type - ) - self.max_entries = None - else: - self.cache = OrderedDict() - self.max_entries = max_entries + def __init__(self, name, max_entries=1000, keylen=1, tree=False): + cache_type = TreeCache if tree else dict + self.cache = LruCache( + max_size=max_entries, keylen=keylen, cache_type=cache_type + ) self.name = name self.keylen = keylen @@ -102,10 +95,6 @@ class Cache(object): self.prefill(key, value, callback=callback) def prefill(self, key, value, callback=None): - if self.max_entries is not None: - while len(self.cache) >= self.max_entries: - self.cache.popitem(last=False, callback=None) - self.cache.set(key, value, callback=callback) def invalidate(self, key): @@ -164,7 +153,7 @@ class CacheDescriptor(object): defer.returnValue(r1 + r2) """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, + def __init__(self, orig, max_entries=1000, num_args=1, tree=False, inlineCallbacks=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) @@ -177,7 +166,6 @@ class CacheDescriptor(object): self.max_entries = max_entries self.num_args = num_args - self.lru = lru self.tree = tree all_args = inspect.getargspec(orig) @@ -200,7 +188,6 @@ class CacheDescriptor(object): name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, - lru=self.lru, tree=self.tree, ) @@ -427,22 +414,20 @@ class _CacheContext(object): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, lru=True, tree=False): +def cached(max_entries=1000, num_args=1, tree=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, - lru=lru, tree=tree, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, - lru=lru, tree=tree, inlineCallbacks=True, ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 9d99eea8d0..ed074ce9ec 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -72,7 +72,7 @@ class CacheTestCase(unittest.TestCase): cache.get(3) def test_eviction_lru(self): - cache = Cache("test", max_entries=2, lru=True) + cache = Cache("test", max_entries=2) cache.prefill(1, "one") cache.prefill(2, "two") -- cgit 1.4.1 From f164fd922024308e702269a881328f7de980e9eb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 19 Aug 2016 14:07:27 +0100 Subject: Move _bulk_get_push_rules_for_room to storage layer --- synapse/push/action_generator.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 41 +++++------------------ synapse/storage/push_rule.py | 56 ++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 34 deletions(-) (limited to 'synapse') diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index b2c94bfaac..ed2ccc4dfb 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -40,7 +40,7 @@ class ActionGenerator: def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "evaluator_for_event"): bulk_evaluator = yield evaluator_for_event( - event, self.hs, self.store, context.current_state + event, self.hs, self.store, context.state_group, context.current_state ) with Measure(self.clock, "action_for_event_by_user"): diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 756e5da513..004eded61f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store): @defer.inlineCallbacks -def evaluator_for_event(event, hs, store, current_state): - room_id = event.room_id - # We also will want to generate notifs for other people in the room so - # their unread countss are correct in the event stream, but to avoid - # generating them for bot / AS users etc, we only do so for people who've - # sent a read receipt into the room. - - local_users_in_room = set( - e.state_key for e in current_state.values() - if e.type == EventTypes.Member and e.membership == Membership.JOIN - and hs.is_mine_id(e.state_key) +def evaluator_for_event(event, hs, store, state_group, current_state): + rules_by_user = yield store.bulk_get_push_rules_for_room( + event.room_id, state_group, current_state ) - # users in the room who have pushers need to get push rules run because - # that's how their pushers work - if_users_with_pushers = yield store.get_if_users_have_pushers( - local_users_in_room - ) - user_ids = set( - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - ) - - users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in local_users_in_room: - user_ids.add(uid) - # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited if event.type == 'm.room.member' and event.content['membership'] == 'invite': @@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state): if invited_user and hs.is_mine_id(invited_user): has_pusher = yield store.user_has_pusher(invited_user) if has_pusher: - user_ids.add(invited_user) - - rules_by_user = yield _get_rules(room_id, user_ids, store) + rules_by_user[invited_user] = yield store.get_push_rules_for_user( + invited_user + ) defer.returnValue(BulkPushRuleEvaluator( - room_id, rules_by_user, user_ids, store + event.room_id, rules_by_user, store )) @@ -90,10 +66,9 @@ class BulkPushRuleEvaluator: the same logic to run the actual rules, but could be optimised further (see https://matrix.org/jira/browse/SYN-562) """ - def __init__(self, room_id, rules_by_user, users_in_room, store): + def __init__(self, room_id, rules_by_user, store): self.room_id = room_id self.rules_by_user = rules_by_user - self.users_in_room = users_in_room self.store = store @defer.inlineCallbacks diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 86e4a3a81d..ca929bc239 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -16,6 +16,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.push.baserules import list_with_base_rules +from synapse.api.constants import EventTypes, Membership from twisted.internet import defer import logging @@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(results) + def bulk_get_push_rules_for_room(self, room_id, state_group, current_state): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) + + @cachedInlineCallbacks(num_args=2) + def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, + cache_context): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + # We also will want to generate notifs for other people in the room so + # their unread countss are correct in the event stream, but to avoid + # generating them for bot / AS users etc, we only do so for people who've + # sent a read receipt into the room. + local_users_in_room = set( + e.state_key for e in current_state.values() + if e.type == EventTypes.Member and e.membership == Membership.JOIN + and self.hs.is_mine_id(e.state_key) + ) + + # users in the room who have pushers need to get push rules run because + # that's how their pushers work + if_users_with_pushers = yield self.get_if_users_have_pushers( + local_users_in_room, cache_context=cache_context, + ) + user_ids = set( + uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher + ) + + users_with_receipts = yield self.get_users_with_read_receipts_in_room( + room_id, cache_context=cache_context, + ) + + # any users with pushers must be ours: they have pushers + for uid in users_with_receipts: + if uid in local_users_in_room: + user_ids.add(uid) + + rules_by_user = yield self.bulk_get_push_rules( + user_ids, cache_context=cache_context + ) + + rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} + + defer.returnValue(rules_by_user) + @cachedList(cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules_enabled(self, user_ids): -- cgit 1.4.1 From dc76a3e909535d99f0b6b4a76279a14685324dc4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 19 Aug 2016 15:02:38 +0100 Subject: Make cache_context an explicit option --- synapse/storage/push_rule.py | 2 +- synapse/util/caches/descriptors.py | 35 +++++++++++++++++++++++++++-------- tests/storage/test__base.py | 4 ++-- 3 files changed, 30 insertions(+), 11 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index ca929bc239..247dd15694 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -134,7 +134,7 @@ class PushRuleStore(SQLBaseStore): return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) - @cachedInlineCallbacks(num_args=2) + @cachedInlineCallbacks(num_args=2, cache_context=True) def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, cache_context): # We don't use `state_group`, its there so that we can cache based diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index c38f01ead0..e7a74d3da8 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -146,7 +146,7 @@ class CacheDescriptor(object): invalidated) by adding a special "cache_context" argument to the function and passing that as a kwarg to all caches called. For example:: - @cachedInlineCallbacks() + @cachedInlineCallbacks(cache_context=True) def foo(self, key, cache_context): r1 = yield self.bar1(key, cache_context=cache_context) r2 = yield self.bar2(key, cache_context=cache_context) @@ -154,7 +154,7 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, tree=False, - inlineCallbacks=False): + inlineCallbacks=False, cache_context=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -171,15 +171,28 @@ class CacheDescriptor(object): all_args = inspect.getargspec(orig) self.arg_names = all_args.args[1:num_args + 1] - if "cache_context" in self.arg_names: - self.arg_names.remove("cache_context") + if "cache_context" in all_args.args: + if not cache_context: + raise ValueError( + "Cannot have a 'cache_context' arg without setting" + " cache_context=True" + ) + try: + self.arg_names.remove("cache_context") + except ValueError: + pass + elif cache_context: + raise ValueError( + "Cannot have cache_context=True without having an arg" + " named `cache_context`" + ) - self.add_cache_context = "cache_context" in all_args.args + self.add_cache_context = cache_context if len(self.arg_names) < self.num_args: raise Exception( "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" + " (@cached cannot key off of *args or **kwargs)" % (orig.__name__,) ) @@ -193,12 +206,16 @@ class CacheDescriptor(object): @functools.wraps(self.orig) def wrapped(*args, **kwargs): + # If we're passed a cache_context then we'll want to call its invalidate() + # whenever we are invalidated cache_context = kwargs.pop("cache_context", None) if cache_context: context_callback = cache_context.invalidate else: context_callback = None + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one self_context = _CacheContext(cache, None) if self.add_cache_context: kwargs["cache_context"] = self_context @@ -414,22 +431,24 @@ class _CacheContext(object): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False): +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, tree=tree, + cache_context=cache_context, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, tree=tree, inlineCallbacks=True, + cache_context=cache_context, ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index ed074ce9ec..eab0c8d219 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -211,7 +211,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount[0] += 1 return key - @cached() + @cached(cache_context=True) def func2(self, key, cache_context): callcount2[0] += 1 return self.func(key, cache_context=cache_context) @@ -244,7 +244,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount[0] += 1 return key - @cached() + @cached(cache_context=True) def func2(self, key, cache_context): callcount2[0] += 1 return self.func(key, cache_context=cache_context) -- cgit 1.4.1 From c0d7d9d6429584f51a8174a331e72a894009f3c8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 19 Aug 2016 15:13:58 +0100 Subject: Rename to on_invalidate --- synapse/storage/push_rule.py | 6 +++--- synapse/util/caches/descriptors.py | 26 ++++++++++---------------- tests/storage/test__base.py | 4 ++-- 3 files changed, 15 insertions(+), 21 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 247dd15694..78334a98cf 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -156,14 +156,14 @@ class PushRuleStore(SQLBaseStore): # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, cache_context=cache_context, + local_users_in_room, on_invalidate=cache_context.invalidate, ) user_ids = set( uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher ) users_with_receipts = yield self.get_users_with_read_receipts_in_room( - room_id, cache_context=cache_context, + room_id, on_invalidate=cache_context.invalidate, ) # any users with pushers must be ours: they have pushers @@ -172,7 +172,7 @@ class PushRuleStore(SQLBaseStore): user_ids.add(uid) rules_by_user = yield self.bulk_get_push_rules( - user_ids, cache_context=cache_context + user_ids, on_invalidate=cache_context.invalidate, ) rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index e7a74d3da8..e93ff40dc0 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -148,8 +148,8 @@ class CacheDescriptor(object): @cachedInlineCallbacks(cache_context=True) def foo(self, key, cache_context): - r1 = yield self.bar1(key, cache_context=cache_context) - r2 = yield self.bar2(key, cache_context=cache_context) + r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) + r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) defer.returnValue(r1 + r2) """ @@ -208,11 +208,7 @@ class CacheDescriptor(object): def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated - cache_context = kwargs.pop("cache_context", None) - if cache_context: - context_callback = cache_context.invalidate - else: - context_callback = None + invalidate_callback = kwargs.pop("on_invalidate", None) # Add our own `cache_context` to argument list if the wrapped function # has asked for one @@ -226,7 +222,7 @@ class CacheDescriptor(object): self_context.key = cache_key try: - cached_result_d = cache.get(cache_key, callback=context_callback) + cached_result_d = cache.get(cache_key, callback=invalidate_callback) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -263,7 +259,7 @@ class CacheDescriptor(object): ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - cache.update(sequence, cache_key, ret, callback=context_callback) + cache.update(sequence, cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) @@ -332,11 +328,9 @@ class CacheListDescriptor(object): @functools.wraps(self.orig) def wrapped(*args, **kwargs): - cache_context = kwargs.pop("cache_context", None) - if cache_context: - context_callback = cache_context.invalidate - else: - context_callback = None + # If we're passed a cache_context then we'll want to call its invalidate() + # whenever we are invalidated + invalidate_callback = kwargs.pop("on_invalidate", None) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] @@ -352,7 +346,7 @@ class CacheListDescriptor(object): key[self.list_pos] = arg try: - res = cache.get(tuple(key), callback=context_callback) + res = cache.get(tuple(key), callback=invalidate_callback) if not res.has_succeeded(): res = res.observe() res.addCallback(lambda r, arg: (arg, r), arg) @@ -388,7 +382,7 @@ class CacheListDescriptor(object): key[self.list_pos] = arg cache.update( sequence, tuple(key), observer, - callback=context_callback + callback=invalidate_callback ) def invalidate(f, key): diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index eab0c8d219..4fc3639de0 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -214,7 +214,7 @@ class CacheDecoratorTestCase(unittest.TestCase): @cached(cache_context=True) def func2(self, key, cache_context): callcount2[0] += 1 - return self.func(key, cache_context=cache_context) + return self.func(key, on_invalidate=cache_context.invalidate) a = A() yield a.func2("foo") @@ -247,7 +247,7 @@ class CacheDecoratorTestCase(unittest.TestCase): @cached(cache_context=True) def func2(self, key, cache_context): callcount2[0] += 1 - return self.func(key, cache_context=cache_context) + return self.func(key, on_invalidate=cache_context.invalidate) a = A() yield a.func2("foo") -- cgit 1.4.1 From 45fd2c8942009a634cf38a90ca1f306aae7022fc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 19 Aug 2016 15:58:52 +0100 Subject: Ensure invalidation list does not grow unboundedly --- synapse/util/caches/descriptors.py | 20 +++++++--------- synapse/util/caches/lrucache.py | 16 ++++++------- tests/storage/test__base.py | 48 ++++++++++++++++++++++++++++++++++++++ tests/util/test_lrucache.py | 40 +++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 20 deletions(-) (limited to 'synapse') diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index e93ff40dc0..8dba61d49f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -25,6 +25,7 @@ from synapse.util.logcontext import ( from . import DEBUG_CACHES, register_cache from twisted.internet import defer +from collections import namedtuple import os import functools @@ -210,16 +211,17 @@ class CacheDescriptor(object): # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) - # Add our own `cache_context` to argument list if the wrapped function - # has asked for one - self_context = _CacheContext(cache, None) + # Add temp cache_context so inspect.getcallargs doesn't explode if self.add_cache_context: - kwargs["cache_context"] = self_context + kwargs["cache_context"] = None arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) - self_context.key = cache_key + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext(cache, cache_key) try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) @@ -414,13 +416,7 @@ class CacheListDescriptor(object): return wrapped -class _CacheContext(object): - __slots__ = ["cache", "key"] - - def __init__(self, cache, key): - self.cache = cache - self.key = key - +class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): def invalidate(self): self.cache.invalidate(self.key) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index a5a827b4d1..9c4c679175 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -32,7 +32,7 @@ def enumerate_leaves(node, depth): class _Node(object): __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] - def __init__(self, prev_node, next_node, key, value, callbacks=[]): + def __init__(self, prev_node, next_node, key, value, callbacks=set()): self.prev_node = prev_node self.next_node = next_node self.key = key @@ -66,7 +66,7 @@ class LruCache(object): return inner - def add_node(key, value, callbacks=[]): + def add_node(key, value, callbacks=set()): prev_node = list_root next_node = prev_node.next_node node = _Node(prev_node, next_node, key, value, callbacks) @@ -94,7 +94,7 @@ class LruCache(object): for cb in node.callbacks: cb() - node.callbacks = [] + node.callbacks.clear() @synchronized def cache_get(key, default=None, callback=None): @@ -102,7 +102,7 @@ class LruCache(object): if node is not None: move_node_to_front(node) if callback: - node.callbacks.append(callback) + node.callbacks.add(callback) return node.value else: return default @@ -114,18 +114,18 @@ class LruCache(object): if value != node.value: for cb in node.callbacks: cb() - node.callbacks = [] + node.callbacks.clear() if callback: - node.callbacks.append(callback) + node.callbacks.add(callback) move_node_to_front(node) node.value = value else: if callback: - callbacks = [callback] + callbacks = set([callback]) else: - callbacks = [] + callbacks = set() add_node(key, value, callbacks) if len(cache) > max_size: todelete = list_root.prev_node diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 4fc3639de0..ab6095564a 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -17,6 +17,8 @@ from tests import unittest from twisted.internet import defer +from mock import Mock + from synapse.util.async import ObservableDeferred from synapse.util.caches.descriptors import Cache, cached @@ -265,3 +267,49 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 4) self.assertEquals(callcount2[0], 3) + + @defer.inlineCallbacks + def test_double_get(self): + callcount = [0] + callcount2 = [0] + + class A(object): + @cached() + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + a.func2.cache.cache = Mock(wraps=a.func2.cache.cache) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 1) + + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 1) + + yield a.func2("foo") + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 2) + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 2) + + a.func.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 3) + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 3) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index bacec2f465..1eba5b535e 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -50,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.setdefault("key", 2), 1) self.assertEquals(cache.get("key"), 1) + cache["key"] = 2 # Make sure overriding works. + self.assertEquals(cache.get("key"), 2) def test_pop(self): cache = LruCache(1) @@ -84,6 +86,44 @@ class LruCacheTestCase(unittest.TestCase): class LruCacheCallbacksTestCase(unittest.TestCase): + def test_get(self): + m = Mock() + cache = LruCache(1) + + cache.set("key", "value") + self.assertFalse(m.called) + + cache.get("key", callback=m) + self.assertFalse(m.called) + + cache.get("key", "value") + self.assertFalse(m.called) + + cache.set("key", "value2") + self.assertEquals(m.call_count, 1) + + cache.set("key", "value") + self.assertEquals(m.call_count, 1) + + def test_multi_get(self): + m = Mock() + cache = LruCache(1) + + cache.set("key", "value") + self.assertFalse(m.called) + + cache.get("key", callback=m) + self.assertFalse(m.called) + + cache.get("key", callback=m) + self.assertFalse(m.called) + + cache.set("key", "value2") + self.assertEquals(m.call_count, 1) + + cache.set("key", "value") + self.assertEquals(m.call_count, 1) + def test_set(self): m = Mock() cache = LruCache(1) -- cgit 1.4.1