diff --git a/README.rst b/README.rst
index d658670835..172dd4dfa0 100644
--- a/README.rst
+++ b/README.rst
@@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
- Python 2.7
-- At least 512 MB RAM.
+- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the
diff --git a/docs/workers.rst b/docs/workers.rst
new file mode 100644
index 0000000000..4eb05b0e59
--- /dev/null
+++ b/docs/workers.rst
@@ -0,0 +1,97 @@
+Scaling synapse via workers
+---------------------------
+
+Synapse has experimental support for splitting out functionality into
+multiple separate python processes, helping greatly with scalability. These
+processes are called 'workers', and are (eventually) intended to scale
+horizontally independently.
+
+All processes continue to share the same database instance, and as such, workers
+only work with postgres based synapse deployments (sharing a single sqlite
+across multiple processes is a recipe for disaster, plus you should be using
+postgres anyway if you care about scalability).
+
+The workers communicate with the master synapse process via a synapse-specific
+HTTP protocol called 'replication' - analogous to MySQL or Postgres style
+database replication; feeding a stream of relevant data to the workers so they
+can be kept in sync with the main synapse process and database state.
+
+To enable workers, you need to add a replication listener to the master synapse, e.g.::
+
+ listeners:
+ - port: 9092
+ bind_address: '127.0.0.1'
+ type: http
+ tls: false
+ x_forwarded: false
+ resources:
+ - names: [replication]
+ compress: false
+
+Under **no circumstances** should this replication API listener be exposed to the
+public internet; it currently implements no authentication whatsoever and is
+unencrypted HTTP.
+
+You then create a set of configs for the various worker processes. These should be
+worker configuration files should be stored in a dedicated subdirectory, to allow
+synctl to manipulate them.
+
+The current available worker applications are:
+ * synapse.app.pusher - handles sending push notifications to sygnal and email
+ * synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
+ * synapse.app.appservice - handles output traffic to Application Services
+ * synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
+ * synapse.app.media_repository - handles the media repository.
+
+Each worker configuration file inherits the configuration of the main homeserver
+configuration file. You can then override configuration specific to that worker,
+e.g. the HTTP listener that it provides (if any); logging configuration; etc.
+You should minimise the number of overrides though to maintain a usable config.
+
+You must specify the type of worker application (worker_app) and the replication
+endpoint that it's talking to on the main synapse process (worker_replication_url).
+
+For instance::
+
+ worker_app: synapse.app.synchrotron
+
+ # The replication listener on the synapse to talk to.
+ worker_replication_url: http://127.0.0.1:9092/_synapse/replication
+
+ worker_listeners:
+ - type: http
+ port: 8083
+ resources:
+ - names:
+ - client
+
+ worker_daemonize: True
+ worker_pid_file: /home/matrix/synapse/synchrotron.pid
+ worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
+
+...is a full configuration for a synchrotron worker instance, which will expose a
+plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
+by the main synapse.
+
+Obviously you should configure your loadbalancer to route the /sync endpoint to
+the synchrotron instance(s) in this instance.
+
+Finally, to actually run your worker-based synapse, you must pass synctl the -a
+commandline option to tell it to operate on all the worker configurations found
+in the given directory, e.g.::
+
+ synctl -a $CONFIG/workers start
+
+Currently one should always restart all workers when restarting or upgrading
+synapse, unless you explicitly know it's safe not to. For instance, restarting
+synapse without restarting all the synchrotrons may result in broken typing
+notifications.
+
+To manipulate a specific worker, you pass the -w option to synctl::
+
+ synctl -w $CONFIG/workers/synchrotron.yaml restart
+
+All of the above is highly experimental and subject to change as Synapse evolves,
+but documenting it here to help folks needing highly scalable Synapses similar
+to the one running matrix.org!
+
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b1b91d0a55..bde9b51b2e 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -81,13 +81,17 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None,
- sender=None, id=None):
+ sender=None, id=None, protocols=None):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
self.namespaces = self._check_namespaces(namespaces)
self.id = id
+ if protocols:
+ self.protocols = set(protocols)
+ else:
+ self.protocols = set()
def _check_namespaces(self, namespaces):
# Sanity check that it is of the form:
@@ -219,6 +223,9 @@ class ApplicationService(object):
or user_id == self.sender
)
+ def is_interested_in_protocol(self, protocol):
+ return protocol in self.protocols
+
def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 6da6a1b62e..066127b666 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
+from synapse.types import ThirdPartyEntityKind
import logging
import urllib
@@ -24,6 +25,28 @@ import urllib
logger = logging.getLogger(__name__)
+def _is_valid_3pe_result(r, field):
+ if not isinstance(r, dict):
+ return False
+
+ for k in (field, "protocol"):
+ if k not in r:
+ return False
+ if not isinstance(r[k], str):
+ return False
+
+ if "fields" not in r:
+ return False
+ fields = r["fields"]
+ if not isinstance(fields, dict):
+ return False
+ for k in fields.keys():
+ if not isinstance(fields[k], str):
+ return False
+
+ return True
+
+
class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and
pushing.
@@ -72,6 +95,43 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False)
@defer.inlineCallbacks
+ def query_3pe(self, service, kind, protocol, fields):
+ if kind == ThirdPartyEntityKind.USER:
+ uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
+ required_field = "userid"
+ elif kind == ThirdPartyEntityKind.LOCATION:
+ uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
+ required_field = "alias"
+ else:
+ raise ValueError(
+ "Unrecognised 'kind' argument %r to query_3pe()", kind
+ )
+
+ try:
+ response = yield self.get_json(uri, fields)
+ if not isinstance(response, list):
+ logger.warning(
+ "query_3pe to %s returned an invalid response %r",
+ uri, response
+ )
+ defer.returnValue([])
+
+ ret = []
+ for r in response:
+ if _is_valid_3pe_result(r, field=required_field):
+ ret.append(r)
+ else:
+ logger.warning(
+ "query_3pe to %s returned an invalid result %r",
+ uri, r
+ )
+
+ defer.returnValue(ret)
+ except Exception as ex:
+ logger.warning("query_3pe to %s threw exception %s", uri, ex)
+ defer.returnValue([])
+
+ @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events)
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 70d28892c6..dfe43b0b4c 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -123,6 +123,15 @@ def _load_appservice(hostname, as_info, config_filename):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
+ # protocols check
+ protocols = as_info.get("protocols")
+ if protocols:
+ # Because strings are lists in python
+ if isinstance(protocols, str) or not isinstance(protocols, list):
+ raise KeyError("Optional 'protocols' must be a list if present.")
+ for p in protocols:
+ if not isinstance(p, str):
+ raise KeyError("Bad value for 'protocols' item")
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
@@ -130,4 +139,5 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
+ protocols=protocols,
)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 6556dd1ae8..dd285452cd 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -160,6 +160,22 @@ class ApplicationServicesHandler(object):
defer.returnValue(result)
@defer.inlineCallbacks
+ def query_3pe(self, kind, protocol, fields):
+ services = yield self._get_services_for_3pn(protocol)
+
+ results = yield defer.DeferredList([
+ self.appservice_api.query_3pe(service, kind, protocol, fields)
+ for service in services
+ ], consumeErrors=True)
+
+ ret = []
+ for (success, result) in results:
+ if success:
+ ret.extend(result)
+
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event.
@@ -188,6 +204,14 @@ class ApplicationServicesHandler(object):
defer.returnValue(interested_list)
@defer.inlineCallbacks
+ def _get_services_for_3pn(self, protocol):
+ services = yield self.store.get_app_services()
+ interested_list = [
+ s for s in services if s.is_interested_in_protocol(protocol)
+ ]
+ defer.returnValue(interested_list)
+
+ @defer.inlineCallbacks
def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index b2c94bfaac..ed2ccc4dfb 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -40,7 +40,7 @@ class ActionGenerator:
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event(
- event, self.hs, self.store, context.current_state
+ event, self.hs, self.store, context.state_group, context.current_state
)
with Measure(self.clock, "action_for_event_by_user"):
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 756e5da513..004eded61f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, current_state):
- room_id = event.room_id
- # We also will want to generate notifs for other people in the room so
- # their unread countss are correct in the event stream, but to avoid
- # generating them for bot / AS users etc, we only do so for people who've
- # sent a read receipt into the room.
-
- local_users_in_room = set(
- e.state_key for e in current_state.values()
- if e.type == EventTypes.Member and e.membership == Membership.JOIN
- and hs.is_mine_id(e.state_key)
+def evaluator_for_event(event, hs, store, state_group, current_state):
+ rules_by_user = yield store.bulk_get_push_rules_for_room(
+ event.room_id, state_group, current_state
)
- # users in the room who have pushers need to get push rules run because
- # that's how their pushers work
- if_users_with_pushers = yield store.get_if_users_have_pushers(
- local_users_in_room
- )
- user_ids = set(
- uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
- )
-
- users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
-
- # any users with pushers must be ours: they have pushers
- for uid in users_with_receipts:
- if uid in local_users_in_room:
- user_ids.add(uid)
-
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state):
if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher:
- user_ids.add(invited_user)
-
- rules_by_user = yield _get_rules(room_id, user_ids, store)
+ rules_by_user[invited_user] = yield store.get_push_rules_for_user(
+ invited_user
+ )
defer.returnValue(BulkPushRuleEvaluator(
- room_id, rules_by_user, user_ids, store
+ event.room_id, rules_by_user, store
))
@@ -90,10 +66,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562)
"""
- def __init__(self, room_id, rules_by_user, users_in_room, store):
+ def __init__(self, room_id, rules_by_user, store):
self.room_id = room_id
self.rules_by_user = rules_by_user
- self.users_in_room = users_in_room
self.store = store
@defer.inlineCallbacks
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 7c23f5a4a8..326780405e 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -48,6 +48,7 @@ from synapse.rest.client.v2_alpha import (
openid,
notifications,
devices,
+ thirdparty,
)
from synapse.http.server import JsonResource
@@ -94,3 +95,4 @@ class ClientRestResource(JsonResource):
openid.register_servlets(hs, client_resource)
notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)
+ thirdparty.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
new file mode 100644
index 0000000000..9abca3a8ad
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet
+from synapse.types import ThirdPartyEntityKind
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class ThirdPartyUserServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
+ releases=())
+
+ def __init__(self, hs):
+ super(ThirdPartyUserServlet, self).__init__()
+
+ self.auth = hs.get_auth()
+ self.appservice_handler = hs.get_application_service_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, protocol):
+ yield self.auth.get_user_by_req(request)
+
+ fields = request.args
+ del fields["access_token"]
+
+ results = yield self.appservice_handler.query_3pe(
+ ThirdPartyEntityKind.USER, protocol, fields
+ )
+
+ defer.returnValue((200, results))
+
+
+class ThirdPartyLocationServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
+ releases=())
+
+ def __init__(self, hs):
+ super(ThirdPartyLocationServlet, self).__init__()
+
+ self.auth = hs.get_auth()
+ self.appservice_handler = hs.get_application_service_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, protocol):
+ yield self.auth.get_user_by_req(request)
+
+ fields = request.args
+ del fields["access_token"]
+
+ results = yield self.appservice_handler.query_3pe(
+ ThirdPartyEntityKind.LOCATION, protocol, fields
+ )
+
+ defer.returnValue((200, results))
+
+
+def register_servlets(hs, http_server):
+ ThirdPartyUserServlet(hs).register(http_server)
+ ThirdPartyLocationServlet(hs).register(http_server)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 029f6612e6..49fa8614f2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -166,7 +166,7 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
- self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
+ self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache(
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index b496b918b7..a854a87eab 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -366,8 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
def get_new_events_for_appservice_txn(txn):
sql = (
"SELECT e.stream_ordering, e.event_id"
- " FROM events AS e, appservice_stream_position AS a"
- " WHERE a.stream_ordering < e.stream_ordering AND e.stream_ordering <= ?"
+ " FROM events AS e"
+ " WHERE"
+ " (SELECT stream_ordering FROM appservice_stream_position)"
+ " < e.stream_ordering"
+ " AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
" LIMIT ?"
)
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 0ba0310c0d..eb15fb751b 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore):
)
self._simple_insert_many_txn(txn, "event_push_actions", values)
- @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
+ @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8183b7f1b0..78334a98cf 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -16,6 +16,7 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules
+from synapse.api.constants import EventTypes, Membership
from twisted.internet import defer
import logging
@@ -48,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore):
- @cachedInlineCallbacks(lru=True)
+ @cachedInlineCallbacks()
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
@@ -72,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules)
- @cachedInlineCallbacks(lru=True)
+ @cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
@@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results)
+ def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
+
+ @cachedInlineCallbacks(num_args=2, cache_context=True)
+ def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
+ cache_context):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ # We also will want to generate notifs for other people in the room so
+ # their unread countss are correct in the event stream, but to avoid
+ # generating them for bot / AS users etc, we only do so for people who've
+ # sent a read receipt into the room.
+ local_users_in_room = set(
+ e.state_key for e in current_state.values()
+ if e.type == EventTypes.Member and e.membership == Membership.JOIN
+ and self.hs.is_mine_id(e.state_key)
+ )
+
+ # users in the room who have pushers need to get push rules run because
+ # that's how their pushers work
+ if_users_with_pushers = yield self.get_if_users_have_pushers(
+ local_users_in_room, on_invalidate=cache_context.invalidate,
+ )
+ user_ids = set(
+ uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
+ )
+
+ users_with_receipts = yield self.get_users_with_read_receipts_in_room(
+ room_id, on_invalidate=cache_context.invalidate,
+ )
+
+ # any users with pushers must be ours: they have pushers
+ for uid in users_with_receipts:
+ if uid in local_users_in_room:
+ user_ids.add(uid)
+
+ rules_by_user = yield self.bulk_get_push_rules(
+ user_ids, on_invalidate=cache_context.invalidate,
+ )
+
+ rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
+
+ defer.returnValue(rules_by_user)
+
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules_enabled(self, user_ids):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index a7d7c54d7e..8f5f8f24a9 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn
)
- @cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
+ @cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id):
result = yield self._simple_select_many_batch(
table='pushers',
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 2c1df0e2b9..ccc3811e84 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res])
- @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
+ @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index ea6823f18d..e1dca927d7 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes"""
- @cached(lru=True)
+ @cached()
def get_event_reference_hash(self, event_id):
return self._get_event_reference_hashes_txn(event_id)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 5b743db67a..0e8fa93e1f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -174,7 +174,7 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f)
- @cached(num_args=2, lru=True, max_entries=1000)
+ @cached(num_args=2, max_entries=1000)
def _get_state_group_from_group(self, group, types):
raise NotImplementedError()
@@ -272,7 +272,7 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
- @cached(num_args=2, lru=True, max_entries=10000)
+ @cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
diff --git a/synapse/types.py b/synapse/types.py
index 5349b0c450..fd17ecbbe0 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
+
+
+# Some arbitrary constants used for internal API enumerations. Don't rely on
+# exact values; always pass or compare symbolically
+class ThirdPartyEntityKind(object):
+ USER = 'user'
+ LOCATION = 'location'
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f31dfb22b7..8dba61d49f 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,8 +25,7 @@ from synapse.util.logcontext import (
from . import DEBUG_CACHES, register_cache
from twisted.internet import defer
-
-from collections import OrderedDict
+from collections import namedtuple
import os
import functools
@@ -54,16 +53,11 @@ class Cache(object):
"metrics",
)
- def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
- if lru:
- cache_type = TreeCache if tree else dict
- self.cache = LruCache(
- max_size=max_entries, keylen=keylen, cache_type=cache_type
- )
- self.max_entries = None
- else:
- self.cache = OrderedDict()
- self.max_entries = max_entries
+ def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+ cache_type = TreeCache if tree else dict
+ self.cache = LruCache(
+ max_size=max_entries, keylen=keylen, cache_type=cache_type
+ )
self.name = name
self.keylen = keylen
@@ -81,8 +75,8 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, key, default=_CacheSentinel):
- val = self.cache.get(key, _CacheSentinel)
+ def get(self, key, default=_CacheSentinel, callback=None):
+ val = self.cache.get(key, _CacheSentinel, callback=callback)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@@ -94,19 +88,15 @@ class Cache(object):
else:
return default
- def update(self, sequence, key, value):
+ def update(self, sequence, key, value, callback=None):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
- self.prefill(key, value)
-
- def prefill(self, key, value):
- if self.max_entries is not None:
- while len(self.cache) >= self.max_entries:
- self.cache.popitem(last=False)
+ self.prefill(key, value, callback=callback)
- self.cache[key] = value
+ def prefill(self, key, value, callback=None):
+ self.cache.set(key, value, callback=callback)
def invalidate(self, key):
self.check_thread()
@@ -151,9 +141,21 @@ class CacheDescriptor(object):
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
+
+ Cached functions can be "chained" (i.e. a cached function can call other cached
+ functions and get appropriately invalidated when they called caches are
+ invalidated) by adding a special "cache_context" argument to the function
+ and passing that as a kwarg to all caches called. For example::
+
+ @cachedInlineCallbacks(cache_context=True)
+ def foo(self, key, cache_context):
+ r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
+ r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
+ defer.returnValue(r1 + r2)
+
"""
- def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
- inlineCallbacks=False):
+ def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+ inlineCallbacks=False, cache_context=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@@ -165,15 +167,33 @@ class CacheDescriptor(object):
self.max_entries = max_entries
self.num_args = num_args
- self.lru = lru
self.tree = tree
- self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
+ all_args = inspect.getargspec(orig)
+ self.arg_names = all_args.args[1:num_args + 1]
+
+ if "cache_context" in all_args.args:
+ if not cache_context:
+ raise ValueError(
+ "Cannot have a 'cache_context' arg without setting"
+ " cache_context=True"
+ )
+ try:
+ self.arg_names.remove("cache_context")
+ except ValueError:
+ pass
+ elif cache_context:
+ raise ValueError(
+ "Cannot have cache_context=True without having an arg"
+ " named `cache_context`"
+ )
+
+ self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwars)"
+ " (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,)
)
@@ -182,16 +202,29 @@ class CacheDescriptor(object):
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
- lru=self.lru,
tree=self.tree,
)
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
+ # If we're passed a cache_context then we'll want to call its invalidate()
+ # whenever we are invalidated
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+
+ # Add temp cache_context so inspect.getcallargs doesn't explode
+ if self.add_cache_context:
+ kwargs["cache_context"] = None
+
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+
+ # Add our own `cache_context` to argument list if the wrapped function
+ # has asked for one
+ if self.add_cache_context:
+ kwargs["cache_context"] = _CacheContext(cache, cache_key)
+
try:
- cached_result_d = cache.get(cache_key)
+ cached_result_d = cache.get(cache_key, callback=invalidate_callback)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@@ -228,7 +261,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
- cache.update(sequence, cache_key, ret)
+ cache.update(sequence, cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@@ -297,6 +330,10 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
+ # If we're passed a cache_context then we'll want to call its invalidate()
+ # whenever we are invalidated
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
@@ -311,7 +348,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg
try:
- res = cache.get(tuple(key))
+ res = cache.get(tuple(key), callback=invalidate_callback)
if not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
@@ -345,7 +382,10 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
- cache.update(sequence, tuple(key), observer)
+ cache.update(
+ sequence, tuple(key), observer,
+ callback=invalidate_callback
+ )
def invalidate(f, key):
cache.invalidate(key)
@@ -376,24 +416,29 @@ class CacheListDescriptor(object):
return wrapped
-def cached(max_entries=1000, num_args=1, lru=True, tree=False):
+class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
+ def invalidate(self):
+ self.cache.invalidate(self.key)
+
+
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
- lru=lru,
tree=tree,
+ cache_context=cache_context,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
- lru=lru,
tree=tree,
inlineCallbacks=True,
+ cache_context=cache_context,
)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index f9df445a8d..9c4c679175 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
class _Node(object):
- __slots__ = ["prev_node", "next_node", "key", "value"]
+ __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
- def __init__(self, prev_node, next_node, key, value):
+ def __init__(self, prev_node, next_node, key, value, callbacks=set()):
self.prev_node = prev_node
self.next_node = next_node
self.key = key
self.value = value
+ self.callbacks = callbacks
class LruCache(object):
@@ -44,6 +45,9 @@ class LruCache(object):
Least-recently-used cache.
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
+
+ Can also set callbacks on objects when getting/setting which are fired
+ when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type()
@@ -62,10 +66,10 @@ class LruCache(object):
return inner
- def add_node(key, value):
+ def add_node(key, value, callbacks=set()):
prev_node = list_root
next_node = prev_node.next_node
- node = _Node(prev_node, next_node, key, value)
+ node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node
next_node.prev_node = node
cache[key] = node
@@ -88,23 +92,41 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
+ for cb in node.callbacks:
+ cb()
+ node.callbacks.clear()
+
@synchronized
- def cache_get(key, default=None):
+ def cache_get(key, default=None, callback=None):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
+ if callback:
+ node.callbacks.add(callback)
return node.value
else:
return default
@synchronized
- def cache_set(key, value):
+ def cache_set(key, value, callback=None):
node = cache.get(key, None)
if node is not None:
+ if value != node.value:
+ for cb in node.callbacks:
+ cb()
+ node.callbacks.clear()
+
+ if callback:
+ node.callbacks.add(callback)
+
move_node_to_front(node)
node.value = value
else:
- add_node(key, value)
+ if callback:
+ callbacks = set([callback])
+ else:
+ callbacks = set()
+ add_node(key, value, callbacks)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
@@ -148,6 +170,9 @@ class LruCache(object):
def cache_clear():
list_root.next_node = list_root
list_root.prev_node = list_root
+ for node in cache.values():
+ for cb in node.callbacks:
+ cb()
cache.clear()
@synchronized
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 03bc1401b7..c31585aea3 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -64,6 +64,9 @@ class TreeCache(object):
self.size -= cnt
return popped
+ def values(self):
+ return [e.value for e in self.root.values()]
+
def __len__(self):
return self.size
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 96b7dba5fe..ab6095564a 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,6 +17,8 @@
from tests import unittest
from twisted.internet import defer
+from mock import Mock
+
from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached
@@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3)
def test_eviction_lru(self):
- cache = Cache("test", max_entries=2, lru=True)
+ cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
@@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
+
+ @defer.inlineCallbacks
+ def test_invalidate_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func.invalidate(("foo",))
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 1)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ @defer.inlineCallbacks
+ def test_eviction_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached(max_entries=2)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+ yield a.func2("foo2")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func("foo3")
+
+ self.assertEquals(callcount[0], 3)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 4)
+ self.assertEquals(callcount2[0], 3)
+
+ @defer.inlineCallbacks
+ def test_double_get(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+ yield a.func2("foo")
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 2)
+
+ a.func.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 3)
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index bab366fb7f..1eba5b535e 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,6 +19,8 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
+from mock import Mock
+
class LruCacheTestCase(unittest.TestCase):
@@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.setdefault("key", 2), 1)
self.assertEquals(cache.get("key"), 1)
+ cache["key"] = 2 # Make sure overriding works.
+ self.assertEquals(cache.get("key"), 2)
def test_pop(self):
cache = LruCache(1)
@@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1
cache.clear()
self.assertEquals(len(cache), 0)
+
+
+class LruCacheCallbacksTestCase(unittest.TestCase):
+ def test_get(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.get("key", "value")
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_multi_get(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_set(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value", m)
+ self.assertFalse(m.called)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_pop(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value", m)
+ self.assertFalse(m.called)
+
+ cache.pop("key")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ cache.pop("key")
+ self.assertEquals(m.call_count, 1)
+
+ def test_del_multi(self):
+ m1 = Mock()
+ m2 = Mock()
+ m3 = Mock()
+ m4 = Mock()
+ cache = LruCache(4, 2, cache_type=TreeCache)
+
+ cache.set(("a", "1"), "value", m1)
+ cache.set(("a", "2"), "value", m2)
+ cache.set(("b", "1"), "value", m3)
+ cache.set(("b", "2"), "value", m4)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+ self.assertEquals(m4.call_count, 0)
+
+ cache.del_multi(("a",))
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 1)
+ self.assertEquals(m3.call_count, 0)
+ self.assertEquals(m4.call_count, 0)
+
+ def test_clear(self):
+ m1 = Mock()
+ m2 = Mock()
+ cache = LruCache(5)
+
+ cache.set("key1", "value", m1)
+ cache.set("key2", "value", m2)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+
+ cache.clear()
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 1)
+
+ def test_eviction(self):
+ m1 = Mock(name="m1")
+ m2 = Mock(name="m2")
+ m3 = Mock(name="m3")
+ cache = LruCache(2)
+
+ cache.set("key1", "value", m1)
+ cache.set("key2", "value", m2)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key3", "value", m3)
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key3", "value")
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.get("key2")
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key1", "value", m1)
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 1)
|