diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/federation/federation_client.py | 30 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 43 | ||||
-rw-r--r-- | synapse/federation/transport/client.py | 70 | ||||
-rw-r--r-- | synapse/federation/transport/server.py | 20 | ||||
-rw-r--r-- | synapse/handlers/message.py | 18 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 89 | ||||
-rw-r--r-- | synapse/metrics/__init__.py | 28 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/keys.py | 100 | ||||
-rw-r--r-- | synapse/storage/presence.py | 32 | ||||
-rw-r--r-- | synapse/storage/state.py | 2 | ||||
-rw-r--r-- | synapse/util/caches/descriptors.py | 2 |
11 files changed, 384 insertions, 50 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 58a6d6a0ed..f5e346cdbc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -134,6 +134,36 @@ class FederationClient(FederationBase): destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail ) + @log_function + def query_client_keys(self, destination, content): + """Query device keys for a device hosted on a remote server. + + Args: + destination (str): Domain name of the remote homeserver + content (dict): The query content. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + sent_queries_counter.inc("client_device_keys") + return self.transport_layer.query_client_keys(destination, content) + + @log_function + def claim_client_keys(self, destination, content): + """Claims one-time keys for a device hosted on a remote server. + + Args: + destination (str): Domain name of the remote homeserver + content (dict): The query content. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + sent_queries_counter.inc("client_one_time_keys") + return self.transport_layer.claim_client_keys(destination, content) + @defer.inlineCallbacks @log_function def backfill(self, dest, context, limit, extremities): diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index cd79e23f4b..725c6f3fa5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError from synapse.crypto.event_signing import compute_event_signature +import simplejson as json import logging @@ -314,6 +315,48 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function + def on_query_client_keys(self, origin, content): + query = [] + for user_id, device_ids in content.get("device_keys", {}).items(): + if not device_ids: + query.append((user_id, None)) + else: + for device_id in device_ids: + query.append((user_id, device_id)) + + results = yield self.store.get_e2e_device_keys(query) + + json_result = {} + for user_id, device_keys in results.items(): + for device_id, json_bytes in device_keys.items(): + json_result.setdefault(user_id, {})[device_id] = json.loads( + json_bytes + ) + + defer.returnValue({"device_keys": json_result}) + + @defer.inlineCallbacks + @log_function + def on_claim_client_keys(self, origin, content): + query = [] + for user_id, device_keys in content.get("one_time_keys", {}).items(): + for device_id, algorithm in device_keys.items(): + query.append((user_id, device_id, algorithm)) + + results = yield self.store.claim_e2e_one_time_keys(query) + + json_result = {} + for user_id, device_keys in results.items(): + for device_id, keys in device_keys.items(): + for key_id, json_bytes in keys.items(): + json_result.setdefault(user_id, {})[device_id] = { + key_id: json.loads(json_bytes) + } + + defer.returnValue({"one_time_keys": json_result}) + + @defer.inlineCallbacks + @log_function def on_get_missing_events(self, origin, room_id, earliest_events, latest_events, limit, min_depth): missing_events = yield self.handler.on_get_missing_events( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 610a4c3163..ced703364b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -224,6 +224,76 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function + def query_client_keys(self, destination, query_content): + """Query the device keys for a list of user ids hosted on a remote + server. + + Request: + { + "device_keys": { + "<user_id>": ["<device_id>"] + } } + + Response: + { + "device_keys": { + "<user_id>": { + "<device_id>": {...} + } } } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the device keys. + """ + path = PREFIX + "/user/keys/query" + + content = yield self.client.post_json( + destination=destination, + path=path, + data=query_content, + ) + defer.returnValue(content) + + @defer.inlineCallbacks + @log_function + def claim_client_keys(self, destination, query_content): + """Claim one-time keys for a list of devices hosted on a remote server. + + Request: + { + "one_time_keys": { + "<user_id>": { + "<device_id>": "<algorithm>" + } } } + + Response: + { + "device_keys": { + "<user_id>": { + "<device_id>": { + "<algorithm>:<key_id>": "<key_base64>" + } } } } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the one-time keys. + """ + + path = PREFIX + "/user/keys/claim" + + content = yield self.client.post_json( + destination=destination, + path=path, + data=query_content, + ) + defer.returnValue(content) + + @defer.inlineCallbacks + @log_function def get_missing_events(self, destination, room_id, earliest_events, latest_events, limit, min_depth): path = PREFIX + "/get_missing_events/%s" % (room_id,) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index bad93c6b2f..36f250e1a3 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet): defer.returnValue((200, content)) +class FederationClientKeysQueryServlet(BaseFederationServlet): + PATH = "/user/keys/query" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query): + response = yield self.handler.on_query_client_keys(origin, content) + defer.returnValue((200, response)) + + +class FederationClientKeysClaimServlet(BaseFederationServlet): + PATH = "/user/keys/claim" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query): + response = yield self.handler.on_claim_client_keys(origin, content) + defer.returnValue((200, response)) + + class FederationQueryAuthServlet(BaseFederationServlet): PATH = "/query_auth/([^/]*)/([^/]*)" @@ -373,4 +391,6 @@ SERVLET_CLASSES = ( FederationQueryAuthServlet, FederationGetMissingEventsServlet, FederationEventAuthServlet, + FederationClientKeysQueryServlet, + FederationClientKeysClaimServlet, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 29e81085d1..f12465fa2c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -460,20 +460,14 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_presence(): - presence_defs = yield defer.DeferredList( - [ - presence_handler.get_state( - target_user=UserID.from_string(m.user_id), - auth_user=auth_user, - as_event=True, - check_auth=False, - ) - for m in room_members - ], - consumeErrors=True, + states = yield presence_handler.get_states( + target_users=[UserID.from_string(m.user_id) for m in room_members], + auth_user=auth_user, + as_event=True, + check_auth=False, ) - defer.returnValue([p for success, p in presence_defs if success]) + defer.returnValue(states.values()) receipts_handler = self.hs.get_handlers().receipts_handler diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 341a516da2..e91e81831e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -192,6 +192,20 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks def get_state(self, target_user, auth_user, as_event=False, check_auth=True): + """Get the current presence state of the given user. + + Args: + target_user (UserID): The user whose presence we want + auth_user (UserID): The user requesting the presence, used for + checking if said user is allowed to see the persence of the + `target_user` + as_event (bool): Format the return as an event or not? + check_auth (bool): Perform the auth checks or not? + + Returns: + dict: The presence state of the `target_user`, whose format depends + on the `as_event` argument. + """ if self.hs.is_mine(target_user): if check_auth: visible = yield self.is_presence_visible( @@ -233,6 +247,81 @@ class PresenceHandler(BaseHandler): defer.returnValue(state) @defer.inlineCallbacks + def get_states(self, target_users, auth_user, as_event=False, check_auth=True): + """A batched version of the `get_state` method that accepts a list of + `target_users` + + Args: + target_users (list): The list of UserID's whose presence we want + auth_user (UserID): The user requesting the presence, used for + checking if said user is allowed to see the persence of the + `target_users` + as_event (bool): Format the return as an event or not? + check_auth (bool): Perform the auth checks or not? + + Returns: + dict: A mapping from user -> presence_state + """ + local_users, remote_users = partitionbool( + target_users, + lambda u: self.hs.is_mine(u) + ) + + if check_auth: + for user in local_users: + visible = yield self.is_presence_visible( + observer_user=auth_user, + observed_user=user + ) + + if not visible: + raise SynapseError(404, "Presence information not visible") + + results = {} + if local_users: + for user in local_users: + if user in self._user_cachemap: + results[user] = self._user_cachemap[user].get_state() + + local_to_user = {u.localpart: u for u in local_users} + + states = yield self.store.get_presence_states( + [u.localpart for u in local_users if u not in results] + ) + + for local_part, state in states.items(): + if state is None: + continue + res = {"presence": state["state"]} + if "status_msg" in state and state["status_msg"]: + res["status_msg"] = state["status_msg"] + results[local_to_user[local_part]] = res + + for user in remote_users: + # TODO(paul): Have remote server send us permissions set + results[user] = self._get_or_offline_usercache(user).get_state() + + for state in results.values(): + if "last_active" in state: + state["last_active_ago"] = int( + self.clock.time_msec() - state.pop("last_active") + ) + + if as_event: + for user, state in results.items(): + content = state + content["user_id"] = user.to_string() + + if "last_active" in content: + content["last_active_ago"] = int( + self._clock.time_msec() - content.pop("last_active") + ) + + results[user] = {"type": "m.presence", "content": content} + + defer.returnValue(results) + + @defer.inlineCallbacks @log_function def set_state(self, target_user, auth_user, state): # return diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 0be9772991..d7bcad8a8a 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -158,18 +158,40 @@ def runUntilCurrentTimer(func): @functools.wraps(func) def f(*args, **kwargs): - pending_calls = len(reactor.getDelayedCalls()) + now = reactor.seconds() + num_pending = 0 + + # _newTimedCalls is one long list of *all* pending calls. Below loop + # is based off of impl of reactor.runUntilCurrent + for delayed_call in reactor._newTimedCalls: + if delayed_call.time > now: + break + + if delayed_call.delayed_time > 0: + continue + + num_pending += 1 + + num_pending += len(reactor.threadCallQueue) + start = time.time() * 1000 ret = func(*args, **kwargs) end = time.time() * 1000 tick_time.inc_by(end - start) - pending_calls_metric.inc_by(pending_calls) + pending_calls_metric.inc_by(num_pending) return ret return f -if hasattr(reactor, "runUntilCurrent"): +try: + # Ensure the reactor has all the attributes we expect + reactor.runUntilCurrent + reactor._newTimedCalls + reactor.threadCallQueue + # runUntilCurrent is called when we have pending calls. It is called once # per iteratation after fd polling. reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent) +except AttributeError: + pass diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 5f3a6207b5..718928eedd 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet +from synapse.types import UserID from syutil.jsonutil import encode_canonical_json from ._base import client_v2_pattern @@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet): super(KeyQueryServlet, self).__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() + self.federation = hs.get_replication_layer() + self.is_mine = hs.is_mine @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): - logger.debug("onPOST") yield self.auth.get_user_by_req(request) try: body = json.loads(request.content.read()) except: raise SynapseError(400, "Invalid key JSON") - query = [] - for user_id, device_ids in body.get("device_keys", {}).items(): - if not device_ids: - query.append((user_id, None)) - else: - for device_id in device_ids: - query.append((user_id, device_id)) - results = yield self.store.get_e2e_device_keys(query) - defer.returnValue(self.json_result(request, results)) + result = yield self.handle_request(body) + defer.returnValue(result) @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): auth_user, client_info = yield self.auth.get_user_by_req(request) auth_user_id = auth_user.to_string() - if not user_id: - user_id = auth_user_id - if not device_id: - device_id = None - # Returns a map of user_id->device_id->json_bytes. - results = yield self.store.get_e2e_device_keys([(user_id, device_id)]) - defer.returnValue(self.json_result(request, results)) - - def json_result(self, request, results): + user_id = user_id if user_id else auth_user_id + device_ids = [device_id] if device_id else [] + result = yield self.handle_request( + {"device_keys": {user_id: device_ids}} + ) + defer.returnValue(result) + + @defer.inlineCallbacks + def handle_request(self, body): + local_query = [] + remote_queries = {} + for user_id, device_ids in body.get("device_keys", {}).items(): + user = UserID.from_string(user_id) + if self.is_mine(user): + if not device_ids: + local_query.append((user_id, None)) + else: + for device_id in device_ids: + local_query.append((user_id, device_id)) + else: + remote_queries.setdefault(user.domain, {})[user_id] = list( + device_ids + ) + results = yield self.store.get_e2e_device_keys(local_query) + json_result = {} for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): json_result.setdefault(user_id, {})[device_id] = json.loads( json_bytes ) - return (200, {"device_keys": json_result}) + + for destination, device_keys in remote_queries.items(): + remote_result = yield self.federation.query_client_keys( + destination, {"device_keys": device_keys} + ) + for user_id, keys in remote_result["device_keys"].items(): + if user_id in device_keys: + json_result[user_id] = keys + defer.returnValue((200, {"device_keys": json_result})) class OneTimeKeyServlet(RestServlet): @@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() + self.federation = hs.get_replication_layer() + self.is_mine = hs.is_mine @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) - results = yield self.store.claim_e2e_one_time_keys( - [(user_id, device_id, algorithm)] + result = yield self.handle_request( + {"one_time_keys": {user_id: {device_id: algorithm}}} ) - defer.returnValue(self.json_result(request, results)) + defer.returnValue(result) @defer.inlineCallbacks def on_POST(self, request, user_id, device_id, algorithm): @@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet): body = json.loads(request.content.read()) except: raise SynapseError(400, "Invalid key JSON") - query = [] + result = yield self.handle_request(body) + defer.returnValue(result) + + @defer.inlineCallbacks + def handle_request(self, body): + local_query = [] + remote_queries = {} for user_id, device_keys in body.get("one_time_keys", {}).items(): - for device_id, algorithm in device_keys.items(): - query.append((user_id, device_id, algorithm)) - results = yield self.store.claim_e2e_one_time_keys(query) - defer.returnValue(self.json_result(request, results)) + user = UserID.from_string(user_id) + if self.is_mine(user): + for device_id, algorithm in device_keys.items(): + local_query.append((user_id, device_id, algorithm)) + else: + remote_queries.setdefault(user.domain, {})[user_id] = ( + device_keys + ) + results = yield self.store.claim_e2e_one_time_keys(local_query) - def json_result(self, request, results): json_result = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): @@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet): json_result.setdefault(user_id, {})[device_id] = { key_id: json.loads(json_bytes) } - return (200, {"one_time_keys": json_result}) + + for destination, device_keys in remote_queries.items(): + remote_result = yield self.federation.claim_client_keys( + destination, {"one_time_keys": device_keys} + ) + for user_id, keys in remote_result["one_time_keys"].items(): + if user_id in device_keys: + json_result[user_id] = keys + + defer.returnValue((200, {"one_time_keys": json_result})) def register_servlets(hs, http_server): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 4f91a2b87c..34ca3b9a54 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -14,19 +14,22 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList from twisted.internet import defer class PresenceStore(SQLBaseStore): def create_presence(self, user_localpart): - return self._simple_insert( + res = self._simple_insert( table="presence", values={"user_id": user_localpart}, desc="create_presence", ) + self.get_presence_state.invalidate((user_localpart,)) + return res + def has_presence_state(self, user_localpart): return self._simple_select_one( table="presence", @@ -36,6 +39,7 @@ class PresenceStore(SQLBaseStore): desc="has_presence_state", ) + @cached(max_entries=2000) def get_presence_state(self, user_localpart): return self._simple_select_one( table="presence", @@ -44,8 +48,27 @@ class PresenceStore(SQLBaseStore): desc="get_presence_state", ) + @cachedList(get_presence_state.cache, list_name="user_localparts") + def get_presence_states(self, user_localparts): + def f(txn): + results = {} + for user_localpart in user_localparts: + res = self._simple_select_one_txn( + txn, + table="presence", + keyvalues={"user_id": user_localpart}, + retcols=["state", "status_msg", "mtime"], + allow_none=True, + ) + if res: + results[user_localpart] = res + + return results + + return self.runInteraction("get_presence_states", f) + def set_presence_state(self, user_localpart, new_state): - return self._simple_update_one( + res = self._simple_update_one( table="presence", keyvalues={"user_id": user_localpart}, updatevalues={"state": new_state["state"], @@ -54,6 +77,9 @@ class PresenceStore(SQLBaseStore): desc="set_presence_state", ) + self.get_presence_state.invalidate((user_localpart,)) + return res + def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( table="presence_allow_inbound", diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ecb62e6dfd..ab3ad5a076 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -240,7 +240,7 @@ class StateStore(SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) - @cached(num_args=2, lru=True, max_entries=100000) + @cached(num_args=2, lru=True, max_entries=10000) def _get_state_group_for_event(self, room_id, event_id): return self._simple_select_one_onecol( table="event_to_state_groups", diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 83bfec2f02..362944bc51 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -293,7 +293,7 @@ class CacheListDescriptor(object): # we can insert the new deferred into the cache. for arg in missing: observer = ret_d.observe() - observer.addCallback(lambda r, arg: r[arg], arg) + observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) |