diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/handlers/message.py | 4 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 22 | ||||
-rw-r--r-- | synapse/notifier.py | 2 | ||||
-rw-r--r-- | synapse/push/clientformat.py | 112 | ||||
-rw-r--r-- | synapse/replication/resource.py | 28 | ||||
-rw-r--r-- | synapse/rest/client/v1/push_rule.py | 130 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 17 | ||||
-rw-r--r-- | synapse/storage/_base.py | 25 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 193 | ||||
-rw-r--r-- | synapse/storage/schema/delta/30/push_rule_stream.sql | 38 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 84 | ||||
-rw-r--r-- | synapse/streams/events.py | 4 | ||||
-rw-r--r-- | synapse/types.py | 7 |
13 files changed, 477 insertions, 189 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e92c74d07b..5c50c611ba 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -662,8 +662,8 @@ class MessageHandler(BaseHandler): user_id, messages, is_peeking=is_peeking ) - start_token = StreamToken(token[0], 0, 0, 0, 0) - end_token = StreamToken(token[1], 0, 0, 0, 0) + start_token = StreamToken.START.copy_and_replace("room_key", token[0]) + end_token = StreamToken.START.copy_and_replace("room_key", token[1]) time_now = self.clock.time_msec() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 26a66814eb..1f6fde8e8a 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -20,6 +20,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util import unwrapFirstError from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.metrics import Measure +from synapse.push.clientformat import format_push_rules_for_user from twisted.internet import defer @@ -224,6 +225,10 @@ class SyncHandler(BaseHandler): ) ) + account_data['m.push_rules'] = yield self.push_rules_for_user( + sync_config.user + ) + tags_by_room = yield self.store.get_tags_for_user( sync_config.user.to_string() ) @@ -328,6 +333,14 @@ class SyncHandler(BaseHandler): defer.returnValue(room_sync) + @defer.inlineCallbacks + def push_rules_for_user(self, user): + user_id = user.to_string() + rawrules = yield self.store.get_push_rules_for_user(user_id) + enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) + rules = format_push_rules_for_user(user, rawrules, enabled_map) + defer.returnValue(rules) + def account_data_for_user(self, account_data): account_data_events = [] @@ -487,6 +500,15 @@ class SyncHandler(BaseHandler): ) ) + push_rules_changed = yield self.store.have_push_rules_changed_for_user( + user_id, int(since_token.push_rules_key) + ) + + if push_rules_changed: + account_data["m.push_rules"] = yield self.push_rules_for_user( + sync_config.user + ) + # Get a list of membership change events that have happened. rooms_changed = yield self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key diff --git a/synapse/notifier.py b/synapse/notifier.py index 3c36a20868..9b69b0333a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -284,7 +284,7 @@ class Notifier(object): @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, - from_token=StreamToken("s0", "0", "0", "0", "0")): + from_token=StreamToken.START): """Wait until the callback returns a non empty response or the timeout fires. """ diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py new file mode 100644 index 0000000000..ae9db9ec2f --- /dev/null +++ b/synapse/push/clientformat.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.push.baserules import list_with_base_rules + +from synapse.push.rulekinds import ( + PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP +) + +import copy +import simplejson as json + + +def format_push_rules_for_user(user, rawrules, enabled_map): + """Converts a list of rawrules and a enabled map into nested dictionaries + to match the Matrix client-server format for push rules""" + + ruleslist = [] + for rawrule in rawrules: + rule = dict(rawrule) + rule["conditions"] = json.loads(rawrule["conditions"]) + rule["actions"] = json.loads(rawrule["actions"]) + ruleslist.append(rule) + + # We're going to be mutating this a lot, so do a deep copy + ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) + + rules = {'global': {}, 'device': {}} + + rules['global'] = _add_empty_priority_class_arrays(rules['global']) + + for r in ruleslist: + rulearray = None + + template_name = _priority_class_to_template_name(r['priority_class']) + + # Remove internal stuff. + for c in r["conditions"]: + c.pop("_id", None) + + pattern_type = c.pop("pattern_type", None) + if pattern_type == "user_id": + c["pattern"] = user.to_string() + elif pattern_type == "user_localpart": + c["pattern"] = user.localpart + + rulearray = rules['global'][template_name] + + template_rule = _rule_to_template(r) + if template_rule: + if r['rule_id'] in enabled_map: + template_rule['enabled'] = enabled_map[r['rule_id']] + elif 'enabled' in r: + template_rule['enabled'] = r['enabled'] + else: + template_rule['enabled'] = True + rulearray.append(template_rule) + + return rules + + +def _add_empty_priority_class_arrays(d): + for pc in PRIORITY_CLASS_MAP.keys(): + d[pc] = [] + return d + + +def _rule_to_template(rule): + unscoped_rule_id = None + if 'rule_id' in rule: + unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id']) + + template_name = _priority_class_to_template_name(rule['priority_class']) + if template_name in ['override', 'underride']: + templaterule = {k: rule[k] for k in ["conditions", "actions"]} + elif template_name in ["sender", "room"]: + templaterule = {'actions': rule['actions']} + unscoped_rule_id = rule['conditions'][0]['pattern'] + elif template_name == 'content': + if len(rule["conditions"]) != 1: + return None + thecond = rule["conditions"][0] + if "pattern" not in thecond: + return None + templaterule = {'actions': rule['actions']} + templaterule["pattern"] = thecond["pattern"] + + if unscoped_rule_id: + templaterule['rule_id'] = unscoped_rule_id + if 'default' in rule: + templaterule['default'] = rule['default'] + return templaterule + + +def _rule_id_from_namespaced(in_rule_id): + return in_rule_id.split('/')[-1] + + +def _priority_class_to_template_name(pc): + return PRIORITY_CLASS_INVERSE_MAP[pc] diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index e0d039518d..adc1eb1d0b 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -36,6 +36,7 @@ STREAM_NAMES = ( ("receipts",), ("user_account_data", "room_account_data", "tag_account_data",), ("backfill",), + ("push_rules",), ) @@ -63,6 +64,7 @@ class ReplicationResource(Resource): * "room_account_data: Per room per user account data. * "tag_account_data": Per room per user tags. * "backfill": Old events that have been backfilled from other servers. + * "push_rules": Per user changes to push rules. The API takes two additional query parameters: @@ -117,14 +119,16 @@ class ReplicationResource(Resource): def current_replication_token(self): stream_token = yield self.sources.get_current_token() backfill_token = yield self.store.get_current_backfill_token() + push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() defer.returnValue(_ReplicationToken( - stream_token.room_stream_id, + room_stream_token, int(stream_token.presence_key), int(stream_token.typing_key), int(stream_token.receipt_key), int(stream_token.account_data_key), backfill_token, + push_rules_token, )) @request_handler @@ -146,6 +150,7 @@ class ReplicationResource(Resource): yield self.presence(writer, current_token) # TODO: implement limit yield self.typing(writer, current_token) # TODO: implement limit yield self.receipts(writer, current_token, limit) + yield self.push_rules(writer, current_token, limit) self.streams(writer, current_token) logger.info("Replicated %d rows", writer.total) @@ -277,6 +282,21 @@ class ReplicationResource(Resource): "position", "user_id", "room_id", "tags" )) + @defer.inlineCallbacks + def push_rules(self, writer, current_token, limit): + current_position = current_token.push_rules + + push_rules = parse_integer(writer.request, "push_rules") + + if push_rules is not None: + rows = yield self.store.get_all_push_rule_updates( + push_rules, current_position, limit + ) + writer.write_header_and_rows("push_rules", rows, ( + "position", "event_stream_ordering", "user_id", "rule_id", "op", + "priority_class", "priority", "conditions", "actions" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -307,12 +327,16 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", + "push_rules" ))): __slots__ = [] def __new__(cls, *args): if len(args) == 1: - return cls(*(int(value) for value in args[0].split("_"))) + streams = [int(value) for value in args[0].split("_")] + if len(streams) < len(cls._fields): + streams.extend([0] * (len(cls._fields) - len(streams))) + return cls(*streams) else: return super(_ReplicationToken, cls).__new__(cls, *args) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 970a019223..981d7708db 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -22,12 +22,10 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) -from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS -from synapse.push.rulekinds import ( - PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP -) +from synapse.push.clientformat import format_push_rules_for_user +from synapse.push.baserules import BASE_RULE_IDS +from synapse.push.rulekinds import PRIORITY_CLASS_MAP -import copy import simplejson as json @@ -36,6 +34,11 @@ class PushRuleRestServlet(ClientV1RestServlet): SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( "Unrecognised request: You probably wanted a trailing slash") + def __init__(self, hs): + super(PushRuleRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + @defer.inlineCallbacks def on_PUT(self, request): spec = _rule_spec_from_path(request.postpath) @@ -51,8 +54,11 @@ class PushRuleRestServlet(ClientV1RestServlet): content = _parse_json(request) + user_id = requester.user.to_string() + if 'attr' in spec: - yield self.set_rule_attr(requester.user.to_string(), spec, content) + yield self.set_rule_attr(user_id, spec, content) + self.notify_user(user_id) defer.returnValue((200, {})) if spec['rule_id'].startswith('.'): @@ -77,8 +83,8 @@ class PushRuleRestServlet(ClientV1RestServlet): after = _namespaced_rule_id(spec, after[0]) try: - yield self.hs.get_datastore().add_push_rule( - user_id=requester.user.to_string(), + yield self.store.add_push_rule( + user_id=user_id, rule_id=_namespaced_rule_id_from_spec(spec), priority_class=priority_class, conditions=conditions, @@ -86,6 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet): before=before, after=after ) + self.notify_user(user_id) except InconsistentRuleException as e: raise SynapseError(400, e.message) except RuleNotFoundException as e: @@ -98,13 +105,15 @@ class PushRuleRestServlet(ClientV1RestServlet): spec = _rule_spec_from_path(request.postpath) requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.hs.get_datastore().delete_push_rule( - requester.user.to_string(), namespaced_rule_id + yield self.store.delete_push_rule( + user_id, namespaced_rule_id ) + self.notify_user(user_id) defer.returnValue((200, {})) except StoreError as e: if e.code == 404: @@ -115,58 +124,16 @@ class PushRuleRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - user = requester.user + user_id = requester.user.to_string() # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.hs.get_datastore().get_push_rules_for_user( - user.to_string() - ) - - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) - ruleslist.append(rule) - - # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) - - rules = {'global': {}, 'device': {}} - - rules['global'] = _add_empty_priority_class_arrays(rules['global']) - - enabled_map = yield self.hs.get_datastore().\ - get_push_rules_enabled_for_user(user.to_string()) - - for r in ruleslist: - rulearray = None - - template_name = _priority_class_to_template_name(r['priority_class']) - - # Remove internal stuff. - for c in r["conditions"]: - c.pop("_id", None) - - pattern_type = c.pop("pattern_type", None) - if pattern_type == "user_id": - c["pattern"] = user.to_string() - elif pattern_type == "user_localpart": - c["pattern"] = user.localpart + rawrules = yield self.store.get_push_rules_for_user(user_id) - rulearray = rules['global'][template_name] + enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - template_rule = _rule_to_template(r) - if template_rule: - if r['rule_id'] in enabled_map: - template_rule['enabled'] = enabled_map[r['rule_id']] - elif 'enabled' in r: - template_rule['enabled'] = r['enabled'] - else: - template_rule['enabled'] = True - rulearray.append(template_rule) + rules = format_push_rules_for_user(requester.user, rawrules, enabled_map) path = request.postpath[1:] @@ -188,6 +155,12 @@ class PushRuleRestServlet(ClientV1RestServlet): def on_OPTIONS(self, _): return 200, {} + def notify_user(self, user_id): + stream_id, _ = self.store.get_push_rules_stream_token() + self.notifier.on_new_event( + "push_rules_key", stream_id, users=[user_id] + ) + def set_rule_attr(self, user_id, spec, val): if spec['attr'] == 'enabled': if isinstance(val, dict) and "enabled" in val: @@ -198,7 +171,7 @@ class PushRuleRestServlet(ClientV1RestServlet): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.hs.get_datastore().set_push_rule_enabled( + return self.store.set_push_rule_enabled( user_id, namespaced_rule_id, val ) elif spec['attr'] == 'actions': @@ -210,7 +183,7 @@ class PushRuleRestServlet(ClientV1RestServlet): if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return self.hs.get_datastore().set_push_rule_actions( + return self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: @@ -308,12 +281,6 @@ def _check_actions(actions): raise InvalidRuleException("Unrecognised action") -def _add_empty_priority_class_arrays(d): - for pc in PRIORITY_CLASS_MAP.keys(): - d[pc] = [] - return d - - def _filter_ruleset_with_path(ruleset, path): if path == []: raise UnrecognizedRequestError( @@ -362,37 +329,6 @@ def _priority_class_from_spec(spec): return pc -def _priority_class_to_template_name(pc): - return PRIORITY_CLASS_INVERSE_MAP[pc] - - -def _rule_to_template(rule): - unscoped_rule_id = None - if 'rule_id' in rule: - unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id']) - - template_name = _priority_class_to_template_name(rule['priority_class']) - if template_name in ['override', 'underride']: - templaterule = {k: rule[k] for k in ["conditions", "actions"]} - elif template_name in ["sender", "room"]: - templaterule = {'actions': rule['actions']} - unscoped_rule_id = rule['conditions'][0]['pattern'] - elif template_name == 'content': - if len(rule["conditions"]) != 1: - return None - thecond = rule["conditions"][0] - if "pattern" not in thecond: - return None - templaterule = {'actions': rule['actions']} - templaterule["pattern"] = thecond["pattern"] - - if unscoped_rule_id: - templaterule['rule_id'] = unscoped_rule_id - if 'default' in rule: - templaterule['default'] = rule['default'] - return templaterule - - def _namespaced_rule_id_from_spec(spec): return _namespaced_rule_id(spec, spec['rule_id']) @@ -401,10 +337,6 @@ def _namespaced_rule_id(spec, rule_id): return "global/%s/%s" % (spec['template'], rule_id) -def _rule_id_from_namespaced(in_rule_id): - return in_rule_id.split('/')[-1] - - class InvalidRuleException(Exception): pass diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f257721ea3..6f37a85d09 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -45,7 +45,7 @@ from .search import SearchStore from .tags import TagsStore from .account_data import AccountDataStore -from util.id_generators import IdGenerator, StreamIdGenerator +from util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -122,6 +122,9 @@ class DataStore(RoomMemberStore, RoomStore, self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + self._push_rules_stream_id_gen = ChainedIdGenerator( + self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" + ) events_max = self._stream_id_gen.get_max_token() event_cache_prefill, min_event_val = self._get_cache_dict( @@ -157,6 +160,18 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=presence_cache_prefill ) + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._push_rules_stream_id_gen.get_max_token()[0], + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", push_rules_id, + prefilled_cache=push_rules_prefill, + ) + super(DataStore, self).__init__(hs) def take_presence_startup_info(self): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 2e97ac84a8..7dc67ecd57 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -770,18 +770,29 @@ class SQLBaseStore(object): table : string giving the table name keyvalues : dict of column names and values to select the row with """ + return self.runInteraction( + desc, self._simple_delete_one_txn, table, keyvalues + ) + + @staticmethod + def _simple_delete_one_txn(txn, table, keyvalues): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) - def func(txn): - txn.execute(sql, keyvalues.values()) - if txn.rowcount == 0: - raise StoreError(404, "No row found") - if txn.rowcount > 1: - raise StoreError(500, "more than one row matched") - return self.runInteraction(desc, func) + txn.execute(sql, keyvalues.values()) + if txn.rowcount == 0: + raise StoreError(404, "No row found") + if txn.rowcount > 1: + raise StoreError(500, "more than one row matched") @staticmethod def _simple_delete_txn(txn, table, keyvalues): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 56e69495b1..9dbad2fd5f 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -99,30 +99,32 @@ class PushRuleStore(SQLBaseStore): results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] defer.returnValue(results) + @defer.inlineCallbacks def add_push_rule( self, user_id, rule_id, priority_class, conditions, actions, before=None, after=None ): conditions_json = json.dumps(conditions) actions_json = json.dumps(actions) - - if before or after: - return self.runInteraction( - "_add_push_rule_relative_txn", - self._add_push_rule_relative_txn, - user_id, rule_id, priority_class, - conditions_json, actions_json, before, after, - ) - else: - return self.runInteraction( - "_add_push_rule_highest_priority_txn", - self._add_push_rule_highest_priority_txn, - user_id, rule_id, priority_class, - conditions_json, actions_json, - ) + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + if before or after: + yield self.runInteraction( + "_add_push_rule_relative_txn", + self._add_push_rule_relative_txn, + stream_id, event_stream_ordering, user_id, rule_id, priority_class, + conditions_json, actions_json, before, after, + ) + else: + yield self.runInteraction( + "_add_push_rule_highest_priority_txn", + self._add_push_rule_highest_priority_txn, + stream_id, event_stream_ordering, user_id, rule_id, priority_class, + conditions_json, actions_json, + ) def _add_push_rule_relative_txn( - self, txn, user_id, rule_id, priority_class, + self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, conditions_json, actions_json, before, after ): # Lock the table since otherwise we'll have annoying races between the @@ -174,12 +176,12 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_id, priority_class, new_rule_priority)) self._upsert_push_rule_txn( - txn, user_id, rule_id, priority_class, new_rule_priority, - conditions_json, actions_json, + txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, + new_rule_priority, conditions_json, actions_json, ) def _add_push_rule_highest_priority_txn( - self, txn, user_id, rule_id, priority_class, + self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, conditions_json, actions_json ): # Lock the table since otherwise we'll have annoying races between the @@ -201,13 +203,13 @@ class PushRuleStore(SQLBaseStore): self._upsert_push_rule_txn( txn, - user_id, rule_id, priority_class, new_prio, + stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio, conditions_json, actions_json, ) def _upsert_push_rule_txn( - self, txn, user_id, rule_id, priority_class, - priority, conditions_json, actions_json + self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, + priority, conditions_json, actions_json, update_stream=True ): """Specialised version of _simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes @@ -242,12 +244,17 @@ class PushRuleStore(SQLBaseStore): }, ) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) - ) + if update_stream: + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, + op="ADD", + data={ + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + } + ) @defer.inlineCallbacks def delete_push_rule(self, user_id, rule_id): @@ -260,25 +267,37 @@ class PushRuleStore(SQLBaseStore): user_id (str): The matrix ID of the push rule owner rule_id (str): The rule_id of the rule to be deleted """ - yield self._simple_delete_one( - "push_rules", - {'user_name': user_id, 'rule_id': rule_id}, - desc="delete_push_rule", - ) + def delete_push_rule_txn(txn, stream_id, event_stream_ordering): + self._simple_delete_one_txn( + txn, + "push_rules", + {'user_name': user_id, 'rule_id': rule_id}, + ) - self.get_push_rules_for_user.invalidate((user_id,)) - self.get_push_rules_enabled_for_user.invalidate((user_id,)) + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, + op="DELETE" + ) + + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering + ) @defer.inlineCallbacks def set_push_rule_enabled(self, user_id, rule_id, enabled): - ret = yield self.runInteraction( - "_set_push_rule_enabled_txn", - self._set_push_rule_enabled_txn, - user_id, rule_id, enabled - ) - defer.returnValue(ret) + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "_set_push_rule_enabled_txn", + self._set_push_rule_enabled_txn, + stream_id, event_stream_ordering, user_id, rule_id, enabled + ) - def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): + def _set_push_rule_enabled_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled + ): new_id = self._push_rules_enable_id_gen.get_next() self._simple_upsert_txn( txn, @@ -287,25 +306,26 @@ class PushRuleStore(SQLBaseStore): {'enabled': 1 if enabled else 0}, {'id': new_id}, ) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, + op="ENABLE" if enabled else "DISABLE" ) + @defer.inlineCallbacks def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): actions_json = json.dumps(actions) - def set_push_rule_actions_txn(txn): + def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): if is_default_rule: # Add a dummy rule to the rules table with the user specified # actions. priority_class = -1 priority = 1 self._upsert_push_rule_txn( - txn, user_id, rule_id, priority_class, priority, - "[]", actions_json + txn, stream_id, event_stream_ordering, user_id, rule_id, + priority_class, priority, "[]", actions_json, + update_stream=False ) else: self._simple_update_one_txn( @@ -315,10 +335,81 @@ class PushRuleStore(SQLBaseStore): {'actions': actions_json}, ) + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, + op="ACTIONS", data={"actions": actions_json} + ) + + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "set_push_rule_actions", set_push_rule_actions_txn, + stream_id, event_stream_ordering + ) + + def _insert_push_rules_update_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None + ): + values = { + "stream_id": stream_id, + "event_stream_ordering": event_stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": op, + } + if data is not None: + values.update(data) + + self._simple_insert_txn(txn, "push_rules_stream", values=values) + + txn.call_after( + self.get_push_rules_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.get_push_rules_enabled_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.push_rules_stream_cache.entity_has_changed, user_id, stream_id + ) + + def get_all_push_rule_updates(self, last_id, current_id, limit): + """Get all the push rules changes that have happend on the server""" + def get_all_push_rule_updates_txn(txn): + sql = ( + "SELECT stream_id, event_stream_ordering, user_id, rule_id," + " op, priority_class, priority, conditions, actions" + " FROM push_rules_stream" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() return self.runInteraction( - "set_push_rule_actions", set_push_rule_actions_txn, + "get_all_push_rule_updates", get_all_push_rule_updates_txn ) + def get_push_rules_stream_token(self): + """Get the position of the push rules stream. + Returns a pair of a stream id for the push_rules stream and the + room stream ordering it corresponds to.""" + return self._push_rules_stream_id_gen.get_max_token() + + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + class RuleNotFoundException(Exception): pass diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/schema/delta/30/push_rule_stream.sql new file mode 100644 index 0000000000..735aa8d5f6 --- /dev/null +++ b/synapse/storage/schema/delta/30/push_rule_stream.sql @@ -0,0 +1,38 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + +CREATE TABLE push_rules_stream( + stream_id BIGINT NOT NULL, + event_stream_ordering BIGINT NOT NULL, + user_id TEXT NOT NULL, + rule_id TEXT NOT NULL, + op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE" + priority_class SMALLINT, + priority INTEGER, + conditions TEXT, + actions TEXT +); + +-- The extra data for each operation is: +-- * ENABLE, DISABLE, DELETE: [] +-- * ACTIONS: ["actions"] +-- * ADD: ["priority_class", "priority", "actions", "conditions"] + +-- Index for replication queries. +CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); +-- Index for /sync queries. +CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index efe3f68e6e..af425ba9a4 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -20,23 +20,21 @@ import threading class IdGenerator(object): def __init__(self, db_conn, table, column): - self.table = table - self.column = column self._lock = threading.Lock() - cur = db_conn.cursor() - self._next_id = self._load_next_id(cur) - cur.close() - - def _load_next_id(self, txn): - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,)) - val, = txn.fetchone() - return val + 1 if val else 1 + self._next_id = _load_max_id(db_conn, table, column) def get_next(self): with self._lock: - i = self._next_id self._next_id += 1 - return i + return self._next_id + + +def _load_max_id(db_conn, table, column): + cur = db_conn.cursor() + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + val, = cur.fetchone() + cur.close() + return val if val else 1 class StreamIdGenerator(object): @@ -52,23 +50,10 @@ class StreamIdGenerator(object): # ... persist event ... """ def __init__(self, db_conn, table, column): - self.table = table - self.column = column - self._lock = threading.Lock() - - cur = db_conn.cursor() - self._current_max = self._load_current_max(cur) - cur.close() - + self._current_max = _load_max_id(db_conn, table, column) self._unfinished_ids = deque() - def _load_current_max(self, txn): - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) - rows = txn.fetchall() - val, = rows[0] - return int(val) if val else 1 - def get_next(self): """ Usage: @@ -124,3 +109,50 @@ class StreamIdGenerator(object): return self._unfinished_ids[0] - 1 return self._current_max + + +class ChainedIdGenerator(object): + """Used to generate new stream ids where the stream must be kept in sync + with another stream. It generates pairs of IDs, the first element is an + integer ID for this stream, the second element is the ID for the stream + that this stream needs to be kept in sync with.""" + + def __init__(self, chained_generator, db_conn, table, column): + self.chained_generator = chained_generator + self._lock = threading.Lock() + self._current_max = _load_max_id(db_conn, table, column) + self._unfinished_ids = deque() + + def get_next(self): + """ + Usage: + with stream_id_gen.get_next() as (stream_id, chained_id): + # ... persist event ... + """ + with self._lock: + self._current_max += 1 + next_id = self._current_max + chained_id = self.chained_generator.get_max_token() + + self._unfinished_ids.append((next_id, chained_id)) + + @contextlib.contextmanager + def manager(): + try: + yield (next_id, chained_id) + finally: + with self._lock: + self._unfinished_ids.remove((next_id, chained_id)) + + return manager() + + def get_max_token(self): + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + """ + with self._lock: + if self._unfinished_ids: + stream_id, chained_id = self._unfinished_ids[0] + return (stream_id - 1, chained_id) + + return (self._current_max, self.chained_generator.get_max_token()) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 5ddf4e988b..d4c0bb6732 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -38,9 +38,12 @@ class EventSources(object): name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items() } + self.store = hs.get_datastore() @defer.inlineCallbacks def get_current_token(self, direction='f'): + push_rules_key, _ = self.store.get_push_rules_stream_token() + token = StreamToken( room_key=( yield self.sources["room"].get_current_key(direction) @@ -57,5 +60,6 @@ class EventSources(object): account_data_key=( yield self.sources["account_data"].get_current_key() ), + push_rules_key=push_rules_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index d5bd95cbd3..5b166835bd 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -115,6 +115,7 @@ class StreamToken( "typing_key", "receipt_key", "account_data_key", + "push_rules_key", )) ): _SEPARATOR = "_" @@ -150,6 +151,7 @@ class StreamToken( or (int(other.typing_key) < int(self.typing_key)) or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.account_data_key) < int(self.account_data_key)) + or (int(other.push_rules_key) < int(self.push_rules_key)) ) def copy_and_advance(self, key, new_value): @@ -174,6 +176,11 @@ class StreamToken( return StreamToken(**d) +StreamToken.START = StreamToken( + *(["s0"] + ["0"] * (len(StreamToken._fields) - 1)) +) + + class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): """Tokens are positions between events. The token "s1" comes after event 1. |