diff options
author | Kegan Dougal <kegan@matrix.org> | 2015-03-06 17:28:49 +0000 |
---|---|---|
committer | Kegan Dougal <kegan@matrix.org> | 2015-03-06 17:28:49 +0000 |
commit | 34ce2ca62f261311a9082b6e243eb6c584487025 (patch) | |
tree | 0144c0f0a8231e2ea68e23c7d5676f48e42ee775 /synapse | |
parent | Assign the AS ID from the database; replace old placeholder txn id. (diff) | |
parent | When setting display name more graciously handle failures to update room state. (diff) | |
download | synapse-34ce2ca62f261311a9082b6e243eb6c584487025.tar.xz |
Merge branch 'develop' into application-services-txn-reliability
Diffstat (limited to '')
-rw-r--r-- | synapse/__init__.py | 2 | ||||
-rwxr-xr-x | synapse/app/homeserver.py | 109 | ||||
-rw-r--r-- | synapse/config/server.py | 3 | ||||
-rw-r--r-- | synapse/config/tls.py | 15 | ||||
-rw-r--r-- | synapse/crypto/context_factory.py | 5 | ||||
-rw-r--r-- | synapse/crypto/keyring.py | 13 | ||||
-rw-r--r-- | synapse/federation/federation_client.py | 121 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 8 | ||||
-rw-r--r-- | synapse/handlers/events.py | 9 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 20 | ||||
-rw-r--r-- | synapse/http/server.py | 44 | ||||
-rw-r--r-- | synapse/push/__init__.py | 18 | ||||
-rw-r--r-- | synapse/push/baserules.py | 107 | ||||
-rw-r--r-- | synapse/push/httppusher.py | 1 | ||||
-rw-r--r-- | synapse/rest/client/v1/push_rule.py | 61 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 48 | ||||
-rw-r--r-- | synapse/storage/schema/delta/14/v14.sql | 9 |
17 files changed, 470 insertions, 123 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index f8ec4b0459..f46a6df1fb 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.7.1-r1" +__version__ = "0.8.0" diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b3ba7dfddc..f96535a978 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -219,61 +219,64 @@ class SynapseHomeServer(HomeServer): def get_version_string(): - null = open(os.devnull, 'w') - cwd = os.path.dirname(os.path.abspath(__file__)) try: - git_branch = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip() - git_branch = "b=" + git_branch - except subprocess.CalledProcessError: - git_branch = "" - - try: - git_tag = subprocess.check_output( - ['git', 'describe', '--exact-match'], - stderr=null, - cwd=cwd, - ).strip() - git_tag = "t=" + git_tag - except subprocess.CalledProcessError: - git_tag = "" - - try: - git_commit = subprocess.check_output( - ['git', 'rev-parse', '--short', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip() - except subprocess.CalledProcessError: - git_commit = "" - - try: - dirty_string = "-this_is_a_dirty_checkout" - is_dirty = subprocess.check_output( - ['git', 'describe', '--dirty=' + dirty_string], - stderr=null, - cwd=cwd, - ).strip().endswith(dirty_string) - - git_dirty = "dirty" if is_dirty else "" - except subprocess.CalledProcessError: - git_dirty = "" - - if git_branch or git_tag or git_commit or git_dirty: - git_version = ",".join( - s for s in - (git_branch, git_tag, git_commit, git_dirty,) - if s - ) - - return ( - "Synapse/%s (%s)" % ( - synapse.__version__, git_version, + null = open(os.devnull, 'w') + cwd = os.path.dirname(os.path.abspath(__file__)) + try: + git_branch = subprocess.check_output( + ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], + stderr=null, + cwd=cwd, + ).strip() + git_branch = "b=" + git_branch + except subprocess.CalledProcessError: + git_branch = "" + + try: + git_tag = subprocess.check_output( + ['git', 'describe', '--exact-match'], + stderr=null, + cwd=cwd, + ).strip() + git_tag = "t=" + git_tag + except subprocess.CalledProcessError: + git_tag = "" + + try: + git_commit = subprocess.check_output( + ['git', 'rev-parse', '--short', 'HEAD'], + stderr=null, + cwd=cwd, + ).strip() + except subprocess.CalledProcessError: + git_commit = "" + + try: + dirty_string = "-this_is_a_dirty_checkout" + is_dirty = subprocess.check_output( + ['git', 'describe', '--dirty=' + dirty_string], + stderr=null, + cwd=cwd, + ).strip().endswith(dirty_string) + + git_dirty = "dirty" if is_dirty else "" + except subprocess.CalledProcessError: + git_dirty = "" + + if git_branch or git_tag or git_commit or git_dirty: + git_version = ",".join( + s for s in + (git_branch, git_tag, git_commit, git_dirty,) + if s ) - ).encode("ascii") + + return ( + "Synapse/%s (%s)" % ( + synapse.__version__, git_version, + ) + ).encode("ascii") + except Exception as e: + logger.warn("Failed to check for git repository: %s", e) return ("Synapse/%s" % (synapse.__version__,)).encode("ascii") diff --git a/synapse/config/server.py b/synapse/config/server.py index 4e4892d40b..b042d4eed9 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -30,7 +30,6 @@ class ServerConfig(Config): self.pid_file = self.abspath(args.pid_file) self.webclient = True self.manhole = args.manhole - self.no_tls = args.no_tls self.soft_file_limit = args.soft_file_limit if not args.content_addr: @@ -76,8 +75,6 @@ class ServerConfig(Config): server_group.add_argument("--content-addr", default=None, help="The host and scheme to use for the " "content repository") - server_group.add_argument("--no-tls", action='store_true', - help="Don't bind to the https port.") server_group.add_argument("--soft-file-limit", type=int, default=0, help="Set the soft limit on the number of " "file descriptors synapse can use. " diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 384b29e7ba..034f9a7bf0 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -28,9 +28,16 @@ class TlsConfig(Config): self.tls_certificate = self.read_tls_certificate( args.tls_certificate_path ) - self.tls_private_key = self.read_tls_private_key( - args.tls_private_key_path - ) + + self.no_tls = args.no_tls + + if self.no_tls: + self.tls_private_key = None + else: + self.tls_private_key = self.read_tls_private_key( + args.tls_private_key_path + ) + self.tls_dh_params_path = self.check_file( args.tls_dh_params_path, "tls_dh_params" ) @@ -45,6 +52,8 @@ class TlsConfig(Config): help="PEM encoded private key for TLS") tls_group.add_argument("--tls-dh-params-path", help="PEM dh parameters for ephemeral keys") + tls_group.add_argument("--no-tls", action='store_true', + help="Don't bind to the https port.") def read_tls_certificate(self, cert_path): cert_pem = self.read_file(cert_path, "tls_certificate") diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 24d4abf3e9..2f8618a0df 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -38,7 +38,10 @@ class ServerContextFactory(ssl.ContextFactory): logger.exception("Failed to enable eliptic curve for TLS") context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.use_certificate(config.tls_certificate) - context.use_privatekey(config.tls_private_key) + + if not config.no_tls: + context.use_privatekey(config.tls_private_key) + context.load_tmp_dh(config.tls_dh_params_path) context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH") diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 828aced44a..f4db7b8a05 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -50,18 +50,27 @@ class Keyring(object): ) try: verify_key = yield self.get_server_verify_key(server_name, key_ids) - except IOError: + except IOError as e: + logger.warn( + "Got IOError when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) raise SynapseError( 502, "Error downloading keys for %s" % (server_name,), Codes.UNAUTHORIZED, ) - except: + except Exception as e: + logger.warn( + "Got Exception when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) raise SynapseError( 401, "No key for %s with id %s" % (server_name, key_ids), Codes.UNAUTHORIZED, ) + try: verify_signed_json(json_object, server_name, verify_key) except: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index ca89a0787c..f131941f45 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -19,14 +19,18 @@ from twisted.internet import defer from .federation_base import FederationBase from .units import Edu -from synapse.api.errors import CodeMessageException, SynapseError +from synapse.api.errors import ( + CodeMessageException, HttpResponseException, SynapseError, +) from synapse.util.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination +import itertools import logging +import random logger = logging.getLogger(__name__) @@ -440,21 +444,112 @@ class FederationClient(FederationBase): defer.returnValue(ret) @defer.inlineCallbacks - def get_missing_events(self, destination, room_id, earliest_events, + def get_missing_events(self, destination, room_id, earliest_events_ids, latest_events, limit, min_depth): - content = yield self.transport_layer.get_missing_events( - destination, room_id, earliest_events, latest_events, limit, - min_depth, - ) + """Tries to fetch events we are missing. This is called when we receive + an event without having received all of its ancestors. - events = [ - self.event_from_pdu_json(e) - for e in content.get("events", []) - ] + Args: + destination (str) + room_id (str) + earliest_events_ids (list): List of event ids. Effectively the + events we expected to receive, but haven't. `get_missing_events` + should only return events that didn't happen before these. + latest_events (list): List of events we have received that we don't + have all previous events for. + limit (int): Maximum number of events to return. + min_depth (int): Minimum depth of events tor return. + """ + try: + content = yield self.transport_layer.get_missing_events( + destination=destination, + room_id=room_id, + earliest_events=earliest_events_ids, + latest_events=[e.event_id for e in latest_events], + limit=limit, + min_depth=min_depth, + ) + + events = [ + self.event_from_pdu_json(e) + for e in content.get("events", []) + ] + + signed_events = yield self._check_sigs_and_hash_and_fetch( + destination, events, outlier=True + ) + + have_gotten_all_from_destination = True + except HttpResponseException as e: + if not e.code == 400: + raise - signed_events = yield self._check_sigs_and_hash_and_fetch( - destination, events, outlier=True - ) + # We are probably hitting an old server that doesn't support + # get_missing_events + signed_events = [] + have_gotten_all_from_destination = False + + if len(signed_events) >= limit: + defer.returnValue(signed_events) + + servers = yield self.store.get_joined_hosts_for_room(room_id) + + servers = set(servers) + servers.discard(self.server_name) + + failed_to_fetch = set() + + while len(signed_events) < limit: + # Are we missing any? + + seen_events = set(earliest_events_ids) + seen_events.update(e.event_id for e in signed_events) + + missing_events = {} + for e in itertools.chain(latest_events, signed_events): + if e.depth > min_depth: + missing_events.update({ + e_id: e.depth for e_id, _ in e.prev_events + if e_id not in seen_events + and e_id not in failed_to_fetch + }) + + if not missing_events: + break + + have_seen = yield self.store.have_events(missing_events) + + for k in have_seen: + missing_events.pop(k, None) + + if not missing_events: + break + + # Okay, we haven't gotten everything yet. Lets get them. + ordered_missing = sorted(missing_events.items(), key=lambda x: x[0]) + + if have_gotten_all_from_destination: + servers.discard(destination) + + def random_server_list(): + srvs = list(servers) + random.shuffle(srvs) + return srvs + + deferreds = [ + self.get_pdu( + destinations=random_server_list(), + event_id=e_id, + ) + for e_id, depth in ordered_missing[:limit - len(signed_events)] + ] + + res = yield defer.DeferredList(deferreds, consumeErrors=True) + for (result, val), (e_id, _) in zip(res, ordered_missing): + if result: + signed_events.append(val) + else: + failed_to_fetch.add(e_id) defer.returnValue(signed_events) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 4264d857be..9c7dcdba96 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -413,12 +413,16 @@ class FederationServer(FederationBase): missing_events = yield self.get_missing_events( origin, pdu.room_id, - earliest_events=list(latest), - latest_events=[pdu.event_id], + earliest_events_ids=list(latest), + latest_events=[pdu], limit=10, min_depth=min_depth, ) + # We want to sort these by depth so we process them and + # tell clients about them in order. + missing_events.sort(key=lambda x: x.depth) + for e in missing_events: yield self._handle_new_pdu( origin, diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 8d5f5c8499..d3297b7292 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -23,6 +23,7 @@ from synapse.events.utils import serialize_event from ._base import BaseHandler import logging +import random logger = logging.getLogger(__name__) @@ -72,6 +73,14 @@ class EventStreamHandler(BaseHandler): rm_handler = self.hs.get_handlers().room_member_handler room_ids = yield rm_handler.get_rooms_for_user(auth_user) + if timeout: + # If they've set a timeout set a minimum limit. + timeout = max(timeout, 500) + + # Add some randomness to this value to try and mitigate against + # thundering herds on restart. + timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) + with PreserveLoggingContext(): events, tokens = yield self.notifier.get_events_for( auth_user, room_ids, pagin_config, timeout diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 03b2159c53..2ddf9d5378 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -212,10 +212,16 @@ class ProfileHandler(BaseHandler): ) msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_event({ - "type": EventTypes.Member, - "room_id": j.room_id, - "state_key": user.to_string(), - "content": content, - "sender": user.to_string() - }, ratelimit=False) + try: + yield msg_handler.create_and_send_event({ + "type": EventTypes.Member, + "room_id": j.room_id, + "state_key": user.to_string(), + "content": content, + "sender": user.to_string() + }, ratelimit=False) + except Exception as e: + logger.warn( + "Failed to update join event for room %s - %s", + j.room_id, str(e.message) + ) diff --git a/synapse/http/server.py b/synapse/http/server.py index 74a101a5d7..767c3ef79b 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -124,27 +124,29 @@ class JsonResource(HttpServer, resource.Resource): # and path regex match for path_entry in self.path_regexs.get(request.method, []): m = path_entry.pattern.match(request.path) - if m: - # We found a match! Trigger callback and then return the - # returned response. We pass both the request and any - # matched groups from the regex to the callback. - - args = [ - urllib.unquote(u).decode("UTF-8") for u in m.groups() - ] - - logger.info( - "Received request: %s %s", - request.method, request.path - ) - - code, response = yield path_entry.callback( - request, - *args - ) - - self._send_response(request, code, response) - return + if not m: + continue + + # We found a match! Trigger callback and then return the + # returned response. We pass both the request and any + # matched groups from the regex to the callback. + + args = [ + urllib.unquote(u).decode("UTF-8") for u in m.groups() + ] + + logger.info( + "Received request: %s %s", + request.method, request.path + ) + + code, response = yield path_entry.callback( + request, + *args + ) + + self._send_response(request, code, response) + return # Huh. No one wanted to handle that? Fiiiiiine. Send 400. raise UnrecognizedRequestError() diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 0fb3e4f7f3..3da0ce8703 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -32,7 +32,7 @@ class Pusher(object): INITIAL_BACKOFF = 1000 MAX_BACKOFF = 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000 - DEFAULT_ACTIONS = ['notify'] + DEFAULT_ACTIONS = ['dont-notify'] INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") @@ -72,16 +72,14 @@ class Pusher(object): # let's assume you probably know about messages you sent yourself defer.returnValue(['dont_notify']) - if ev['type'] == 'm.room.member': - if ev['state_key'] != self.user_name: - defer.returnValue(['dont_notify']) - - rawrules = yield self.store.get_push_rules_for_user_name(self.user_name) + rawrules = yield self.store.get_push_rules_for_user(self.user_name) for r in rawrules: r['conditions'] = json.loads(r['conditions']) r['actions'] = json.loads(r['actions']) + enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name) + user = UserID.from_string(self.user_name) rules = baserules.list_with_base_rules(rawrules, user) @@ -107,6 +105,8 @@ class Pusher(object): room_member_count += 1 for r in rules: + if r['rule_id'] in enabled_map and not enabled_map[r['rule_id']]: + continue matches = True conditions = r['conditions'] @@ -117,7 +117,11 @@ class Pusher(object): ev, c, display_name=my_display_name, room_member_count=room_member_count ) - # ignore rules with no actions (we have an explict 'dont_notify' + logger.debug( + "Rule %s %s", + r['rule_id'], "matches" if matches else "doesn't match" + ) + # ignore rules with no actions (we have an explict 'dont_notify') if len(actions) == 0: logger.warn( "Ignoring rule id %s with no actions for user %s" % diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 162d265f66..eddc7fcbe2 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -32,12 +32,14 @@ def make_base_rules(user, kind): if kind == 'override': rules = make_base_override_rules() + elif kind == 'underride': + rules = make_base_underride_rules(user) elif kind == 'content': rules = make_base_content_rules(user) for r in rules: r['priority_class'] = PRIORITY_CLASS_MAP[kind] - r['default'] = True + r['default'] = True # Deprecated, left for backwards compat return rules @@ -45,6 +47,7 @@ def make_base_rules(user, kind): def make_base_content_rules(user): return [ { + 'rule_id': 'global/content/.m.rule.contains_user_name', 'conditions': [ { 'kind': 'event_match', @@ -57,6 +60,8 @@ def make_base_content_rules(user): { 'set_tweak': 'sound', 'value': 'default', + }, { + 'set_tweak': 'highlight' } ] }, @@ -66,6 +71,20 @@ def make_base_content_rules(user): def make_base_override_rules(): return [ { + 'rule_id': 'global/underride/.m.rule.suppress_notices', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'content.msgtype', + 'pattern': 'm.notice', + } + ], + 'actions': [ + 'dont-notify', + ] + }, + { + 'rule_id': 'global/override/.m.rule.contains_display_name', 'conditions': [ { 'kind': 'contains_display_name' @@ -76,10 +95,13 @@ def make_base_override_rules(): { 'set_tweak': 'sound', 'value': 'default' + }, { + 'set_tweak': 'highlight' } ] }, { + 'rule_id': 'global/override/.m.rule.room_one_to_one', 'conditions': [ { 'kind': 'room_member_count', @@ -95,3 +117,86 @@ def make_base_override_rules(): ] } ] + + +def make_base_underride_rules(user): + return [ + { + 'rule_id': 'global/underride/.m.rule.invite_for_me', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.member', + }, + { + 'kind': 'event_match', + 'key': 'content.membership', + 'pattern': 'invite', + }, + { + 'kind': 'event_match', + 'key': 'state_key', + 'pattern': user.to_string(), + }, + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'default' + } + ] + }, + { + 'rule_id': 'global/underride/.m.rule.member_event', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.member', + } + ], + 'actions': [ + 'notify', + ] + }, + { + 'rule_id': 'global/underride/.m.rule.message', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.message', + } + ], + 'actions': [ + 'notify', + ] + }, + { + 'rule_id': 'global/underride/.m.rule.call', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.call.invite', + } + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'ring' + } + ] + }, + { + 'rule_id': 'global/underride/.m.rule.fallback', + 'conditions': [ + ], + 'actions': [ + 'notify', + ] + }, + ] diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 202fcb42f4..caca709f10 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -88,6 +88,7 @@ class HttpPusher(Pusher): } if event['type'] == 'm.room.member': d['notification']['membership'] = event['content']['membership'] + d['notification']['user_is_target'] = event['state_key'] == self.user_name if 'content' in event: d['notification']['content'] = event['content'] diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 73ba0494e6..fef0eb6572 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -50,6 +50,10 @@ class PushRuleRestServlet(ClientV1RestServlet): content = _parse_json(request) + if 'attr' in spec: + self.set_rule_attr(user.to_string(), spec, content) + defer.returnValue((200, {})) + try: (conditions, actions) = _rule_tuple_from_request_object( spec['template'], @@ -110,7 +114,7 @@ class PushRuleRestServlet(ClientV1RestServlet): # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name( + rawrules = yield self.hs.get_datastore().get_push_rules_for_user( user.to_string() ) @@ -124,6 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet): rules['global'] = _add_empty_priority_class_arrays(rules['global']) + enabled_map = yield self.hs.get_datastore().\ + get_push_rules_enabled_for_user(user.to_string()) + for r in ruleslist: rulearray = None @@ -149,6 +156,9 @@ class PushRuleRestServlet(ClientV1RestServlet): template_rule = _rule_to_template(r) if template_rule: + template_rule['enabled'] = True + if r['rule_id'] in enabled_map: + template_rule['enabled'] = enabled_map[r['rule_id']] rulearray.append(template_rule) path = request.postpath[1:] @@ -189,6 +199,25 @@ class PushRuleRestServlet(ClientV1RestServlet): def on_OPTIONS(self, _): return 200, {} + def set_rule_attr(self, user_name, spec, val): + if spec['attr'] == 'enabled': + if not isinstance(val, bool): + raise SynapseError(400, "Value for 'enabled' must be boolean") + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + self.hs.get_datastore().set_push_rule_enabled( + user_name, namespaced_rule_id, val + ) + else: + raise UnrecognizedRequestError() + + def get_rule_attr(self, user_name, namespaced_rule_id, attr): + if attr == 'enabled': + return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id( + user_name, namespaced_rule_id + ) + else: + raise UnrecognizedRequestError() + def _rule_spec_from_path(path): if len(path) < 2: @@ -226,6 +255,12 @@ def _rule_spec_from_path(path): } if device: spec['profile_tag'] = device + + path = path[1:] + + if len(path) > 0 and len(path[0]) > 0: + spec['attr'] = path[0] + return spec @@ -275,7 +310,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None for a in actions: if a in ['notify', 'dont_notify', 'coalesce']: pass - elif isinstance(a, dict) and 'set_sound' in a: + elif isinstance(a, dict) and 'set_tweak' in a: pass else: raise InvalidRuleException("Unrecognised action") @@ -319,10 +354,23 @@ def _filter_ruleset_with_path(ruleset, path): if path[0] == '': return ruleset[template_kind] rule_id = path[0] + + the_rule = None for r in ruleset[template_kind]: if r['rule_id'] == rule_id: - return r - raise NotFoundError + the_rule = r + if the_rule is None: + raise NotFoundError + + path = path[1:] + if len(path) == 0: + return the_rule + + attr = path[0] + if attr in the_rule: + return the_rule[attr] + else: + raise UnrecognizedRequestError() def _priority_class_from_spec(spec): @@ -339,7 +387,7 @@ def _priority_class_from_spec(spec): def _priority_class_to_template_name(pc): if pc > PRIORITY_CLASS_MAP['override']: # per-device - prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP) + prio_class_index = pc - len(PRIORITY_CLASS_MAP) return PRIORITY_CLASS_INVERSE_MAP[prio_class_index] else: return PRIORITY_CLASS_INVERSE_MAP[pc] @@ -399,9 +447,6 @@ class InvalidRuleException(Exception): def _parse_json(request): try: content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.", - errcode=Codes.NOT_JSON) return content except ValueError: raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index ae46b39cc1..bbf322cc84 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) class PushRuleStore(SQLBaseStore): @defer.inlineCallbacks - def get_push_rules_for_user_name(self, user_name): + def get_push_rules_for_user(self, user_name): sql = ( "SELECT "+",".join(PushRuleTable.fields)+" " "FROM "+PushRuleTable.table_name+" " @@ -46,6 +46,28 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(dicts) @defer.inlineCallbacks + def get_push_rules_enabled_for_user(self, user_name): + results = yield self._simple_select_list( + PushRuleEnableTable.table_name, + {'user_name': user_name}, + PushRuleEnableTable.fields + ) + defer.returnValue( + {r['rule_id']: False if r['enabled'] == 0 else True for r in results} + ) + + @defer.inlineCallbacks + def get_push_rule_enabled_by_user_rule_id(self, user_name, rule_id): + results = yield self._simple_select_list( + PushRuleEnableTable.table_name, + {'user_name': user_name, 'rule_id': rule_id}, + ['enabled'] + ) + if not results: + defer.returnValue(True) + defer.returnValue(results[0]) + + @defer.inlineCallbacks def add_push_rule(self, before, after, **kwargs): vals = copy.copy(kwargs) if 'conditions' in vals: @@ -193,6 +215,20 @@ class PushRuleStore(SQLBaseStore): {'user_name': user_name, 'rule_id': rule_id} ) + @defer.inlineCallbacks + def set_push_rule_enabled(self, user_name, rule_id, enabled): + if enabled: + yield self._simple_delete_one( + PushRuleEnableTable.table_name, + {'user_name': user_name, 'rule_id': rule_id} + ) + else: + yield self._simple_upsert( + PushRuleEnableTable.table_name, + {'user_name': user_name, 'rule_id': rule_id}, + {'enabled': False} + ) + class RuleNotFoundException(Exception): pass @@ -216,3 +252,13 @@ class PushRuleTable(Table): ] EntryType = collections.namedtuple("PushRuleEntry", fields) + + +class PushRuleEnableTable(Table): + table_name = "push_rules_enable" + + fields = [ + "user_name", + "rule_id", + "enabled" + ] diff --git a/synapse/storage/schema/delta/14/v14.sql b/synapse/storage/schema/delta/14/v14.sql new file mode 100644 index 0000000000..0212726448 --- /dev/null +++ b/synapse/storage/schema/delta/14/v14.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS push_rules_enable ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + enabled TINYINT, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_enable_user_name on push_rules_enable (user_name); |