summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-03-01 09:43:27 +0000
committerErik Johnston <erik@matrix.org>2016-03-01 09:43:27 +0000
commit903fb34b39cd750050b4c89d9c0f5492652b8fcd (patch)
tree7e74b35c054227b2897ed46bccfb2b4de40f47f6
parentReport size of ExpiringCache (diff)
parentMerge pull request #607 from matrix-org/dbkr/send_inviter_member_event (diff)
downloadsynapse-903fb34b39cd750050b4c89d9c0f5492652b8fcd.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/expiring_cache_size
-rw-r--r--synapse/handlers/_base.py8
-rw-r--r--synapse/handlers/register.py15
-rw-r--r--synapse/handlers/room.py8
-rw-r--r--synapse/push/baserules.py57
-rw-r--r--synapse/rest/client/v1/login.py6
-rw-r--r--synapse/rest/client/v1/push_rule.py41
-rw-r--r--synapse/rest/client/v1/room.py14
-rw-r--r--synapse/storage/presence.py15
-rw-r--r--synapse/storage/push_rule.py25
-rw-r--r--synapse/storage/registration.py44
-rw-r--r--synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql24
11 files changed, 227 insertions, 30 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 5613bd2059..bdade98bf7 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -293,6 +293,12 @@ class BaseHandler(object):
 
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.INVITE:
+                def is_inviter_member_event(e):
+                    return (
+                        e.type == EventTypes.Member and
+                        e.sender == event.sender
+                    )
+
                 event.unsigned["invite_room_state"] = [
                     {
                         "type": e.type,
@@ -306,7 +312,7 @@ class BaseHandler(object):
                         EventTypes.CanonicalAlias,
                         EventTypes.RoomAvatar,
                         EventTypes.Name,
-                    )
+                    ) or is_inviter_member_event(e)
                 ]
 
                 invitee = UserID.from_string(event.state_key)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index f8959e5d82..6d155d57e7 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -349,3 +349,18 @@ class RegistrationHandler(BaseHandler):
 
     def auth_handler(self):
         return self.hs.get_handlers().auth_handler
