summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/app/generic_worker.py6
-rw-r--r--synapse/federation/transport/client.py49
-rw-r--r--synapse/handlers/device.py14
-rw-r--r--synapse/handlers/e2e_keys.py166
-rw-r--r--synapse/handlers/sync.py7
-rw-r--r--synapse/push/push_rule_evaluator.py69
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py8
8 files changed, 262 insertions, 59 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 3bf2d02450..d8d340f426 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.12.3"
+__version__ = "1.12.4"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 5363642d64..66be6ea2ec 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -98,6 +98,10 @@ from synapse.rest.client.v1.voip import VoipRestServlet
 from synapse.rest.client.v2_alpha import groups, sync, user_directory
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
+from synapse.rest.client.v2_alpha.account_data import (
+    AccountDataServlet,
+    RoomAccountDataServlet,
+)
 from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
@@ -475,6 +479,8 @@ class GenericWorkerServer(HomeServer):
                     ProfileDisplaynameRestServlet(self).register(resource)
                     ProfileRestServlet(self).register(resource)
                     KeyUploadServlet(self).register(resource)
+                    AccountDataServlet(self).register(resource)
+                    RoomAccountDataServlet(self).register(resource)
 
                     sync.register_servlets(self, resource)
                     events.register_servlets(self, resource)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index dc563538de..383e3fdc8b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -399,20 +399,30 @@ class TransportLayerClient(object):
             {
               "device_keys": {
                 "<user_id>": ["<device_id>"]
-            } }
+              }
+            }
 
         Response:
             {
               "device_keys": {
                 "<user_id>": {
                   "<device_id>": {...}
-            } } }
+                }
+              },
+              "master_key": {
+                "<user_id>": {...}
+                }
+              },
+              "self_signing_key": {
+                "<user_id>": {...}
+              }
+            }
 
         Args:
             destination(str): The server to query.
             query_content(dict): The user ids to query.
         Returns:
-            A dict containg the device keys.
+            A dict containing device and cross-signing keys.
         """
         path = _create_v1_path("/user/keys/query")
 
@@ -429,14 +439,30 @@ class TransportLayerClient(object):
         Response:
             {
               "stream_id": "...",
-              "devices": [ { ... } ]
+              "devices": [ { ... } ],
+              "master_key": {
+                "user_id": "<user_id>",
+                "usage": [...],
+                "keys": {...},
+                "signatures": {
+                  "<user_id>": {...}
+                }
+              },
+              "self_signing_key": {
+                "user_id": "<user_id>",
+                "usage": [...],
+                "keys": {...},
+                "signatures": {
+                  "<user_id>": {...}
+                }
+              }
             }
 
         Args:
             destination(str): The server to query.
             query_content(dict): The user ids to query.
         Returns:
-            A dict containg the device keys.
+            A dict containing device and cross-signing keys.
         """
         path = _create_v1_path("/user/devices/%s", user_id)
 
@@ -454,8 +480,10 @@ class TransportLayerClient(object):
             {
               "one_time_keys": {
                 "<user_id>": {
-                    "<device_id>": "<algorithm>"
-            } } }
+                  "<device_id>": "<algorithm>"
+                }
+              }
+            }
 
         Response:
             {
@@ -463,13 +491,16 @@ class TransportLayerClient(object):
                 "<user_id>": {
                   "<device_id>": {
                     "<algorithm>:<key_id>": "<key_base64>"
-            } } } }
+                  }
+                }
+              }
+            }
 
         Args:
             destination(str): The server to query.
             query_content(dict): The user ids to query.
         Returns:
