diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 007a0998a7..31e1abb964 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""This module contains classes for authenticating the user."""
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
@@ -42,13 +41,20 @@ AuthEventTypes = (
class Auth(object):
-
+ """
+ FIXME: This class contains a mix of functions for authenticating users
+ of our client-server API and authenticating events added to room graphs.
+ """
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
+ # Docs for these currently lives at
+ # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
+ # In addition, we have type == delete_pusher which grants access only to
+ # delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ",
"guest = ",
@@ -525,7 +531,7 @@ class Auth(object):
return default
@defer.inlineCallbacks
- def get_user_by_req(self, request, allow_guest=False):
+ def get_user_by_req(self, request, allow_guest=False, rights="access"):
""" Get a registered user's ID.
Args:
@@ -547,7 +553,7 @@ class Auth(object):
)
access_token = request.args["access_token"][0]
- user_info = yield self.get_user_by_access_token(access_token)
+ user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
@@ -608,7 +614,7 @@ class Auth(object):
defer.returnValue(user_id)
@defer.inlineCallbacks
- def get_user_by_access_token(self, token):
+ def get_user_by_access_token(self, token, rights="access"):
""" Get a registered user's ID.
Args:
@@ -619,7 +625,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid.
"""
try:
- ret = yield self.get_user_from_macaroon(token)
+ ret = yield self.get_user_from_macaroon(token, rights)
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
@@ -627,11 +633,11 @@ class Auth(object):
defer.returnValue(ret)
@defer.inlineCallbacks
- def get_user_from_macaroon(self, macaroon_str):
+ def get_user_from_macaroon(self, macaroon_str, rights="access"):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
- self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
+ self.validate_macaroon(macaroon, rights, self.hs.config.expire_access_token)
user_prefix = "user_id = "
user = None
@@ -654,6 +660,13 @@ class Auth(object):
"is_guest": True,
"token_id": None,
}
+ elif rights == "delete_pusher":
+ # We don't store these tokens in the database
+ ret = {
+ "user": user,
+ "is_guest": False,
+ "token_id": None,
+ }
else:
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
@@ -685,7 +698,8 @@ class Auth(object):
Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate
- type_string(str): The kind of token this is (e.g. "access", "refresh")
+ type_string(str): The kind of token required (e.g. "access", "refresh",
+ "delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet.
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 135dd58c15..f1de1e7ce9 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -21,6 +21,7 @@ from synapse.config._base import ConfigError
from synapse.config.database import DatabaseConfig
from synapse.config.logger import LoggingConfig
from synapse.config.emailconfig import EmailConfig
+from synapse.config.key import KeyConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.storage.roommember import RoomMemberStore
@@ -63,6 +64,26 @@ class SlaveConfig(DatabaseConfig):
self.pid_file = self.abspath(config.get("pid_file"))
self.public_baseurl = config["public_baseurl"]
+ # some things used by the auth handler but not actually used in the
+ # pusher codebase
+ self.bcrypt_rounds = None
+ self.ldap_enabled = None
+ self.ldap_server = None
+ self.ldap_port = None
+ self.ldap_tls = None
+ self.ldap_search_base = None
+ self.ldap_search_property = None
+ self.ldap_email_property = None
+ self.ldap_full_name_property = None
+
+ # We would otherwise try to use the registration shared secret as the
+ # macaroon shared secret if there was no macaroon_shared_secret, but
+ # that means pulling in RegistrationConfig too. We don't need to be
+ # backwards compaitible in the pusher codebase so just make people set
+ # macaroon_shared_secret. We set this to None to prevent it referencing
+ # an undefined key.
+ self.registration_shared_secret = None
+
def default_config(self, server_name, **kwargs):
pid_file = self.abspath("pusher.pid")
return """\
@@ -95,7 +116,7 @@ class SlaveConfig(DatabaseConfig):
""" % locals()
-class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig):
+class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig, KeyConfig):
pass
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 26c865e171..200793b5ed 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -529,6 +529,11 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
+ def generate_delete_pusher_token(self, user_id):
+ macaroon = self._generate_base_macaroon(user_id)
+ macaroon.add_first_party_caveat("type = delete_pusher")
+ return macaroon.serialize()
+
def validate_short_term_login_token_and_get_user_id(self, login_token):
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c41dafdef5..15caf1950a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -26,9 +26,9 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
)
from synapse.util import unwrapFirstError
-from synapse.util.async import concurrently_execute
+from synapse.util.async import concurrently_execute, run_on_reactor
from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
+from synapse.util.logcontext import preserve_fn
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -908,13 +908,16 @@ class MessageHandler(BaseHandler):
"Failed to get destination from event %s", s.event_id
)
- with PreserveLoggingContext():
- # Don't block waiting on waking up all the listeners.
+ @defer.inlineCallbacks
+ def _notify():
+ yield run_on_reactor()
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
+ preserve_fn(_notify)()
+
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5307b62b85..be26a491ff 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -198,9 +198,8 @@ class SyncHandler(object):
@defer.inlineCallbacks
def push_rules_for_user(self, user):
user_id = user.to_string()
- rawrules = yield self.store.get_push_rules_for_user(user_id)
- enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
- rules = format_push_rules_for_user(user, rawrules, enabled_map)
+ rules = yield self.store.get_push_rules_for_user(user_id)
+ rules = format_push_rules_for_user(user, rules)
defer.returnValue(rules)
@defer.inlineCallbacks
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 5664d5a381..c38f24485a 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -33,11 +33,7 @@ from .metric import (
logger = logging.getLogger(__name__)
-# We'll keep all the available metrics in a single toplevel dict, one shared
-# for the entire process. We don't currently support per-HomeServer instances
-# of metrics, because in practice any one python VM will host only one
-# HomeServer anyway. This makes a lot of implementation neater
-all_metrics = {}
+all_metrics = []
class Metrics(object):
@@ -53,7 +49,7 @@ class Metrics(object):
metric = metric_class(full_name, *args, **kwargs)
- all_metrics[full_name] = metric
+ all_metrics.append(metric)
return metric
def register_counter(self, *args, **kwargs):
@@ -84,12 +80,12 @@ def render_all():
# TODO(paul): Internal hack
update_resource_metrics()
- for name in sorted(all_metrics.keys()):
+ for metric in all_metrics:
try:
- strs += all_metrics[name].render()
+ strs += metric.render()
except Exception:
- strs += ["# FAILED to render %s" % name]
- logger.exception("Failed to render %s metric", name)
+ strs += ["# FAILED to render"]
+ logger.exception("Failed to render metric")
strs.append("") # to generate a final CRLF
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index 368fc24984..341043952a 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -47,9 +47,6 @@ class BaseMetric(object):
for k, v in zip(self.labels, values)])
)
- def render(self):
- return map_concat(self.render_item, sorted(self.counts.keys()))
-
class CounterMetric(BaseMetric):
"""The simplest kind of metric; one that stores a monotonically-increasing
@@ -83,6 +80,9 @@ class CounterMetric(BaseMetric):
def render_item(self, k):
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
+ def render(self):
+ return map_concat(self.render_item, sorted(self.counts.keys()))
+
class CallbackMetric(BaseMetric):
"""A metric that returns the numeric value returned by a callback whenever
@@ -126,30 +126,30 @@ class DistributionMetric(object):
class CacheMetric(object):
- """A combination of two CounterMetrics, one to count cache hits and one to
- count a total, and a callback metric to yield the current size.
+ __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
- This metric generates standard metric name pairs, so that monitoring rules
- can easily be applied to measure hit ratio."""
-
- def __init__(self, name, size_callback, labels=[]):
+ def __init__(self, name, size_callback, cache_name):
self.name = name
+ self.cache_name = cache_name
- self.hits = CounterMetric(name + ":hits", labels=labels)
- self.total = CounterMetric(name + ":total", labels=labels)
+ self.hits = 0
+ self.misses = 0
- self.size = CallbackMetric(
- name + ":size",
- callback=size_callback,
- labels=labels,
- )
+ self.size_callback = size_callback
- def inc_hits(self, *values):
- self.hits.inc(*values)
- self.total.inc(*values)
+ def inc_hits(self):
+ self.hits += 1
- def inc_misses(self, *values):
- self.total.inc(*values)
+ def inc_misses(self):
+ self.misses += 1
def render(self):
- return self.hits.render() + self.total.render() + self.size.render()
+ size = self.size_callback()
+ hits = self.hits
+ total = self.misses + self.hits
+
+ return [
+ """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
+ """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
+ """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
+ ]
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 9b208668b6..46e768e35c 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -40,7 +40,7 @@ class ActionGenerator:
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "handle_push_actions_for_event"):
bulk_evaluator = yield evaluator_for_event(
- event, self.hs, self.store
+ event, self.hs, self.store, context.current_state
)
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 25f2fb9da4..6e42121b1d 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -18,10 +18,9 @@ import ujson as json
from twisted.internet import defer
-from .baserules import list_with_base_rules
from .push_rule_evaluator import PushRuleEvaluatorForEvent
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.visibility import filter_events_for_clients
@@ -38,62 +37,41 @@ def decode_rule_json(rule):
@defer.inlineCallbacks
def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids)
- rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
- rules_by_user = {
- uid: list_with_base_rules([
- decode_rule_json(rule_list)
- for rule_list in rules_by_user.get(uid, [])
- ])
- for uid in user_ids
- }
-
- # We apply the rules-enabled map here: bulk_get_push_rules doesn't
- # fetch disabled rules, but this won't account for any server default
- # rules the user has disabled, so we need to do this too.
- for uid in user_ids:
- user_enabled_map = rules_enabled_by_user.get(uid)
- if not user_enabled_map:
- continue
-
- for i, rule in enumerate(rules_by_user[uid]):
- rule_id = rule['rule_id']
-
- if rule_id in user_enabled_map:
- if rule.get('enabled', True) != bool(user_enabled_map[rule_id]):
- # Rules are cached across users.
- rule = dict(rule)
- rule['enabled'] = bool(user_enabled_map[rule_id])
- rules_by_user[uid][i] = rule
-
defer.returnValue(rules_by_user)
@defer.inlineCallbacks
-def evaluator_for_event(event, hs, store):
+def evaluator_for_event(event, hs, store, current_state):
room_id = event.room_id
-
- # users in the room who have pushers need to get push rules run because
- # that's how their pushers work
- users_with_pushers = yield store.get_users_with_pushers_in_room(room_id)
-
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
- all_in_room = yield store.get_users_in_room(room_id)
- all_in_room = set(all_in_room)
+ local_users_in_room = set(
+ e.state_key for e in current_state.values()
+ if e.type == EventTypes.Member and e.membership == Membership.JOIN
+ and hs.is_mine_id(e.state_key)
+ )
- receipts = yield store.get_receipts_for_room(room_id, "m.read")
+ # users in the room who have pushers need to get push rules run because
+ # that's how their pushers work
+ if_users_with_pushers = yield store.get_if_users_have_pushers(
+ local_users_in_room
+ )
+ user_ids = set(
+ uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
+ )
+
+ users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers
- user_ids = set(users_with_pushers)
- for r in receipts:
- if hs.is_mine_id(r['user_id']) and r['user_id'] in all_in_room:
- user_ids.add(r['user_id'])
+ for uid in users_with_receipts:
+ if uid in local_users_in_room:
+ user_ids.add(uid)
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
@@ -104,8 +82,6 @@ def evaluator_for_event(event, hs, store):
if has_pusher:
user_ids.add(invited_user)
- user_ids = list(user_ids)
-
rules_by_user = yield _get_rules(room_id, user_ids, store)
defer.returnValue(BulkPushRuleEvaluator(
@@ -143,7 +119,10 @@ class BulkPushRuleEvaluator:
self.store, user_tuples, [event], {event.event_id: current_state}
)
- room_members = yield self.store.get_users_in_room(self.room_id)
+ room_members = set(
+ e.state_key for e in current_state.values()
+ if e.type == EventTypes.Member and e.membership == Membership.JOIN
+ )
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index ae9db9ec2f..b3983f7940 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -23,10 +23,7 @@ import copy
import simplejson as json
-def format_push_rules_for_user(user, rawrules, enabled_map):
- """Converts a list of rawrules and a enabled map into nested dictionaries
- to match the Matrix client-server format for push rules"""
-
+def load_rules_for_user(user, rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
@@ -35,7 +32,26 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
- ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
+ rules = list(list_with_base_rules(ruleslist))
+
+ for i, rule in enumerate(rules):
+ rule_id = rule['rule_id']
+ if rule_id in enabled_map:
+ if rule.get('enabled', True) != bool(enabled_map[rule_id]):
+ # Rules are cached across users.
+ rule = dict(rule)
+ rule['enabled'] = bool(enabled_map[rule_id])
+ rules[i] = rule
+
+ return rules
+
+
+def format_push_rules_for_user(user, ruleslist):
+ """Converts a list of rawrules and a enabled map into nested dictionaries
+ to match the Matrix client-server format for push rules"""
+
+ # We're going to be mutating this a lot, so do a deep copy
+ ruleslist = copy.deepcopy(ruleslist)
rules = {'global': {}, 'device': {}}
@@ -60,9 +76,7 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
template_rule = _rule_to_template(r)
if template_rule:
- if r['rule_id'] in enabled_map:
- template_rule['enabled'] = enabled_map[r['rule_id']]
- elif 'enabled' in r:
+ if 'enabled' in r:
template_rule['enabled'] = r['enabled']
else:
template_rule['enabled'] = True
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 2c21ed3088..12a3ec7fd8 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -279,5 +279,5 @@ class EmailPusher(object):
logger.info("Sending notif email for user %r", self.user_id)
yield self.mailer.send_notification_mail(
- self.user_id, self.email, push_actions, reason
+ self.app_id, self.user_id, self.email, push_actions, reason
)
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index c1e9057eb6..88402e42a6 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
- "in the %s room..."
+ "in the %(room)s room..."
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
@@ -81,6 +81,7 @@ class Mailer(object):
def __init__(self, hs, app_name):
self.hs = hs
self.store = self.hs.get_datastore()
+ self.auth_handler = self.hs.get_auth_handler()
self.state_handler = self.hs.get_state_handler()
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name
@@ -96,7 +97,8 @@ class Mailer(object):
)
@defer.inlineCallbacks
- def send_notification_mail(self, user_id, email_address, push_actions, reason):
+ def send_notification_mail(self, app_id, user_id, email_address,
+ push_actions, reason):
try:
from_string = self.hs.config.email_notif_from % {
"app": self.app_name
@@ -167,7 +169,9 @@ class Mailer(object):
template_vars = {
"user_display_name": user_display_name,
- "unsubscribe_link": self.make_unsubscribe_link(),
+ "unsubscribe_link": self.make_unsubscribe_link(
+ user_id, app_id, email_address
+ ),
"summary_text": summary_text,
"app_name": self.app_name,
"rooms": rooms,
@@ -433,9 +437,18 @@ class Mailer(object):
notif['room_id'], notif['event_id']
)
- def make_unsubscribe_link(self):
- # XXX: matrix.to
- return "https://vector.im/#/settings"
+ def make_unsubscribe_link(self, user_id, app_id, email_address):
+ params = {
+ "access_token": self.auth_handler.generate_delete_pusher_token(user_id),
+ "app_id": app_id,
+ "pushkey": email_address,
+ }
+
+ # XXX: make r0 once API is stable
+ return "%s_matrix/client/unstable/pushers/remove?%s" % (
+ self.hs.config.public_baseurl,
+ urllib.urlencode(params),
+ )
def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index f59b0eabbc..735c03c7eb 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -15,7 +15,10 @@
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage import DataStore
from synapse.storage.account_data import AccountDataStore
+from synapse.storage.tags import TagsStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedAccountDataStore(BaseSlavedStore):
@@ -25,6 +28,14 @@ class SlavedAccountDataStore(BaseSlavedStore):
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id",
)
+ self._account_data_stream_cache = StreamChangeCache(
+ "AccountDataAndTagsChangeCache",
+ self._account_data_id_gen.get_current_token(),
+ )
+
+ get_account_data_for_user = (
+ AccountDataStore.__dict__["get_account_data_for_user"]
+ )
get_global_account_data_by_type_for_users = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
@@ -34,6 +45,16 @@ class SlavedAccountDataStore(BaseSlavedStore):
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
)
+ get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
+
+ get_updated_tags = DataStore.get_updated_tags.__func__
+ get_updated_account_data_for_user = (
+ DataStore.get_updated_account_data_for_user.__func__
+ )
+
+ def get_max_account_data_stream_id(self):
+ return self._account_data_id_gen.get_current_token()
+
def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
@@ -47,15 +68,33 @@ class SlavedAccountDataStore(BaseSlavedStore):
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
- user_id, data_type = row[1:3]
+ position, user_id, data_type = row[:3]
self.get_global_account_data_by_type_for_user.invalidate(
(data_type, user_id,)
)
+ self.get_account_data_for_user.invalidate((user_id,))
+ self._account_data_stream_cache.entity_has_changed(
+ user_id, position
+ )
stream = result.get("room_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
+ for row in stream["rows"]:
+ position, user_id = row[:2]
+ self.get_account_data_for_user.invalidate((user_id,))
+ self._account_data_stream_cache.entity_has_changed(
+ user_id, position
+ )
stream = result.get("tag_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
+ for row in stream["rows"]:
+ position, user_id = row[:2]
+ self.get_tags_for_user.invalidate((user_id,))
+ self._account_data_stream_cache.entity_has_changed(
+ user_id, position
+ )
+
+ return super(SlavedAccountDataStore, self).process_replication(result)
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
new file mode 100644
index 0000000000..25792d9429
--- /dev/null
+++ b/synapse/replication/slave/storage/appservice.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from synapse.storage import DataStore
+from synapse.config.appservice import load_appservices
+
+
+class SlavedApplicationServiceStore(BaseSlavedStore):
+ def __init__(self, db_conn, hs):
+ super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
+ self.services_cache = load_appservices(
+ hs.config.server_name,
+ hs.config.app_service_config_files
+ )
+
+ get_app_service_by_token = DataStore.get_app_service_by_token.__func__
+ get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index c0d741452d..cbc1ae4190 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -23,6 +23,7 @@ from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore
+from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
@@ -57,6 +58,9 @@ class SlavedEventStore(BaseSlavedStore):
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
+ self._membership_stream_cache = StreamChangeCache(
+ "MembershipStreamChangeCache", events_max,
+ )
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
@@ -87,6 +91,9 @@ class SlavedEventStore(BaseSlavedStore):
_get_state_group_from_group = (
StateStore.__dict__["_get_state_group_from_group"]
)
+ get_recent_event_ids_for_room = (
+ StreamStore.__dict__["get_recent_event_ids_for_room"]
+ )
get_unread_push_actions_for_user_in_range = (
DataStore.get_unread_push_actions_for_user_in_range.__func__
@@ -109,10 +116,16 @@ class SlavedEventStore(BaseSlavedStore):
DataStore.get_room_events_stream_for_room.__func__
)
get_events_around = DataStore.get_events_around.__func__
+ get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__
+ get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
+ get_room_events_stream_for_rooms = (
+ DataStore.get_room_events_stream_for_rooms.__func__
+ )
+ get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
- _set_before_and_after = DataStore._set_before_and_after
+ _set_before_and_after = staticmethod(DataStore._set_before_and_after)
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
@@ -220,9 +233,9 @@ class SlavedEventStore(BaseSlavedStore):
self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,))
- # self._membership_stream_cache.entity_has_changed(
- # event.state_key, event.internal_metadata.stream_ordering
- # )
+ self._membership_stream_cache.entity_has_changed(
+ event.state_key, event.internal_metadata.stream_ordering
+ )
self.get_invited_rooms_for_user.invalidate((event.state_key,))
if not event.is_state():
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
new file mode 100644
index 0000000000..819ed62881
--- /dev/null
+++ b/synapse/replication/slave/storage/filtering.py
@@ -0,0 +1,25 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from synapse.storage.filtering import FilteringStore
+
+
+class SlavedFilteringStore(BaseSlavedStore):
+ def __init__(self, db_conn, hs):
+ super(SlavedFilteringStore, self).__init__(db_conn, hs)
+
+ # Filters are immutable so this cache doesn't need to be expired
+ get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
new file mode 100644
index 0000000000..703f4a49bf
--- /dev/null
+++ b/synapse/replication/slave/storage/presence.py
@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.storage import DataStore
+
+
+class SlavedPresenceStore(BaseSlavedStore):
+ def __init__(self, db_conn, hs):
+ super(SlavedPresenceStore, self).__init__(db_conn, hs)
+ self._presence_id_gen = SlavedIdTracker(
+ db_conn, "presence_stream", "stream_id",
+ )
+
+ self._presence_on_startup = self._get_active_presence(db_conn)
+
+ self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
+ "PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
+ )
+
+ _get_active_presence = DataStore._get_active_presence.__func__
+ take_presence_startup_info = DataStore.take_presence_startup_info.__func__
+ get_presence_for_users = DataStore.get_presence_for_users.__func__
+
+ def get_current_presence_token(self):
+ return self._presence_id_gen.get_current_token()
+
+ def stream_positions(self):
+ result = super(SlavedPresenceStore, self).stream_positions()
+ position = self._presence_id_gen.get_current_token()
+ result["presence"] = position
+ return result
+
+ def process_replication(self, result):
+ stream = result.get("presence")
+ if stream:
+ self._presence_id_gen.advance(int(stream["position"]))
+ for row in stream["rows"]:
+ position, user_id = row[:2]
+ self.presence_stream_cache.entity_has_changed(
+ user_id, position
+ )
+
+ return super(SlavedPresenceStore, self).process_replication(result)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
new file mode 100644
index 0000000000..21ceb0213a
--- /dev/null
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .events import SlavedEventStore
+from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage import DataStore
+from synapse.storage.push_rule import PushRuleStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+
+class SlavedPushRuleStore(SlavedEventStore):
+ def __init__(self, db_conn, hs):
+ super(SlavedPushRuleStore, self).__init__(db_conn, hs)
+ self._push_rules_stream_id_gen = SlavedIdTracker(
+ db_conn, "push_rules_stream", "stream_id",
+ )
+ self.push_rules_stream_cache = StreamChangeCache(
+ "PushRulesStreamChangeCache",
+ self._push_rules_stream_id_gen.get_current_token(),
+ )
+
+ get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
+ get_push_rules_enabled_for_user = (
+ PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
+ )
+ have_push_rules_changed_for_user = (
+ DataStore.have_push_rules_changed_for_user.__func__
+ )
+
+ def get_push_rules_stream_token(self):
+ return (
+ self._push_rules_stream_id_gen.get_current_token(),
+ self._stream_id_gen.get_current_token(),
+ )
+
+ def stream_positions(self):
+ result = super(SlavedPushRuleStore, self).stream_positions()
+ result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
+ return result
+
+ def process_replication(self, result):
+ stream = result.get("push_rules")
+ if stream:
+ for row in stream["rows"]:
+ position = row[0]
+ user_id = row[2]
+ self.get_push_rules_for_user.invalidate((user_id,))
+ self.get_push_rules_enabled_for_user.invalidate((user_id,))
+ self.push_rules_stream_cache.entity_has_changed(
+ user_id, position
+ )
+
+ self._push_rules_stream_id_gen.advance(int(stream["position"]))
+
+ return super(SlavedPushRuleStore, self).process_replication(result)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index ec007516d0..ac9662d399 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.receipts import ReceiptsStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
@@ -37,11 +38,28 @@ class SlavedReceiptsStore(BaseSlavedStore):
db_conn, "receipts_linearized", "stream_id"
)
+ self._receipts_stream_cache = StreamChangeCache(
+ "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+ )
+
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
+ get_linearized_receipts_for_room = (
+ ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
+ )
+ _get_linearized_receipts_for_rooms = (
+ ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
+ )
+ get_last_receipt_event_id_for_user = (
+ ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
+ )
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
+ get_linearized_receipts_for_rooms = (
+ DataStore.get_linearized_receipts_for_rooms.__func__
+ )
+
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
@@ -52,10 +70,15 @@ class SlavedReceiptsStore(BaseSlavedStore):
if stream:
self._receipts_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
- room_id, receipt_type, user_id = row[1:4]
+ position, room_id, receipt_type, user_id = row[:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
+ self._receipts_stream_cache.entity_has_changed(room_id, position)
return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
+ self.get_linearized_receipts_for_room.invalidate_many((room_id,))
+ self.get_last_receipt_event_id_for_user.invalidate(
+ (user_id, room_id, receipt_type)
+ )
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
new file mode 100644
index 0000000000..307833f9e1
--- /dev/null
+++ b/synapse/replication/slave/storage/registration.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from synapse.storage import DataStore
+from synapse.storage.registration import RegistrationStore
+
+
+class SlavedRegistrationStore(BaseSlavedStore):
+ def __init__(self, db_conn, hs):
+ super(SlavedRegistrationStore, self).__init__(db_conn, hs)
+
+ # TODO: use the cached version and invalidate deleted tokens
+ get_user_by_access_token = RegistrationStore.__dict__[
+ "get_user_by_access_token"
+ ].orig
+
+ _query_for_auth = DataStore._query_for_auth.__func__
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 02d837ee6a..6bb4821ec6 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
- rawrules = yield self.store.get_push_rules_for_user(user_id)
+ rules = yield self.store.get_push_rules_for_user(user_id)
- enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
-
- rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
+ rules = format_push_rules_for_user(requester.user, rules)
path = request.postpath[1:]
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index ab928a16da..9a2ed6ed88 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -17,7 +17,11 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import (
+ parse_json_object_from_request, parse_string, RestServlet
+)
+from synapse.http.server import finish_request
+from synapse.api.errors import StoreError
from .base import ClientV1RestServlet, client_path_patterns
@@ -136,6 +140,57 @@ class PushersSetRestServlet(ClientV1RestServlet):
return 200, {}
+class PushersRemoveRestServlet(RestServlet):
+ """
+ To allow pusher to be delete by clicking a link (ie. GET request)
+ """
+ PATTERNS = client_path_patterns("/pushers/remove$")
+ SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
+
+ def __init__(self, hs):
+ super(RestServlet, self).__init__()
+ self.hs = hs
+ self.notifier = hs.get_notifier()
+ self.auth = hs.get_v1auth()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
+ user = requester.user
+
+ app_id = parse_string(request, "app_id", required=True)
+ pushkey = parse_string(request, "pushkey", required=True)
+
+ pusher_pool = self.hs.get_pusherpool()
+
+ try:
+ yield pusher_pool.remove_pusher(
+ app_id=app_id,
+ pushkey=pushkey,
+ user_id=user.to_string(),
+ )
+ except StoreError as se:
+ if se.code != 404:
+ # This is fine: they're already unsubscribed
+ raise
+
+ self.notifier.on_new_replication_data()
+
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Server", self.hs.version_string)
+ request.setHeader(b"Content-Length", b"%d" % (
+ len(PushersRemoveRestServlet.SUCCESS_HTML),
+ ))
+ request.write(PushersRemoveRestServlet.SUCCESS_HTML)
+ finish_request(request)
+ defer.returnValue(None)
+
+ def on_OPTIONS(self, _):
+ return 200, {}
+
+
def register_servlets(hs, http_server):
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)
+ PushersRemoveRestServlet(hs).register(http_server)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 8581796b7e..6928a213e8 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -149,7 +149,7 @@ class DataStore(RoomMemberStore, RoomStore,
"AccountDataAndTagsChangeCache", account_max,
)
- self.__presence_on_startup = self._get_active_presence(db_conn)
+ self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream",
@@ -190,8 +190,8 @@ class DataStore(RoomMemberStore, RoomStore,
super(DataStore, self).__init__(hs)
def take_presence_startup_info(self):
- active_on_startup = self.__presence_on_startup
- self.__presence_on_startup = None
+ active_on_startup = self._presence_on_startup
+ self._presence_on_startup = None
return active_on_startup
def _get_active_presence(self, db_conn):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 4655669ba0..2b3f79577b 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -342,9 +342,6 @@ class EventsStore(SQLBaseStore):
txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
- txn.call_after(
- self.get_users_with_pushers_in_room.invalidate, (event.room_id,)
- )
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index ebb97c8474..786d6f6d67 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -15,6 +15,7 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.push.baserules import list_with_base_rules
from twisted.internet import defer
import logging
@@ -23,6 +24,29 @@ import simplejson as json
logger = logging.getLogger(__name__)
+def _load_rules(rawrules, enabled_map):
+ ruleslist = []
+ for rawrule in rawrules:
+ rule = dict(rawrule)
+ rule["conditions"] = json.loads(rawrule["conditions"])
+ rule["actions"] = json.loads(rawrule["actions"])
+ ruleslist.append(rule)
+
+ # We're going to be mutating this a lot, so do a deep copy
+ rules = list(list_with_base_rules(ruleslist))
+
+ for i, rule in enumerate(rules):
+ rule_id = rule['rule_id']
+ if rule_id in enabled_map:
+ if rule.get('enabled', True) != bool(enabled_map[rule_id]):
+ # Rules are cached across users.
+ rule = dict(rule)
+ rule['enabled'] = bool(enabled_map[rule_id])
+ rules[i] = rule
+
+ return rules
+
+
class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks(lru=True)
def get_push_rules_for_user(self, user_id):
@@ -42,7 +66,11 @@ class PushRuleStore(SQLBaseStore):
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
)
- defer.returnValue(rows)
+ enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+
+ rules = _load_rules(rows, enabled_map)
+
+ defer.returnValue(rules)
@cachedInlineCallbacks(lru=True)
def get_push_rules_enabled_for_user(self, user_id):
@@ -85,6 +113,14 @@ class PushRuleStore(SQLBaseStore):
for row in rows:
results.setdefault(row['user_name'], []).append(row)
+
+ enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+
+ for user_id, rules in results.items():
+ results[user_id] = _load_rules(
+ rules, enabled_map_by_user.get(user_id, {})
+ )
+
defer.returnValue(results)
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 9e8e2e2964..a7d7c54d7e 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
from canonicaljson import encode_canonical_json
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
import logging
import simplejson as json
@@ -135,19 +135,35 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn
)
- @cachedInlineCallbacks(num_args=1)
- def get_users_with_pushers_in_room(self, room_id):
- users = yield self.get_users_in_room(room_id)
-
+ @cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
+ def get_if_user_has_pusher(self, user_id):
result = yield self._simple_select_many_batch(
table='pushers',
+ keyvalues={
+ 'user_name': 'user_id',
+ },
+ retcol='user_name',
+ desc='get_if_user_has_pusher',
+ allow_none=True,
+ )
+
+ defer.returnValue(bool(result))
+
+ @cachedList(cached_method_name="get_if_user_has_pusher",
+ list_name="user_ids", num_args=1, inlineCallbacks=True)
+ def get_if_users_have_pushers(self, user_ids):
+ rows = yield self._simple_select_many_batch(
+ table='pushers',
column='user_name',
- iterable=users,
+ iterable=user_ids,
retcols=['user_name'],
- desc='get_users_with_pushers_in_room'
+ desc='get_if_users_have_pushers'
)
- defer.returnValue([r['user_name'] for r in result])
+ result = {user_id: False for user_id in user_ids}
+ result.update({r['user_name']: True for r in rows})
+
+ defer.returnValue(result)
@defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id,
@@ -178,16 +194,16 @@ class PusherStore(SQLBaseStore):
},
)
if newly_inserted:
- # get_users_with_pushers_in_room only cares if the user has
+ # get_if_user_has_pusher only cares if the user has
# at least *one* pusher.
- txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
+ txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
yield self.runInteraction("add_pusher", f)
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
def delete_pusher_txn(txn, stream_id):
- txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
+ txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
self._simple_delete_one_txn(
txn,
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index f1774f0e44..8c26f39fbb 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -34,6 +34,26 @@ class ReceiptsStore(SQLBaseStore):
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
+ @cachedInlineCallbacks()
+ def get_users_with_read_receipts_in_room(self, room_id):
+ receipts = yield self.get_receipts_for_room(room_id, "m.read")
+ defer.returnValue(set(r['user_id'] for r in receipts))
+
+ def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+ user_id):
+ if receipt_type != "m.read":
+ return
+
+ # Returns an ObservableDeferred
+ res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
+
+ if res and res.called and user_id in res.result:
+ # We'd only be adding to the set, so no point invalidating if the
+ # user is already there
+ return
+
+ self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
@@ -229,6 +249,10 @@ class ReceiptsStore(SQLBaseStore):
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
txn.call_after(
+ self._invalidate_get_users_with_receipts_in_room,
+ room_id, receipt_type, user_id,
+ )
+ txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
@@ -374,6 +398,10 @@ class ReceiptsStore(SQLBaseStore):
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
txn.call_after(
+ self._invalidate_get_users_with_receipts_in_room,
+ room_id, receipt_type, user_id,
+ )
+ txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index face685ed2..64b4bd371b 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -59,9 +59,6 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
- self.get_users_with_pushers_in_room.invalidate, (event.room_id,)
- )
- txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
)
@@ -241,23 +238,10 @@ class RoomMemberStore(SQLBaseStore):
return results
- @cached(max_entries=5000)
+ @cachedInlineCallbacks(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):
- return self.runInteraction(
- "get_joined_hosts_for_room",
- self._get_joined_hosts_for_room_txn,
- room_id,
- )
-
- def _get_joined_hosts_for_room_txn(self, txn, room_id):
- rows = self._get_members_rows_txn(
- txn,
- room_id, membership=Membership.JOIN
- )
-
- joined_domains = set(get_domain_from_id(r["user_id"]) for r in rows)
-
- return joined_domains
+ user_ids = yield self.get_users_in_room(room_id)
+ defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
rows = self._get_members_rows_txn(
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 0d6f48e2d8..40be7fe7e3 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -102,6 +102,15 @@ class ObservableDeferred(object):
def observers(self):
return self._observers
+ def has_called(self):
+ return self._result is not None
+
+ def has_succeeded(self):
+ return self._result is not None and self._result[0] is True
+
+ def get_result(self):
+ return self._result[1]
+
def __getattr__(self, name):
return getattr(self._deferred, name)
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index d53569ca49..ebd715c5dc 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -24,11 +24,21 @@ DEBUG_CACHES = False
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
caches_by_name = {}
-cache_counter = metrics.register_cache(
- "cache",
- lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
- labels=["name"],
-)
+# cache_counter = metrics.register_cache(
+# "cache",
+# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
+# labels=["name"],
+# )
+
+
+def register_cache(name, cache):
+ caches_by_name[name] = cache
+ return metrics.register_cache(
+ "cache",
+ lambda: len(cache),
+ name,
+ )
+
_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
caches_by_name["string_cache"] = _string_cache
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 758f5982b0..f31dfb22b7 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -22,7 +22,7 @@ from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
-from . import caches_by_name, DEBUG_CACHES, cache_counter
+from . import DEBUG_CACHES, register_cache
from twisted.internet import defer
@@ -33,6 +33,7 @@ import functools
import inspect
import threading
+
logger = logging.getLogger(__name__)
@@ -43,6 +44,15 @@ CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class Cache(object):
+ __slots__ = (
+ "cache",
+ "max_entries",
+ "name",
+ "keylen",
+ "sequence",
+ "thread",
+ "metrics",
+ )
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
if lru:
@@ -59,7 +69,7 @@ class Cache(object):
self.keylen = keylen
self.sequence = 0
self.thread = None
- caches_by_name[name] = self.cache
+ self.metrics = register_cache(name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -74,10 +84,10 @@ class Cache(object):
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return val
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
if default is _CacheSentinel:
raise KeyError()
@@ -293,16 +303,21 @@ class CacheListDescriptor(object):
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
- cached = {}
+ results = {}
+ cached_defers = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
try:
- res = cache.get(tuple(key)).observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
- cached[arg] = res
+ res = cache.get(tuple(key))
+ if not res.has_succeeded():
+ res = res.observe()
+ res.addCallback(lambda r, arg: (arg, r), arg)
+ cached_defers[arg] = res
+ else:
+ results[arg] = res.get_result()
except KeyError:
missing.append(arg)
@@ -340,12 +355,21 @@ class CacheListDescriptor(object):
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
- cached[arg] = res
-
- return preserve_context_over_deferred(defer.gatherResults(
- cached.values(),
- consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
+ cached_defers[arg] = res
+
+ if cached_defers:
+ def update_results_dict(res):
+ results.update(res)
+ return results
+
+ return preserve_context_over_deferred(defer.gatherResults(
+ cached_defers.values(),
+ consumeErrors=True,
+ ).addCallback(update_results_dict).addErrback(
+ unwrapFirstError
+ ))
+ else:
+ return results
obj.__dict__[self.orig.__name__] = wrapped
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index f92d80542b..b0ca1bb79d 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -15,7 +15,7 @@
from synapse.util.caches.lrucache import LruCache
from collections import namedtuple
-from . import caches_by_name, cache_counter
+from . import register_cache
import threading
import logging
@@ -43,7 +43,7 @@ class DictionaryCache(object):
__slots__ = []
self.sentinel = Sentinel()
- caches_by_name[name] = self.cache
+ self.metrics = register_cache(name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -58,7 +58,7 @@ class DictionaryCache(object):
def get(self, key, dict_keys=None):
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
if dict_keys is None:
return DictionaryEntry(entry.full, dict(entry.value))
@@ -69,7 +69,7 @@ class DictionaryCache(object):
if k in entry.value
})
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return DictionaryEntry(False, {})
def invalidate(self, key):
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2b68c1ac93..080388958f 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches import register_cache
import logging
@@ -49,7 +49,7 @@ class ExpiringCache(object):
self._cache = {}
- caches_by_name[cache_name] = self._cache
+ self.metrics = register_cache(cache_name, self._cache)
def start(self):
if not self._expiry_ms:
@@ -78,9 +78,9 @@ class ExpiringCache(object):
def __getitem__(self, key):
try:
entry = self._cache[key]
- cache_counter.inc_hits(self._cache_name)
+ self.metrics.inc_hits()
except KeyError:
- cache_counter.inc_misses(self._cache_name)
+ self.metrics.inc_misses()
raise
if self._reset_expiry_on_get:
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index ea8a74ca69..3c051dabc4 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches import register_cache
from blist import sorteddict
@@ -42,7 +42,7 @@ class StreamChangeCache(object):
self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
- caches_by_name[self.name] = self._cache
+ self.metrics = register_cache(self.name, self._cache)
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
@@ -53,19 +53,19 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos < self._earliest_known_stream_pos:
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return True
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return False
if stream_pos < latest_entity_change_pos:
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return True
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return False
def get_entities_changed(self, entities, stream_pos):
@@ -82,10 +82,10 @@ class StreamChangeCache(object):
self._cache[k] for k in keys[i:]
).intersection(entities)
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
else:
result = entities
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return result
diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py
index f3c1927ce1..f85455a5af 100644
--- a/tests/metrics/test_metric.py
+++ b/tests/metrics/test_metric.py
@@ -61,9 +61,6 @@ class CounterMetricTestCase(unittest.TestCase):
'vector{method="PUT"} 1',
])
- # Check that passing too few values errors
- self.assertRaises(ValueError, counter.inc)
-
class CallbackMetricTestCase(unittest.TestCase):
@@ -138,27 +135,27 @@ class CacheMetricTestCase(unittest.TestCase):
def test_cache(self):
d = dict()
- metric = CacheMetric("cache", lambda: len(d))
+ metric = CacheMetric("cache", lambda: len(d), "cache_name")
self.assertEquals(metric.render(), [
- 'cache:hits 0',
- 'cache:total 0',
- 'cache:size 0',
+ 'cache:hits{name="cache_name"} 0',
+ 'cache:total{name="cache_name"} 0',
+ 'cache:size{name="cache_name"} 0',
])
metric.inc_misses()
d["key"] = "value"
self.assertEquals(metric.render(), [
- 'cache:hits 0',
- 'cache:total 1',
- 'cache:size 1',
+ 'cache:hits{name="cache_name"} 0',
+ 'cache:total{name="cache_name"} 1',
+ 'cache:size{name="cache_name"} 1',
])
metric.inc_hits()
self.assertEquals(metric.render(), [
- 'cache:hits 1',
- 'cache:total 2',
- 'cache:size 1',
+ 'cache:hits{name="cache_name"} 1',
+ 'cache:total{name="cache_name"} 2',
+ 'cache:size{name="cache_name"} 1',
])
|