+
+    @defer.inlineCallbacks
+    def guest_access_token_for(self, medium, address, inviter_user_id):
+        access_token = yield self.store.get_3pid_guest_access_token(medium, address)
+        if access_token:
+            defer.returnValue(access_token)
+
+        _, access_token = yield self.register(
+            generate_token=True,
+            make_guest=True
+        )
+        access_token = yield self.store.save_or_get_3pid_guest_access_token(
+            medium, address, access_token, inviter_user_id
+        )
+        defer.returnValue(access_token)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index eb9700a35b..d2de23a6cc 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -848,6 +848,13 @@ class RoomMemberHandler(BaseHandler):
                 user.
         """
 
+        registration_handler = self.hs.get_handlers().registration_handler
+        guest_access_token = yield registration_handler.guest_access_token_for(
+            medium=medium,
+            address=address,
+            inviter_user_id=inviter_user_id,
+        )
+
         is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
             id_server_scheme, id_server,
         )
@@ -864,6 +871,7 @@ class RoomMemberHandler(BaseHandler):
                 "sender": inviter_user_id,
                 "sender_display_name": inviter_display_name,
                 "sender_avatar_url": inviter_avatar_url,
+                "guest_access_token": guest_access_token,
             }
         )
         # TODO: Check for success
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 0832c77cb4..86a2998bcc 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -13,46 +13,67 @@
 # limitations under the License.
 
 from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
+import copy
 
 
 def list_with_base_rules(rawrules):
+    """Combine the list of rules set by the user with the default push rules
+
+    :param list rawrules: The rules the user has modified or set.
+    :returns: A new list with the rules set by the user combined with the
+        defaults.
+    """
     ruleslist = []
 
+    # Grab the base rules that the user has modified.
+    # The modified base rules have a priority_class of -1.
+    modified_base_rules = {
+        r['rule_id']: r for r in rawrules if r['priority_class'] < 0
+    }
+
+    # Remove the modified base rules from the list, They'll be added back
+    # in the default postions in the list.
+    rawrules = [r for r in rawrules if r['priority_class'] >= 0]
+
     # shove the server default rules for each kind onto the end of each
     current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
 
     ruleslist.extend(make_base_prepend_rules(
-        PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+        PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
     ))
 
     for r in rawrules:
         if r['priority_class'] < current_prio_class:
             while r['priority_class'] < current_prio_class:
                 ruleslist.extend(make_base_append_rules(
-                    PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                    PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                    modified_base_rules,
                 ))
                 current_prio_class -= 1
                 if current_prio_class > 0:
                     ruleslist.extend(make_base_prepend_rules(
-                        PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                        PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                        modified_base_rules,
                     ))
 
         ruleslist.append(r)
 
     while current_prio_class > 0:
         ruleslist.extend(make_base_append_rules(
-            PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+            PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+            modified_base_rules,
         ))
         current_prio_class -= 1
         if current_prio_class > 0:
             ruleslist.extend(make_base_prepend_rules(
-                PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                modified_base_rules,
             ))
 
     return ruleslist
 
 
-def make_base_append_rules(kind):
+def make_base_append_rules(kind, modified_base_rules):
     rules = []
 
     if kind == 'override':
@@ -62,15 +83,31 @@ def make_base_append_rules(kind):
     elif kind == 'content':
         rules = BASE_APPEND_CONTENT_RULES
 
+    # Copy the rules before modifying them
+    rules = copy.deepcopy(rules)
+    for r in rules:
+        # Only modify the actions, keep the conditions the same.
+        modified = modified_base_rules.get(r['rule_id'])
+        if modified:
+            r['actions'] = modified['actions']
+
     return rules
 
 
-def make_base_prepend_rules(kind):
+def make_base_prepend_rules(kind, modified_base_rules):
     rules = []
 
     if kind == 'override':
         rules = BASE_PREPEND_OVERRIDE_RULES
 
+    # Copy the rules before modifying them
+    rules = copy.deepcopy(rules)
+    for r in rules:
+        # Only modify the actions, keep the conditions the same.
+        modified = modified_base_rules.get(r['rule_id'])
+        if modified:
+            r['actions'] = modified['actions']
+
     return rules
 
 
@@ -263,18 +300,24 @@ BASE_APPEND_UNDERRIDE_RULES = [
 ]
 
 
+BASE_RULE_IDS = set()
+
 for r in BASE_APPEND_CONTENT_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['content']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_PREPEND_OVERRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['override']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_APPEND_OVRRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['override']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_APPEND_UNDERRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['underride']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 79101106ac..f13272da8e 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -404,10 +404,12 @@ def _parse_json(request):
     try:
         content = json.loads(request.content.read())
         if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.")
+            raise SynapseError(
+                400, "Content must be a JSON object.", errcode=Codes.BAD_JSON
+            )
         return content
     except ValueError:
-        raise SynapseError(400, "Content not JSON.")
+        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 5db2805d68..970a019223 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -22,7 +22,7 @@ from .base import ClientV1RestServlet, client_path_patterns
 from synapse.storage.push_rule import (
     InconsistentRuleException, RuleNotFoundException
 )
-import synapse.push.baserules as baserules
+from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS
 from synapse.push.rulekinds import (
     PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
 )
@@ -55,6 +55,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
             yield self.set_rule_attr(requester.user.to_string(), spec, content)
             defer.returnValue((200, {}))
 
+        if spec['rule_id'].startswith('.'):
+            # Rule ids starting with '.' are reserved for server default rules.
+            raise SynapseError(400, "cannot add new rule_ids that start with '.'")
+
         try:
             (conditions, actions) = _rule_tuple_from_request_object(
                 spec['template'],
@@ -128,7 +132,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
             ruleslist.append(rule)
 
         # We're going to be mutating this a lot, so do a deep copy
-        ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
+        ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
 
         rules = {'global': {}, 'device': {}}
 
@@ -197,13 +201,17 @@ class PushRuleRestServlet(ClientV1RestServlet):
             return self.hs.get_datastore().set_push_rule_enabled(
                 user_id, namespaced_rule_id, val
             )
-        else:
-            raise UnrecognizedRequestError()
-
-    def get_rule_attr(self, user_id, namespaced_rule_id, attr):
-        if attr == 'enabled':
-            return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id(
-                user_id, namespaced_rule_id
+        elif spec['attr'] == 'actions':
+            actions = val.get('actions')
+            _check_actions(actions)
+            namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+            rule_id = spec['rule_id']
+            is_default_rule = rule_id.startswith(".")
+            if is_default_rule:
+                if namespaced_rule_id not in BASE_RULE_IDS:
+                    raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
+            return self.hs.get_datastore().set_push_rule_actions(
+                user_id, namespaced_rule_id, actions, is_default_rule
             )
         else:
             raise UnrecognizedRequestError()
@@ -282,6 +290,15 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
         raise InvalidRuleException("No actions found")
     actions = req_obj['actions']
 
+    _check_actions(actions)
+
+    return conditions, actions
+
+
+def _check_actions(actions):
+    if not isinstance(actions, list):
+        raise InvalidRuleException("No actions found")
+
     for a in actions:
         if a in ['notify', 'dont_notify', 'coalesce']:
             pass
@@ -290,8 +307,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
         else:
             raise InvalidRuleException("Unrecognised action")
 
-    return conditions, actions
-
 
 def _add_empty_priority_class_arrays(d):
     for pc in PRIORITY_CLASS_MAP.keys():
@@ -332,7 +347,9 @@ def _filter_ruleset_with_path(ruleset, path):
 
     attr = path[0]
     if attr in the_rule:
-        return the_rule[attr]
+        # Make sure we return a JSON object as the attribute may be a
+        # JSON value.
+        return {attr: the_rule[attr]}
     else:
         raise UnrecognizedRequestError()
 
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 07a2a5dd82..f5ed4f7302 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -228,7 +228,12 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
             allow_guest=True,
         )
 
-        content = _parse_json(request)
+        try:
+            content = _parse_json(request)
+        except:
+            # Turns out we used to ignore the body entirely, and some clients
+            # cheekily send invalid bodies.
+            content = {}
 
         if RoomID.is_valid(room_identifier):
             room_id = room_identifier
@@ -427,7 +432,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
         }:
             raise AuthError(403, "Guest access not allowed")
 
-        content = _parse_json(request)
+        try:
+            content = _parse_json(request)
+        except:
+            # Turns out we used to ignore the body entirely, and some clients
+            # cheekily send invalid bodies.
+            content = {}
 
         if membership_action == "invite" and self._has_3pid_invite_keys(content):
             yield self.handlers.room_member_handler.do_3pid_invite(
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 70ece56548..3ef91d34db 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -58,17 +58,20 @@ class UserPresenceState(namedtuple("UserPresenceState",
 class PresenceStore(SQLBaseStore):
     @defer.inlineCallbacks
     def update_presence(self, presence_states):
-        stream_id_manager = yield self._presence_id_gen.get_next(self)
-        with stream_id_manager as stream_id:
+        stream_ordering_manager = yield self._presence_id_gen.get_next_mult(
+            self, len(presence_states)
+        )
+
+        with stream_ordering_manager as stream_orderings:
             yield self.runInteraction(
                 "update_presence",
-                self._update_presence_txn, stream_id, presence_states,
+                self._update_presence_txn, stream_orderings, presence_states,
             )
 
-        defer.returnValue((stream_id, self._presence_id_gen.get_max_token()))
+        defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token()))
 
-    def _update_presence_txn(self, txn, stream_id, presence_states):
-        for state in presence_states:
+    def _update_presence_txn(self, txn, stream_orderings, presence_states):
+        for stream_id, state in zip(stream_orderings, presence_states):
             txn.call_after(
                 self.presence_stream_cache.entity_has_changed,
                 state.user_id, stream_id,
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index e19a81e41f..bb5c14d912 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -294,6 +294,31 @@ class PushRuleStore(SQLBaseStore):
             self.get_push_rules_enabled_for_user.invalidate, (user_id,)
         )
 
+    def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+        actions_json = json.dumps(actions)
+
+        def set_push_rule_actions_txn(txn):
+            if is_default_rule:
+                # Add a dummy rule to the rules table with the user specified
+                # actions.
+                priority_class = -1
+                priority = 1
+                self._upsert_push_rule_txn(
+                    txn, user_id, rule_id, priority_class, priority,
+                    "[]", actions_json
+                )
+            else:
+                self._simple_update_one_txn(
+                    txn,
+                    "push_rules",
+                    {'user_name': user_id, 'rule_id': rule_id},
+                    {'actions': actions_json},
+                )
+
+        return self.runInteraction(
+            "set_push_rule_actions", set_push_rule_actions_txn,
+        )
+
 
 class RuleNotFoundException(Exception):
     pass
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 967c732bda..03a9b66e4a 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -387,3 +387,47 @@ class RegistrationStore(SQLBaseStore):
             "find_next_generated_user_id",
             _find_next_generated_user_id
         )))
+
+    @defer.inlineCallbacks
+    def get_3pid_guest_access_token(self, medium, address):
+        ret = yield self._simple_select_one(
+            "threepid_guest_access_tokens",
+            {
+                "medium": medium,
+                "address": address
+            },
+            ["guest_access_token"], True, 'get_3pid_guest_access_token'
+        )
+        if ret:
+            defer.returnValue(ret["guest_access_token"])
+        defer.returnValue(None)
+
+    @defer.inlineCallbacks
+    def save_or_get_3pid_guest_access_token(
+            self, medium, address, access_token, inviter_user_id
+    ):
+        """
+        Gets the 3pid's guest access token if exists, else saves access_token.
+
+        :param medium (str): Medium of the 3pid. Must be "email".
+        :param address (str): 3pid address.
+        :param access_token (str): The access token to persist if none is
+            already persisted.
+        :param inviter_user_id (str): User ID of the inviter.
+        :return (deferred str): Whichever access token is persisted at the end
+            of this function call.
+        """
+        def insert(txn):
+            txn.execute(
+                "INSERT INTO threepid_guest_access_tokens "
+                "(medium, address, guest_access_token, first_inviter) "
+                "VALUES (?, ?, ?, ?)",
+                (medium, address, access_token, inviter_user_id)
+            )
+
+        try:
+            yield self.runInteraction("save_3pid_guest_access_token", insert)
+            defer.returnValue(access_token)
+        except self.database_engine.module.IntegrityError:
+            ret = yield self.get_3pid_guest_access_token(medium, address)
+            defer.returnValue(ret)
diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql
new file mode 100644
index 0000000000..0dd2f1360c
--- /dev/null
+++ b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql
@@ -0,0 +1,24 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Stores guest account access tokens generated for unbound 3pids.
+CREATE TABLE threepid_guest_access_tokens(
+    medium TEXT, -- The medium of the 3pid. Must be "email".
+    address TEXT, -- The 3pid address.
+    guest_access_token TEXT, -- The access token for a guest user for this 3pid.
+    first_inviter TEXT -- User ID of the first user to invite this 3pid to a room.
+);
+
+CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address);