-            A dict containg the one-time keys.
+            A dict containing the one-time keys.
         """
 
         path = _create_v1_path("/user/keys/claim")
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index a514c30714..993499f446 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -125,8 +125,14 @@ class DeviceWorkerHandler(BaseHandler):
         users_who_share_room = yield self.store.get_users_who_share_room_with_user(
             user_id
         )
+
+        tracked_users = set(users_who_share_room)
+
+        # Always tell the user about their own devices
+        tracked_users.add(user_id)
+
         changed = yield self.store.get_users_whose_devices_changed(
-            from_token.device_list_key, users_who_share_room
+            from_token.device_list_key, tracked_users
         )
 
         # Then work out if any users have since joined
@@ -456,7 +462,11 @@ class DeviceHandler(DeviceWorkerHandler):
 
         room_ids = yield self.store.get_rooms_for_user(user_id)
 
-        yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)
+        # specify the user ID too since the user should always get their own device list
+        # updates, even if they aren't in any rooms.
+        yield self.notifier.on_new_event(
+            "device_list_key", position, users=[user_id], rooms=room_ids
+        )
 
         if hosts:
             logger.info(
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 95a9d71f41..8f1bc0323c 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -54,19 +54,23 @@ class E2eKeysHandler(object):
 
         self._edu_updater = SigningKeyEduUpdater(hs, self)
 
+        federation_registry = hs.get_federation_registry()
+
         self._is_master = hs.config.worker_app is None
         if not self._is_master:
             self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
                 hs
             )
+        else:
+            # Only register this edu handler on master as it requires writing
+            # device updates to the db
+            #
+            # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+            federation_registry.register_edu_handler(
+                "org.matrix.signing_key_update",
+                self._edu_updater.incoming_signing_key_update,
+            )
 
-        federation_registry = hs.get_federation_registry()
-
-        # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
-        federation_registry.register_edu_handler(
-            "org.matrix.signing_key_update",
-            self._edu_updater.incoming_signing_key_update,
-        )
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
         # "query handler" interface.
@@ -170,8 +174,8 @@ class E2eKeysHandler(object):
             """This is called when we are querying the device list of a user on
             a remote homeserver and their device list is not in the device list
             cache. If we share a room with this user and we're not querying for
-            specific user we will update the cache
-            with their device list."""
+            specific user we will update the cache with their device list.
+            """
 
             destination_query = remote_queries_not_in_cache[destination]
 
@@ -957,13 +961,19 @@ class E2eKeysHandler(object):
         return signature_list, failures
 
     @defer.inlineCallbacks
-    def _get_e2e_cross_signing_verify_key(self, user_id, key_type, from_user_id=None):
-        """Fetch the cross-signing public key from storage and interpret it.
+    def _get_e2e_cross_signing_verify_key(
+        self, user_id: str, key_type: str, from_user_id: str = None
+    ):
+        """Fetch locally or remotely query for a cross-signing public key.
+
+        First, attempt to fetch the cross-signing public key from storage.
+        If that fails, query the keys from the homeserver they belong to
+        and update our local copy.
 
         Args:
-            user_id (str): the user whose key should be fetched
-            key_type (str): the type of key to fetch
-            from_user_id (str): the user that we are fetching the keys for.
+            user_id: the user whose key should be fetched
+            key_type: the type of key to fetch
+            from_user_id: the user that we are fetching the keys for.
                 This affects what signatures are fetched.
 
         Returns:
@@ -972,16 +982,140 @@ class E2eKeysHandler(object):
 
         Raises:
             NotFoundError: if the key is not found
+            SynapseError: if `user_id` is invalid
         """
+        user = UserID.from_string(user_id)
         key = yield self.store.get_e2e_cross_signing_key(
             user_id, key_type, from_user_id
         )
+
+        if key:
+            # We found a copy of this key in our database. Decode and return it
+            key_id, verify_key = get_verify_key_from_cross_signing_key(key)
+            return key, key_id, verify_key
+
+        # If we couldn't find the key locally, and we're looking for keys of
+        # another user then attempt to fetch the missing key from the remote
+        # user's server.
+        #
+        # We may run into this in possible edge cases where a user tries to
+        # cross-sign a remote user, but does not share any rooms with them yet.
+        # Thus, we would not have their key list yet. We instead fetch the key,
+        # store it and notify clients of new, associated device IDs.
+        if self.is_mine(user) or key_type not in ["master", "self_signing"]:
+            # Note that master and self_signing keys are the only cross-signing keys we
+            # can request over federation
+            raise NotFoundError("No %s key found for %s" % (key_type, user_id))
+
+        (
+            key,
+            key_id,
+            verify_key,
+        ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
+
         if key is None:
-            logger.debug("no %s key found for %s", key_type, user_id)
             raise NotFoundError("No %s key found for %s" % (key_type, user_id))
-        key_id, verify_key = get_verify_key_from_cross_signing_key(key)
+
         return key, key_id, verify_key
 
+    @defer.inlineCallbacks
+    def _retrieve_cross_signing_keys_for_remote_user(
+        self, user: UserID, desired_key_type: str,
+    ):
+        """Queries cross-signing keys for a remote user and saves them to the database
+
+        Only the key specified by `key_type` will be returned, while all retrieved keys
+        will be saved regardless
+
+        Args:
+            user: The user to query remote keys for
+            desired_key_type: The type of key to receive. One of "master", "self_signing"
+
+        Returns:
+            Deferred[Tuple[Optional[Dict], Optional[str], Optional[VerifyKey]]]: A tuple
+            of the retrieved key content, the key's ID and the matching VerifyKey.
+            If the key cannot be retrieved, all values in the tuple will instead be None.
+        """
+        try:
+            remote_result = yield self.federation.query_user_devices(
+                user.domain, user.to_string()
+            )
+        except Exception as e:
+            logger.warning(
+                "Unable to query %s for cross-signing keys of user %s: %s %s",
+                user.domain,
+                user.to_string(),
+                type(e),
+                e,
+            )
+            return None, None, None
+
+        # Process each of the retrieved cross-signing keys
+        desired_key = None
+        desired_key_id = None
+        desired_verify_key = None
+        retrieved_device_ids = []
+        for key_type in ["master", "self_signing"]:
+            key_content = remote_result.get(key_type + "_key")
+            if not key_content:
+                continue
+
+            # Ensure these keys belong to the correct user
+            if "user_id" not in key_content:
+                logger.warning(
+                    "Invalid %s key retrieved, missing user_id field: %s",
+                    key_type,
+                    key_content,
+                )
+                continue
+            if user.to_string() != key_content["user_id"]:
+                logger.warning(
+                    "Found %s key of user %s when querying for keys of user %s",
+                    key_type,
+                    key_content["user_id"],
+                    user.to_string(),
+                )
+                continue
+
+            # Validate the key contents
+            try:
+                # verify_key is a VerifyKey from signedjson, which uses
+                # .version to denote the portion of the key ID after the
+                # algorithm and colon, which is the device ID
+                key_id, verify_key = get_verify_key_from_cross_signing_key(key_content)
+            except ValueError as e:
+                logger.warning(
+                    "Invalid %s key retrieved: %s - %s %s",
+                    key_type,
+                    key_content,
+                    type(e),
+                    e,
+                )
+                continue
+
+            # Note down the device ID attached to this key
+            retrieved_device_ids.append(verify_key.version)
+
+            # If this is the desired key type, save it and its ID/VerifyKey
+            if key_type == desired_key_type:
+                desired_key = key_content
+                desired_verify_key = verify_key
+                desired_key_id = key_id
+
+            # At the same time, store this key in the db for subsequent queries
+            yield self.store.set_e2e_cross_signing_key(
+                user.to_string(), key_type, key_content
+            )
+
+        # Notify clients that new devices for this user have been discovered
+        if retrieved_device_ids:
+            # XXX is this necessary?
+            yield self.device_handler.notify_device_update(
+                user.to_string(), retrieved_device_ids
+            )
+
+        return desired_key, desired_key_id, desired_verify_key
+
 
 def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
     """Check a cross-signing key uploaded by a user.  Performs some basic sanity
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 669dbc8a48..cfd5dfc9e5 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1143,9 +1143,14 @@ class SyncHandler(object):
                 user_id
             )
 
+            tracked_users = set(users_who_share_room)
+
+            # Always tell the user about their own devices
+            tracked_users.add(user_id)
+
             # Step 1a, check for changes in devices of users we share a room with
             users_that_have_changed = await self.store.get_users_whose_devices_changed(
-                since_token.device_list_key, users_who_share_room
+                since_token.device_list_key, tracked_users
             )
 
             # Step 1b, check for newly joined rooms
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index b1587183a8..4cd702b5fa 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,9 +16,11 @@
 
 import logging
 import re
+from typing import Pattern
 
 from six import string_types
 
+from synapse.events import EventBase
 from synapse.types import UserID
 from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
 from synapse.util.caches.lrucache import LruCache
@@ -56,18 +58,18 @@ def _test_ineq_condition(condition, number):
     rhs = m.group(2)
     if not rhs.isdigit():
         return False
-    rhs = int(rhs)
+    rhs_int = int(rhs)
 
     if ineq == "" or ineq == "==":
-        return number == rhs
+        return number == rhs_int
     elif ineq == "<":
-        return number < rhs
+        return number < rhs_int
     elif ineq == ">":
-        return number > rhs
+        return number > rhs_int
     elif ineq == ">=":
-        return number >= rhs
+        return number >= rhs_int
     elif ineq == "<=":
-        return number <= rhs
+        return number <= rhs_int
     else:
         return False
 
@@ -83,7 +85,13 @@ def tweaks_for_actions(actions):
 
 
 class PushRuleEvaluatorForEvent(object):
-    def __init__(self, event, room_member_count, sender_power_level, power_levels):
+    def __init__(
+        self,
+        event: EventBase,
+        room_member_count: int,
+        sender_power_level: int,
+        power_levels: dict,
+    ):
         self._event = event
         self._room_member_count = room_member_count
         self._sender_power_level = sender_power_level
@@ -92,7 +100,7 @@ class PushRuleEvaluatorForEvent(object):
         # Maps strings of e.g. 'content.body' -> event["content"]["body"]
         self._value_cache = _flatten_dict(event)
 
-    def matches(self, condition, user_id, display_name):
+    def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
         if condition["kind"] == "event_match":
             return self._event_match(condition, user_id)
         elif condition["kind"] == "contains_display_name":
@@ -106,7 +114,7 @@ class PushRuleEvaluatorForEvent(object):
         else:
             return True
 
-    def _event_match(self, condition, user_id):
+    def _event_match(self, condition: dict, user_id: str) -> bool:
         pattern = condition.get("pattern", None)
 
         if not pattern:
@@ -134,7 +142,7 @@ class PushRuleEvaluatorForEvent(object):
 
             return _glob_matches(pattern, haystack)
 
-    def _contains_display_name(self, display_name):
+    def _contains_display_name(self, display_name: str) -> bool:
         if not display_name:
             return False
 
@@ -142,51 +150,52 @@ class PushRuleEvaluatorForEvent(object):
         if not body:
             return False
 
-        return _glob_matches(display_name, body, word_boundary=True)
+        # Similar to _glob_matches, but do not treat display_name as a glob.
+        r = regex_cache.get((display_name, False, True), None)
+        if not r:
+            r = re.escape(display_name)
+            r = _re_word_boundary(r)
+            r = re.compile(r, flags=re.IGNORECASE)
+            regex_cache[(display_name, False, True)] = r
+
+        return r.search(body)
 
-    def _get_value(self, dotted_key):
+    def _get_value(self, dotted_key: str) -> str:
         return self._value_cache.get(dotted_key, None)
 
 
-# Caches (glob, word_boundary) -> regex for push. See _glob_matches
+# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
 regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
 register_cache("cache", "regex_push_cache", regex_cache)
 
 
-def _glob_matches(glob, value, word_boundary=False):
+def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
     """Tests if value matches glob.
 
     Args:
-        glob (string)
-        value (string): String to test against glob.
-        word_boundary (bool): Whether to match against word boundaries or entire
+        glob
+        value: String to test against glob.
+        word_boundary: Whether to match against word boundaries or entire
             string. Defaults to False.
-
-    Returns:
-        bool
     """
 
     try:
-        r = regex_cache.get((glob, word_boundary), None)
+        r = regex_cache.get((glob, True, word_boundary), None)
         if not r:
             r = _glob_to_re(glob, word_boundary)
-            regex_cache[(glob, word_boundary)] = r
+            regex_cache[(glob, True, word_boundary)] = r
         return r.search(value)
     except re.error:
         logger.warning("Failed to parse glob to regex: %r", glob)
         return False
 
 
-def _glob_to_re(glob, word_boundary):
+def _glob_to_re(glob: str, word_boundary: bool) -> Pattern:
     """Generates regex for a given glob.
 
     Args:
-        glob (string)
-        word_boundary (bool): Whether to match against word boundaries or entire
-            string. Defaults to False.
-
-    Returns:
-        regex object
+        glob
+        word_boundary: Whether to match against word boundaries or entire string.
     """
     if IS_GLOB.search(glob):
         r = re.escape(glob)
@@ -219,7 +228,7 @@ def _glob_to_re(glob, word_boundary):
         return re.compile(r, flags=re.IGNORECASE)
 
 
-def _re_word_boundary(r):
+def _re_word_boundary(r: str) -> str:
     """
     Adds word boundary characters to the start and end of an
     expression to require that the match occur as a whole word,
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 64eb7fec3b..c1d4cd0caf 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -38,8 +38,12 @@ class AccountDataServlet(RestServlet):
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
+        self._is_worker = hs.config.worker_app is not None
 
     async def on_PUT(self, request, user_id, account_data_type):
+        if self._is_worker:
+            raise Exception("Cannot handle PUT /account_data on worker")
+
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
@@ -86,8 +90,12 @@ class RoomAccountDataServlet(RestServlet):
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
+        self._is_worker = hs.config.worker_app is not None
 
     async def on_PUT(self, request, user_id, room_id, account_data_type):
+        if self._is_worker:
+            raise Exception("Cannot handle PUT /account_data on worker")
+
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")