From 555d702e3441e355d6b5e23702ab9505728dea71 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 31 Dec 2016 15:21:37 +0000 Subject: limit total timeout for get_missing_events to 10s --- synapse/federation/federation_client.py | 4 +++- synapse/federation/federation_server.py | 5 +++++ synapse/federation/transport/client.py | 5 +++-- 3 files changed, 11 insertions(+), 3 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 6851f2376d..b4bcec77ed 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -707,7 +707,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def get_missing_events(self, destination, room_id, earliest_events_ids, - latest_events, limit, min_depth): + latest_events, limit, min_depth, timeout): """Tries to fetch events we are missing. This is called when we receive an event without having received all of its ancestors. @@ -721,6 +721,7 @@ class FederationClient(FederationBase): have all previous events for. limit (int): Maximum number of events to return. min_depth (int): Minimum depth of events tor return. + timeout (int): Max time to wait in ms """ try: content = yield self.transport_layer.get_missing_events( @@ -730,6 +731,7 @@ class FederationClient(FederationBase): latest_events=[e.event_id for e in latest_events], limit=limit, min_depth=min_depth, + timeout=timeout, ) events = [ diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index f4c60e67e3..6d76e6f917 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -425,6 +425,7 @@ class FederationServer(FederationBase): " limit: %d, min_depth: %d", earliest_events, latest_events, limit, min_depth ) + missing_events = yield self.handler.on_get_missing_events( origin, room_id, earliest_events, latest_events, limit, min_depth ) @@ -567,6 +568,9 @@ class FederationServer(FederationBase): len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] ) + # XXX: we set timeout to 10s to help workaround + # https://github.com/matrix-org/synapse/issues/1733 + missing_events = yield self.get_missing_events( origin, pdu.room_id, @@ -574,6 +578,7 @@ class FederationServer(FederationBase): latest_events=[pdu], limit=10, min_depth=min_depth, + timeout=10000, ) # We want to sort these by depth so we process them and diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 491cdc29e1..915af34409 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -386,7 +386,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def get_missing_events(self, destination, room_id, earliest_events, - latest_events, limit, min_depth): + latest_events, limit, min_depth, timeout): path = PREFIX + "/get_missing_events/%s" % (room_id,) content = yield self.client.post_json( @@ -397,7 +397,8 @@ class TransportLayerClient(object): "min_depth": int(min_depth), "earliest_events": earliest_events, "latest_events": latest_events, - } + }, + timeout=timeout, ) defer.returnValue(content) -- cgit 1.4.1 From 8e82611f3726bfd577ca77a39328c63ecb29410f Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 5 Jan 2017 11:44:44 +0000 Subject: fix comment --- synapse/federation/federation_server.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 6d76e6f917..800f04189f 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -569,7 +569,23 @@ class FederationServer(FederationBase): ) # XXX: we set timeout to 10s to help workaround - # https://github.com/matrix-org/synapse/issues/1733 + # https://github.com/matrix-org/synapse/issues/1733. + # The reason is to avoid holding the linearizer lock + # whilst processing inbound /send transactions, causing + # FDs to stack up and block other inbound transactions + # which empirically can currently take up to 30 minutes. + # + # N.B. this explicitly disables retry attempts. + # + # N.B. this also increases our chances of falling back to + # fetching fresh state for the room if the missing event + # can't be found, which slightly reduces our security. + # it may also increase our DAG extremity count for the room, + # causing additional state resolution? See #1760. + # However, fetching state doesn't hold the linearizer lock + # apparently. + # + # see https://github.com/matrix-org/synapse/pull/1744 missing_events = yield self.get_missing_events( origin, -- cgit 1.4.1 From f7085ac84f76ee621ea52c9eaa0399c786d14027 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 9 Jan 2017 17:17:10 +0000 Subject: Name linearizer's for better logs --- synapse/federation/federation_server.py | 4 ++-- synapse/handlers/room_member.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/state.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 1fee4e83a6..862ccbef5d 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -52,8 +52,8 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() - self._room_pdu_linearizer = Linearizer() - self._server_linearizer = Linearizer() + self._room_pdu_linearizer = Linearizer("fed_room_pdu") + self._server_linearizer = Linearizer("fed_server") # We cache responses to state queries, as they take a while and often # come in waves. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 2f8782e522..649aaf6d29 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -45,7 +45,7 @@ class RoomMemberHandler(BaseHandler): def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) - self.member_linearizer = Linearizer() + self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 692e078419..2b693ae548 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -61,7 +61,7 @@ class MediaRepository(object): self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements - self.remote_media_linearizer = Linearizer() + self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() diff --git a/synapse/state.py b/synapse/state.py index 8003099c88..b9d5627a82 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -89,7 +89,7 @@ class StateHandler(object): # dict of set of event_ids -> _StateCacheEntry. self._state_cache = None - self.resolve_linearizer = Linearizer() + self.resolve_linearizer = Linearizer(name="state_resolve_lock") def start_caching(self): logger.debug("start_caching") -- cgit 1.4.1 From e6153e1bd10529b28b69820decbc039b9d6a1f27 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 13:21:04 +0000 Subject: Fix couple of federation state bugs --- synapse/federation/federation_client.py | 6 ++++-- synapse/handlers/federation.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b4bcec77ed..c9175bb33d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred -from synapse.events import FrozenEvent +from synapse.events import FrozenEvent, builder import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination @@ -499,8 +499,10 @@ class FederationClient(FederationBase): if "prev_state" not in pdu_dict: pdu_dict["prev_state"] = [] + ev = builder.EventBuilder(pdu_dict) + defer.returnValue( - (destination, self.event_from_pdu_json(pdu_dict)) + (destination, ev) ) break except CodeMessageException as e: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ea89e0cf2d..ced5646e9a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -596,7 +596,7 @@ class FederationHandler(BaseHandler): preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) for e in event_ids ])) - states = dict(zip(event_ids, [s[1] for s in states])) + states = dict(zip(event_ids, [s.state for s in states])) state_map = yield self.store.get_events( [e_id for ids in states.values() for e_id in ids], -- cgit 1.4.1 From f878f64f4314dae6bd68b11ad1edbf0883f9bd8f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 17:20:39 +0000 Subject: Lower the not retrying host log line to debug --- synapse/federation/transaction_queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/federation') diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 7db7b806dc..6b3a7abb9e 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -362,7 +362,7 @@ class TransactionQueue(object): if not success: break except NotRetryingDestination: - logger.info( + logger.debug( "TX [%s] not ready for retry yet - " "dropping transaction for now", destination, -- cgit 1.4.1 From 2367c5568c01bc65aacc955b76ba707918b37f1e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Jan 2017 14:27:27 +0000 Subject: Add basic implementation of local device list changes --- synapse/federation/transaction_queue.py | 24 ++- synapse/handlers/device.py | 65 ++++++-- synapse/handlers/e2e_keys.py | 1 + synapse/handlers/sync.py | 13 ++ synapse/rest/client/v2_alpha/sync.py | 6 +- synapse/storage/__init__.py | 11 ++ synapse/storage/_base.py | 6 + synapse/storage/devices.py | 169 +++++++++++++++++++-- synapse/storage/end_to_end_keys.py | 23 ++- .../schema/delta/40/device_list_streams.sql | 56 +++++++ synapse/streams/events.py | 4 + synapse/types.py | 2 + tests/handlers/test_typing.py | 3 + tests/rest/client/v1/test_rooms.py | 4 +- 14 files changed, 348 insertions(+), 39 deletions(-) create mode 100644 synapse/storage/schema/delta/40/device_list_streams.sql (limited to 'synapse/federation') diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 6b3a7abb9e..65c6673a87 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -100,6 +100,7 @@ class TransactionQueue(object): self.pending_failures_by_dest = {} self.last_device_stream_id_by_dest = {} + self.last_device_list_stream_id_by_dest = {} # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) @@ -356,7 +357,7 @@ class TransactionQueue(object): success = yield self._send_new_transaction( destination, pending_pdus, pending_edus, pending_failures, device_stream_id, - should_delete_from_device_stream=bool(device_message_edus), + includes_device_messages=bool(device_message_edus), limiter=limiter, ) if not success: @@ -373,6 +374,8 @@ class TransactionQueue(object): @defer.inlineCallbacks def _get_new_device_messages(self, destination): + # TODO: Send appropriate device list messages + last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0) to_device_stream_id = self.store.get_to_device_stream_token() contents, stream_id = yield self.store.get_new_device_msgs_for_remote( @@ -387,13 +390,27 @@ class TransactionQueue(object): ) for content in contents ] + + last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0) + now_stream_id, results = yield self.store.get_devices_by_remote( + destination, last_device_list + ) + edus.extend( + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.device_list_update", + content=content, + ) + for content in results + ) defer.returnValue((edus, stream_id)) @measure_func("_send_new_transaction") @defer.inlineCallbacks def _send_new_transaction(self, destination, pending_pdus, pending_edus, pending_failures, device_stream_id, - should_delete_from_device_stream, limiter): + includes_device_messages, limiter): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) @@ -506,7 +523,8 @@ class TransactionQueue(object): success = False else: # Remove the acknowledged device messages from the database - if should_delete_from_device_stream: + # Only bother if we actually sent some device messages + if includes_device_messages: yield self.store.delete_device_msgs_for_remote( destination, device_stream_id ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index aa68755936..d92780b642 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,6 +15,7 @@ from synapse.api import errors from synapse.util import stringutils +from synapse.types import get_domain_from_id from twisted.internet import defer from ._base import BaseHandler @@ -27,6 +28,8 @@ class DeviceHandler(BaseHandler): def __init__(self, hs): super(DeviceHandler, self).__init__(hs) + self.state = hs.get_state_handler() + @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, initial_device_display_name=None): @@ -45,29 +48,29 @@ class DeviceHandler(BaseHandler): str: device id (generated if none was supplied) """ if device_id is not None: - yield self.store.store_device( + new_device = yield self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, - ignore_if_known=True, ) + if new_device: + yield self.notify_device_update(user_id, device_id) defer.returnValue(device_id) # if the device id is not specified, we'll autogen one, but loop a few # times in case of a clash. attempts = 0 while attempts < 5: - try: - device_id = stringutils.random_string(10).upper() - yield self.store.store_device( - user_id=user_id, - device_id=device_id, - initial_device_display_name=initial_device_display_name, - ignore_if_known=False, - ) + device_id = stringutils.random_string(10).upper() + new_device = yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ) + if new_device: + yield self.notify_device_update(user_id, device_id) defer.returnValue(device_id) - except errors.StoreError: - attempts += 1 + attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @@ -147,6 +150,8 @@ class DeviceHandler(BaseHandler): user_id=user_id, device_id=device_id ) + yield self.notify_device_update(user_id, device_id) + @defer.inlineCallbacks def update_device(self, user_id, device_id, content): """ Update the given device @@ -166,12 +171,48 @@ class DeviceHandler(BaseHandler): device_id, new_display_name=content.get("display_name") ) + yield self.notify_device_update(user_id, device_id) except errors.StoreError, e: if e.code == 404: raise errors.NotFoundError() else: raise + @defer.inlineCallbacks + def notify_device_update(self, user_id, device_id): + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = [r.room_id for r in rooms] + + hosts = set() + for room_id in room_ids: + users = yield self.state.get_current_user_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in users) + hosts.discard(self.server_name) + + position = yield self.store.add_device_change_to_streams( + user_id, device_id, list(hosts) + ) + + yield self.notifier.on_new_event( + "device_list_key", position, rooms=room_ids, + ) + + for host in hosts: + self.federation.send_device_messages(host) + + @defer.inlineCallbacks + def get_device_list_changes(self, user_id, room_ids, from_key): + room_ids = frozenset(room_ids) + + user_ids_changed = set() + changed = yield self.store.get_user_whose_devices_changed(from_key) + for other_user_id in changed: + other_rooms = yield self.store.get_rooms_for_user(other_user_id) + if room_ids.intersection(e.room_id for e in other_rooms): + user_ids_changed.add(other_user_id) + + defer.returnValue(user_ids_changed) + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index b63a660c06..38c2a2d39e 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -259,6 +259,7 @@ class E2eKeysHandler(object): user_id, device_id, time_now, encode_canonical_json(device_keys) ) + yield self.device_handler.notify_device_update(user_id, device_id) one_time_keys = keys.get("one_time_keys", None) if one_time_keys: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c880f61685..06bf626367 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. "to_device", # List of direct messages for the device. + "device_lists", # List of user_ids whose devices have chanegd ])): __slots__ = [] @@ -143,6 +144,7 @@ class SyncHandler(object): self.clock = hs.get_clock() self.response_cache = ResponseCache(hs) self.state = hs.get_state_handler() + self.device_handler = hs.get_device_handler() def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): @@ -544,6 +546,16 @@ class SyncHandler(object): yield self._generate_sync_entry_for_to_device(sync_result_builder) + if since_token and since_token.device_list_key: + user_id = sync_config.user.to_string() + rooms = yield self.store.get_rooms_for_user(user_id) + joined_room_ids = set(r.room_id for r in rooms) + device_lists = yield self.device_handler.get_device_list_changes( + user_id, joined_room_ids, since_token.device_list_key + ) + else: + device_lists = [] + defer.returnValue(SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, @@ -551,6 +563,7 @@ class SyncHandler(object): invited=sync_result_builder.invited, archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, + device_lists=device_lists, next_batch=sync_result_builder.now_token, )) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 7199ec883a..b3d8001638 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -170,12 +170,16 @@ class SyncRestServlet(RestServlet): ) archived = self.encode_archived( - sync_result.archived, time_now, requester.access_token_id, filter.event_fields + sync_result.archived, time_now, requester.access_token_id, + filter.event_fields, ) response_content = { "account_data": {"events": sync_result.account_data}, "to_device": {"events": sync_result.to_device}, + "device_lists": { + "changed": list(sync_result.device_lists), + }, "presence": self.encode_presence( sync_result.presence, time_now ), diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e8495f1eb9..b9968debe5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore, self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) + self._device_list_id_gen = StreamIdGenerator( + db_conn, "device_lists_stream", "stream_id", + ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") @@ -210,6 +213,14 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=device_outbox_prefill, ) + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max, + ) + cur = LoggingTransaction( db_conn.cursor(), name="_find_stream_orderings_for_times_txn", diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 963ef999d5..05374682fd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -387,6 +387,10 @@ class SQLBaseStore(object): Args: table : string giving the table name values : dict of new column names and values for them + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True """ try: yield self.runInteraction( @@ -398,6 +402,8 @@ class SQLBaseStore(object): # a cursor after we receive an error from the db. if not or_ignore: raise + defer.returnValue(False) + defer.returnValue(True) @staticmethod def _simple_insert_txn(txn, table, values): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 17920d4480..b594f501f9 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import ujson as json from twisted.internet import defer @@ -33,17 +34,13 @@ class DeviceStore(SQLBaseStore): user_id (str): id of user associated with the device device_id (str): id of device initial_device_display_name (str): initial displayname of the - device - ignore_if_known (bool): ignore integrity errors which mean the - device is already known + device. Ignored if device exists. Returns: - defer.Deferred - Raises: - StoreError: if ignore_if_known is False and the device was already - known + defer.Deferred: boolean whether the device was inserted or an + existing device existed with that ID. """ try: - yield self._simple_insert( + inserted = yield self._simple_insert( "devices", values={ "user_id": user_id, @@ -51,8 +48,9 @@ class DeviceStore(SQLBaseStore): "display_name": initial_device_display_name }, desc="store_device", - or_ignore=ignore_if_known, + or_ignore=True, ) + defer.returnValue(inserted) except Exception as e: logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" " display_name=%s(%r) failed: %s", @@ -139,3 +137,156 @@ class DeviceStore(SQLBaseStore): ) defer.returnValue({d["device_id"]: d for d in devices}) + + def get_devices_by_remote(self, destination, from_stream_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + has_changed = self._device_list_stream_cache.has_entity_changed( + destination, int(from_stream_id) + ) + if not has_changed: + defer.returnValue((now_stream_id, [])) + + return self.runInteraction( + "get_devices_by_remote", self._get_devices_by_remote_txn, + destination, from_stream_id, now_stream_id, + ) + + def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, + now_stream_id): + sql = """ + SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id > ? AND stream_id <= ? AND sent = ? + GROUP BY user_id, device_id + """ + txn.execute( + sql, (destination, from_stream_id, now_stream_id, False) + ) + rows = txn.fetchall() + + if not rows: + return now_stream_id, [] + + # maps (user_id, device_id) -> stream_id + query_map = {(r[0], r[1]): r[2] for r in rows} + devices = self._get_e2e_device_keys_txn( + txn, query_map.keys(), include_all_devices=True + ) + + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND sent = ? + """ + + results = [] + for user_id, user_devices in devices.iteritems(): + txn.execute(prev_sent_id_sql, (destination, user_id, True)) + rows = txn.fetchall() + prev_id = rows[0][0] + for device_id, result in user_devices.iteritems(): + stream_id = query_map[(user_id, device_id)] + result = { + "user_id": user_id, + "device_id": device_id, + "prev_id": prev_id, + "stream_id": stream_id, + } + + prev_id = stream_id + + key_json = result.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = result.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.setdefault(user_id, {})[device_id] = result + + return now_stream_id, results + + def mark_as_sent_devices_by_remote(self, destination, stream_id): + return self.runInteraction( + "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, + destination, stream_id, + ) + + @defer.inlineCallbacks + def get_user_whose_devices_changed(self, from_key): + from_key = int(from_key) + changed = self._device_list_stream_cache.get_all_entities_changed(from_key) + if changed is not None: + defer.returnValue(set(changed)) + + sql = """ + SELECT user_id FROM device_lists_stream WHERE stream_id > ? + """ + rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + defer.returnValue(set(row["user_id"] for row in rows)) + + def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id < ( + SELECT coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + ) + """ + txn.execute(sql, (destination, destination, stream_id,)) + + sql = """ + UPDATE device_lists_outbound_pokes SET sent = ? + WHERE destination = ? AND stream_id <= ? + """ + txn.execute(sql, (destination, True,)) + + @defer.inlineCallbacks + def add_device_change_to_streams(self, user_id, device_id, hosts): + # device_lists_stream + # device_lists_outbound_pokes + with self._device_list_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_device_change_to_streams", self._add_device_change_txn, + user_id, device_id, hosts, stream_id, + ) + defer.returnValue(stream_id) + + def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id): + txn.call_after( + self._device_list_stream_cache.entity_has_changed, + user_id, stream_id, + ) + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, stream_id, + ) + + self._simple_insert_txn( + txn, + table="device_lists_stream", + values={ + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + } + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_outbound_pokes", + values=[ + { + "destination": destination, + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + "sent": False, + } + for destination in hosts + ] + ) + + def get_device_stream_token(self): + return self._device_list_id_gen.get_current_token() diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 385d607056..f82943a7a8 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,9 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections - -import twisted.internet.defer +from twisted.internet import defer from ._base import SQLBaseStore @@ -33,7 +31,7 @@ class EndToEndKeyStore(SQLBaseStore): } ) - def get_e2e_device_keys(self, query_list): + def get_e2e_device_keys(self, query_list, include_all_devices=False): """Fetch a list of device keys. Args: query_list(list): List of pairs of user_ids and device_ids. @@ -45,10 +43,11 @@ class EndToEndKeyStore(SQLBaseStore): return {} return self.runInteraction( - "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list + "get_e2e_device_keys", self._get_e2e_device_keys_txn, + query_list, include_all_devices, ) - def _get_e2e_device_keys_txn(self, txn, query_list): + def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): query_clauses = [] query_params = [] @@ -63,23 +62,23 @@ class EndToEndKeyStore(SQLBaseStore): query_clauses.append(query_clause) sql = ( - "SELECT k.user_id, k.device_id, " + "SELECT user_id, device_id, " " d.display_name AS device_display_name, " " k.key_json" " FROM e2e_device_keys_json k" - " LEFT JOIN devices d ON d.user_id = k.user_id" - " AND d.device_id = k.device_id" + " %s JOIN devices d USING (user_id, device_id)" " WHERE %s" ) % ( + "FULL OUTER" if include_all_devices else "LEFT", " OR ".join("(" + q + ")" for q in query_clauses) ) txn.execute(sql, query_params) rows = self.cursor_to_dict(txn) - result = collections.defaultdict(dict) + result = {} for row in rows: - result[row["user_id"]][row["device_id"]] = row + result.setdefault(row["user_id"], {})[row["device_id"]] = row return result @@ -152,7 +151,7 @@ class EndToEndKeyStore(SQLBaseStore): "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - @twisted.internet.defer.inlineCallbacks + @defer.inlineCallbacks def delete_e2e_keys_by_device(self, user_id, device_id): yield self._simple_delete( table="e2e_device_keys_json", diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql new file mode 100644 index 0000000000..61cac63bbb --- /dev/null +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -0,0 +1,56 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_list_streams_remote ( + list_id TEXT NOT NULL, + origin TEXT NOT NULL, + user_id TEXT NOT NULL, + is_full BOOLEAN NOT NULL, + ts BIGINT NOT NULL +); + +CREATE INDEX device_list_streams_remote_id_origin ON device_list_streams_remote( + origin, list_id, user_id +); + + +CREATE TABLE device_lists_remote_cache ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + content TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); + + +CREATE TABLE device_lists_stream ( + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL +); + +CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); + + +CREATE TABLE device_lists_outbound_pokes ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + sent BOOLEAN NOT NULL +); + +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 4d44c3d4ca..91a59b0bae 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -44,6 +44,7 @@ class EventSources(object): def get_current_token(self): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() + device_list_key = self.store.get_device_stream_token() token = StreamToken( room_key=( @@ -63,6 +64,7 @@ class EventSources(object): ), push_rules_key=push_rules_key, to_device_key=to_device_key, + device_list_key=device_list_key, ) defer.returnValue(token) @@ -70,6 +72,7 @@ class EventSources(object): def get_current_token_for_room(self, room_id): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() + device_list_key = self.store.get_device_stream_token() token = StreamToken( room_key=( @@ -89,5 +92,6 @@ class EventSources(object): ), push_rules_key=push_rules_key, to_device_key=to_device_key, + device_list_key=device_list_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 3a3ab21d17..9666f9d73f 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -158,6 +158,7 @@ class StreamToken( "account_data_key", "push_rules_key", "to_device_key", + "device_list_key", )) ): _SEPARATOR = "_" @@ -195,6 +196,7 @@ class StreamToken( or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.to_device_key) < int(self.to_device_key)) + or (int(other.device_list_key) < int(self.device_list_key)) ) def copy_and_advance(self, key, new_value): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index c718d1f98f..f88d2be7c5 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -75,6 +75,7 @@ class TypingNotificationsTestCase(unittest.TestCase): "get_received_txn_response", "set_received_txn_response", "get_destination_retry_timings", + "get_devices_by_remote", ]), state_handler=self.state_handler, handlers=None, @@ -99,6 +100,8 @@ class TypingNotificationsTestCase(unittest.TestCase): defer.succeed(retry_timings_res) ) + self.datastore.get_devices_by_remote.return_value = (0, []) + def get_received_txn_response(*args): return defer.succeed(None) self.datastore.get_received_txn_response = get_received_txn_response diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 6bce352c5f..d746ea8568 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase): @defer.inlineCallbacks def test_topo_token_is_accepted(self): - token = "t1-0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0" (code, response) = yield self.mock_resource.trigger_get( "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)) @@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase): @defer.inlineCallbacks def test_stream_token_is_accepted_for_fwd_pagianation(self): - token = "s0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0" (code, response) = yield self.mock_resource.trigger_get( "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)) -- cgit 1.4.1 From 51e9fe36e46331ac611cec1d4cb425c1bc98721c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Jan 2017 16:55:21 +0000 Subject: Fix up sending of m.device_list_update edus --- synapse/federation/transaction_queue.py | 121 ++++++++++++++++---------------- synapse/handlers/device.py | 1 + synapse/storage/devices.py | 40 +++++------ 3 files changed, 82 insertions(+), 80 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 65c6673a87..d18f6b6cfd 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -306,62 +306,74 @@ class TransactionQueue(object): yield run_on_reactor() while True: - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_presence = self.pending_presence_by_dest.pop(destination, {}) - pending_failures = self.pending_failures_by_dest.pop(destination, []) + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_presence = self.pending_presence_by_dest.pop(destination, {}) + pending_failures = self.pending_failures_by_dest.pop(destination, []) - pending_edus.extend( - self.pending_edus_keyed_by_dest.pop(destination, {}).values() - ) + pending_edus.extend( + self.pending_edus_keyed_by_dest.pop(destination, {}).values() + ) - limiter = yield get_retry_limiter( - destination, - self.clock, - self.store, - ) + limiter = yield get_retry_limiter( + destination, + self.clock, + self.store, + ) - device_message_edus, device_stream_id = ( - yield self._get_new_device_messages(destination) - ) + device_message_edus, device_stream_id, dev_list_id = ( + yield self._get_new_device_messages(destination) + ) - pending_edus.extend(device_message_edus) - if pending_presence: - pending_edus.append( - Edu( - origin=self.server_name, - destination=destination, - edu_type="m.presence", - content={ - "push": [ - format_user_presence_state( - presence, self.clock.time_msec() - ) - for presence in pending_presence.values() - ] - }, - ) + pending_edus.extend(device_message_edus) + if pending_presence: + pending_edus.append( + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.presence", + content={ + "push": [ + format_user_presence_state( + presence, self.clock.time_msec() + ) + for presence in pending_presence.values() + ] + }, ) + ) - if pending_pdus: - logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) + if pending_pdus: + logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) - if not pending_pdus and not pending_edus and not pending_failures: - logger.debug("TX [%s] Nothing to send", destination) - self.last_device_stream_id_by_dest[destination] = ( - device_stream_id + if not pending_pdus and not pending_edus and not pending_failures: + logger.debug("TX [%s] Nothing to send", destination) + self.last_device_stream_id_by_dest[destination] = ( + device_stream_id + ) + return + + success = yield self._send_new_transaction( + destination, pending_pdus, pending_edus, pending_failures, + limiter=limiter, + ) + if success: + # Remove the acknowledged device messages from the database + # Only bother if we actually sent some device messages + if device_message_edus: + yield self.store.delete_device_msgs_for_remote( + destination, device_stream_id + ) + logger.info("Marking as sent %r %r", destination, dev_list_id) + yield self.store.mark_as_sent_devices_by_remote( + destination, dev_list_id ) - return - success = yield self._send_new_transaction( - destination, pending_pdus, pending_edus, pending_failures, - device_stream_id, - includes_device_messages=bool(device_message_edus), - limiter=limiter, - ) - if not success: - break + self.last_device_stream_id_by_dest[destination] = device_stream_id + self.last_device_list_stream_id_by_dest[destination] = dev_list_id + else: + break except NotRetryingDestination: logger.debug( "TX [%s] not ready for retry yet - " @@ -374,8 +386,6 @@ class TransactionQueue(object): @defer.inlineCallbacks def _get_new_device_messages(self, destination): - # TODO: Send appropriate device list messages - last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0) to_device_stream_id = self.store.get_to_device_stream_token() contents, stream_id = yield self.store.get_new_device_msgs_for_remote( @@ -404,13 +414,12 @@ class TransactionQueue(object): ) for content in results ) - defer.returnValue((edus, stream_id)) + defer.returnValue((edus, stream_id, now_stream_id)) @measure_func("_send_new_transaction") @defer.inlineCallbacks def _send_new_transaction(self, destination, pending_pdus, pending_edus, - pending_failures, device_stream_id, - includes_device_messages, limiter): + pending_failures, limiter): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) @@ -521,14 +530,6 @@ class TransactionQueue(object): "Failed to send event %s to %s", p.event_id, destination ) success = False - else: - # Remove the acknowledged device messages from the database - # Only bother if we actually sent some device messages - if includes_device_messages: - yield self.store.delete_device_msgs_for_remote( - destination, device_stream_id - ) - self.last_device_stream_id_by_dest[destination] = device_stream_id except RuntimeError as e: # We capture this here as there as nothing actually listens # for this finishing functions deferred. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d92780b642..ba4c48d590 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -29,6 +29,7 @@ class DeviceHandler(BaseHandler): super(DeviceHandler, self).__init__(hs) self.state = hs.get_state_handler() + self.federation = hs.get_federation_sender() @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index b594f501f9..9628e2ff75 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -141,11 +141,11 @@ class DeviceStore(SQLBaseStore): def get_devices_by_remote(self, destination, from_stream_id): now_stream_id = self._device_list_id_gen.get_current_token() - has_changed = self._device_list_stream_cache.has_entity_changed( + has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) ) if not has_changed: - defer.returnValue((now_stream_id, [])) + return (now_stream_id, []) return self.runInteraction( "get_devices_by_remote", self._get_devices_by_remote_txn, @@ -165,7 +165,7 @@ class DeviceStore(SQLBaseStore): rows = txn.fetchall() if not rows: - return now_stream_id, [] + return (now_stream_id, []) # maps (user_id, device_id) -> stream_id query_map = {(r[0], r[1]): r[2] for r in rows} @@ -189,7 +189,7 @@ class DeviceStore(SQLBaseStore): result = { "user_id": user_id, "device_id": device_id, - "prev_id": prev_id, + "prev_id": [prev_id] if prev_id else [], "stream_id": stream_id, } @@ -202,9 +202,9 @@ class DeviceStore(SQLBaseStore): if device_display_name: result["device_display_name"] = device_display_name - results.setdefault(user_id, {})[device_id] = result + results.append(result) - return now_stream_id, results + return (now_stream_id, results) def mark_as_sent_devices_by_remote(self, destination, stream_id): return self.runInteraction( @@ -212,19 +212,6 @@ class DeviceStore(SQLBaseStore): destination, stream_id, ) - @defer.inlineCallbacks - def get_user_whose_devices_changed(self, from_key): - from_key = int(from_key) - changed = self._device_list_stream_cache.get_all_entities_changed(from_key) - if changed is not None: - defer.returnValue(set(changed)) - - sql = """ - SELECT user_id FROM device_lists_stream WHERE stream_id > ? - """ - rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) - defer.returnValue(set(row["user_id"] for row in rows)) - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): sql = """ DELETE FROM device_lists_outbound_pokes @@ -239,7 +226,20 @@ class DeviceStore(SQLBaseStore): UPDATE device_lists_outbound_pokes SET sent = ? WHERE destination = ? AND stream_id <= ? """ - txn.execute(sql, (destination, True,)) + txn.execute(sql, (True, destination, stream_id,)) + + @defer.inlineCallbacks + def get_user_whose_devices_changed(self, from_key): + from_key = int(from_key) + changed = self._device_list_stream_cache.get_all_entities_changed(from_key) + if changed is not None: + defer.returnValue(set(changed)) + + sql = """ + SELECT user_id FROM device_lists_stream WHERE stream_id > ? + """ + rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + defer.returnValue(set(row["user_id"] for row in rows)) @defer.inlineCallbacks def add_device_change_to_streams(self, user_id, device_id, hosts): -- cgit 1.4.1 From c974116f197d211ba9b42159fe61cfd5957411b5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jan 2017 16:06:54 +0000 Subject: Implement device key caching over federation --- synapse/federation/federation_client.py | 10 + synapse/federation/federation_server.py | 3 + synapse/federation/transport/client.py | 26 +++ synapse/federation/transport/server.py | 8 + synapse/handlers/device.py | 85 +++++++-- synapse/handlers/e2e_keys.py | 40 +++- synapse/storage/devices.py | 201 +++++++++++++++++++-- synapse/storage/end_to_end_keys.py | 4 +- .../schema/delta/40/device_list_streams.sql | 20 +- tests/handlers/test_device.py | 18 +- tests/handlers/test_directory.py | 1 + tests/handlers/test_profile.py | 1 + tests/storage/test_appservice.py | 21 ++- 13 files changed, 381 insertions(+), 57 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c9175bb33d..b5bcfd705a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -126,6 +126,16 @@ class FederationClient(FederationBase): destination, content, timeout ) + @log_function + def query_user_devices(self, destination, user_id, timeout=30000): + """Query the device keys for a list of user ids hosted on a remote + server. + """ + sent_queries_counter.inc("user_devices") + return self.transport_layer.query_user_devices( + destination, user_id, timeout + ) + @log_function def claim_client_keys(self, destination, content, timeout): """Claims one-time keys for a device hosted on a remote server. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 862ccbef5d..e922b7ff4a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -416,6 +416,9 @@ class FederationServer(FederationBase): def on_query_client_keys(self, origin, content): return self.on_query_request("client_keys", content) + def on_query_user_devices(self, origin, user_id): + return self.on_query_request("user_devices", user_id) + @defer.inlineCallbacks @log_function def on_claim_client_keys(self, origin, content): diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 915af34409..f49e8a2cc4 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -346,6 +346,32 @@ class TransportLayerClient(object): ) defer.returnValue(content) + @defer.inlineCallbacks + @log_function + def query_user_devices(self, destination, user_id, timeout): + """Query the devices for a user id hosted on a remote server. + + Response: + { + "stream_id": "...", + "devices": [ { ... } ] + } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the device keys. + """ + path = PREFIX + "/user/devices/" + user_id + + content = yield self.client.get_json( + destination=destination, + path=path, + timeout=timeout, + ) + defer.returnValue(content) + @defer.inlineCallbacks @log_function def claim_client_keys(self, destination, query_content, timeout): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 159dbd1747..c840da834c 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet): return self.handler.on_query_client_keys(origin, content) +class FederationUserDevicesQueryServlet(BaseFederationServlet): + PATH = "/user/devices/(?P[^/]*)" + + def on_GET(self, origin, content, query, user_id): + return self.handler.on_query_user_devices(origin, user_id) + + class FederationClientKeysClaimServlet(BaseFederationServlet): PATH = "/user/keys/claim" @@ -613,6 +620,7 @@ SERVLET_CLASSES = ( FederationGetMissingEventsServlet, FederationEventAuthServlet, FederationClientKeysQueryServlet, + FederationUserDevicesQueryServlet, FederationClientKeysClaimServlet, FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ba4c48d590..2d66b3721a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,6 +15,7 @@ from synapse.api import errors from synapse.util import stringutils +from synapse.util.async import Linearizer from synapse.types import get_domain_from_id from twisted.internet import defer from ._base import BaseHandler @@ -28,8 +29,18 @@ class DeviceHandler(BaseHandler): def __init__(self, hs): super(DeviceHandler, self).__init__(hs) + self.hs = hs self.state = hs.get_state_handler() - self.federation = hs.get_federation_sender() + self.federation_sender = hs.get_federation_sender() + self.federation = hs.get_replication_layer() + self._remote_edue_linearizer = Linearizer(name="remote_device_list") + + self.federation.register_edu_handler( + "m.device_list_update", self._incoming_device_list_update, + ) + self.federation.register_query_handler( + "user_devices", self.on_federation_query_user_devices, + ) @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, @@ -55,7 +66,7 @@ class DeviceHandler(BaseHandler): initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) defer.returnValue(device_id) # if the device id is not specified, we'll autogen one, but loop a few @@ -69,7 +80,7 @@ class DeviceHandler(BaseHandler): initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) defer.returnValue(device_id) attempts += 1 @@ -151,7 +162,7 @@ class DeviceHandler(BaseHandler): user_id=user_id, device_id=device_id ) - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) @defer.inlineCallbacks def update_device(self, user_id, device_id, content): @@ -172,7 +183,7 @@ class DeviceHandler(BaseHandler): device_id, new_display_name=content.get("display_name") ) - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) except errors.StoreError, e: if e.code == 404: raise errors.NotFoundError() @@ -180,26 +191,28 @@ class DeviceHandler(BaseHandler): raise @defer.inlineCallbacks - def notify_device_update(self, user_id, device_id): + def notify_device_update(self, user_id, device_ids): rooms = yield self.store.get_rooms_for_user(user_id) room_ids = [r.room_id for r in rooms] hosts = set() - for room_id in room_ids: - users = yield self.state.get_current_user_in_room(room_id) - hosts.update(get_domain_from_id(u) for u in users) - hosts.discard(self.server_name) + if self.hs.is_mine_id(user_id): + for room_id in room_ids: + users = yield self.state.get_current_user_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in users) + hosts.discard(self.server_name) position = yield self.store.add_device_change_to_streams( - user_id, device_id, list(hosts) + user_id, device_ids, list(hosts) ) yield self.notifier.on_new_event( "device_list_key", position, rooms=room_ids, ) + logger.info("Sending device list update notif to: %r", hosts) for host in hosts: - self.federation.send_device_messages(host) + self.federation_sender.send_device_messages(host) @defer.inlineCallbacks def get_device_list_changes(self, user_id, room_ids, from_key): @@ -214,6 +227,54 @@ class DeviceHandler(BaseHandler): defer.returnValue(user_ids_changed) + @defer.inlineCallbacks + def _incoming_device_list_update(self, origin, edu_content): + user_id = edu_content["user_id"] + device_id = edu_content["device_id"] + stream_id = edu_content["stream_id"] + prev_ids = edu_content.get("prev_id", []) + + if get_domain_from_id(user_id) != origin: + # TODO: Raise? + return + + logger.info("Got edu: %r", edu_content) + + with (yield self._remote_edue_linearizer.queue(user_id)): + resync = True + if len(prev_ids) == 1: + extremity = yield self.store.get_device_list_remote_extremity(user_id) + logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids) + if str(extremity) == str(prev_ids[0]): + resync = False + + if resync: + result = yield self.federation.query_user_devices(origin, user_id) + stream_id = result["stream_id"] + devices = result["devices"] + yield self.store.update_remote_device_list_cache( + user_id, devices, stream_id, + ) + device_ids = [device["device_id"] for device in devices] + yield self.notify_device_update(user_id, device_ids) + else: + content = dict(edu_content) + for key in ("user_id", "device_id", "stream_id", "prev_ids"): + content.pop(key, None) + yield self.store.update_remote_device_list_cache_entry( + user_id, device_id, content, stream_id, + ) + yield self.notify_device_update(user_id, [device_id]) + + @defer.inlineCallbacks + def on_federation_query_user_devices(self, user_id): + stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) + defer.returnValue({ + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + }) + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 38c2a2d39e..832998a6d3 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -73,8 +73,7 @@ class E2eKeysHandler(object): if self.is_mine_id(user_id): local_query[user_id] = device_ids else: - domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = device_ids + remote_queries[user_id] = device_ids # do the queries failures = {} @@ -85,9 +84,40 @@ class E2eKeysHandler(object): if user_id in local_query: results[user_id] = keys + remote_queries_not_in_cache = {} + if remote_queries: + query_list = [] + for user_id, device_ids in remote_queries.iteritems(): + if device_ids: + query_list.extend((user_id, device_id) for device_id in device_ids) + else: + query_list.append((user_id, None)) + + user_ids_not_in_cache, remote_results = ( + yield self.store.get_user_devices_from_cache( + query_list + ) + ) + for user_id, devices in remote_results.iteritems(): + user_devices = results.setdefault(user_id, {}) + for device_id, device in devices.iteritems(): + keys = device.get("keys", None) + device_display_name = device.get("device_display_name", None) + if keys: + result = dict(keys) + unsigned = result.setdefault("unsigned", {}) + if device_display_name: + unsigned["device_display_name"] = device_display_name + user_devices[device_id] = result + + for user_id in user_ids_not_in_cache: + domain = get_domain_from_id(user_id) + r = remote_queries_not_in_cache.setdefault(domain, {}) + r[user_id] = remote_queries[user_id] + @defer.inlineCallbacks def do_remote_query(destination): - destination_query = remote_queries[destination] + destination_query = remote_queries_not_in_cache[destination] try: limiter = yield get_retry_limiter( destination, self.clock, self.store @@ -119,7 +149,7 @@ class E2eKeysHandler(object): yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(do_remote_query)(destination) - for destination in remote_queries + for destination in remote_queries_not_in_cache ])) defer.returnValue({ @@ -259,7 +289,7 @@ class E2eKeysHandler(object): user_id, device_id, time_now, encode_canonical_json(device_keys) ) - yield self.device_handler.notify_device_update(user_id, device_id) + yield self.device_handler.notify_device_update(user_id, [device_id]) one_time_keys = keys.get("one_time_keys", None) if one_time_keys: diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 9628e2ff75..8ee3119db2 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -138,6 +138,89 @@ class DeviceStore(SQLBaseStore): defer.returnValue({d["device_id"]: d for d in devices}) + def get_device_list_remote_extremity(self, user_id): + return self._simple_select_one_onecol( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + retcol="stream_id", + desc="get_device_list_remote_extremity", + allow_none=True, + ) + + def update_remote_device_list_cache_entry(self, user_id, device_id, content, + stream_id): + return self.runInteraction( + "update_remote_device_list_cache_entry", + self._update_remote_device_list_cache_entry_txn, + user_id, device_id, content, stream_id, + ) + + def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, + content, stream_id): + self._simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "content": json.dumps(content), + } + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + + def update_remote_device_list_cache(self, user_id, devices, stream_id): + return self.runInteraction( + "update_remote_device_list_cache", + self._update_remote_device_list_cache_txn, + user_id, devices, stream_id, + ) + + def _update_remote_device_list_cache_txn(self, txn, user_id, devices, + stream_id): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_remote_cache", + values=[ + { + "user_id": user_id, + "device_id": content["device_id"], + "content": json.dumps(content), + } + for content in devices + ] + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + def get_devices_by_remote(self, destination, from_stream_id): now_stream_id = self._device_list_id_gen.get_current_token() @@ -184,7 +267,7 @@ class DeviceStore(SQLBaseStore): txn.execute(prev_sent_id_sql, (destination, user_id, True)) rows = txn.fetchall() prev_id = rows[0][0] - for device_id, result in user_devices.iteritems(): + for device_id, device in user_devices.iteritems(): stream_id = query_map[(user_id, device_id)] result = { "user_id": user_id, @@ -195,10 +278,10 @@ class DeviceStore(SQLBaseStore): prev_id = stream_id - key_json = result.get("key_json", None) + key_json = device.get("key_json", None) if key_json: result["keys"] = json.loads(key_json) - device_display_name = result.get("device_display_name", None) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name @@ -206,6 +289,96 @@ class DeviceStore(SQLBaseStore): return (now_stream_id, results) + def get_user_devices_from_cache(self, query_list): + return self.runInteraction( + "get_user_devices_from_cache", self._get_user_devices_from_cache_txn, + query_list, + ) + + def _get_user_devices_from_cache_txn(self, txn, query_list): + user_ids = {user_id for user_id, _ in query_list} + + user_ids_in_cache = set() + for user_id in user_ids: + stream_ids = self._simple_select_onecol_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + retcol="stream_id", + ) + if stream_ids: + user_ids_in_cache.add(user_id) + + user_ids_not_in_cache = user_ids - user_ids_in_cache + + results = {} + for user_id, device_id in query_list: + if user_id not in user_ids_in_cache: + continue + + if device_id: + content = self._simple_select_one_onecol_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + ) + results.setdefault(user_id, {})[device_id] = json.loads(content) + else: + devices = self._simple_select_list_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + ) + results[user_id] = { + device["device_id"]: json.loads(device["content"]) + for device in devices + } + user_ids_in_cache.discard(user_id) + + return user_ids_not_in_cache, results + + def get_devices_with_keys_by_user(self, user_id): + return self.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, user_id, + ) + + def _get_devices_with_keys_by_user_txn(self, txn, user_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + devices = self._get_e2e_device_keys_txn( + txn, [(user_id, None)], include_all_devices=True + ) + + for user_id, user_devices in devices.iteritems(): + results = [] + for device_id, device in user_devices.iteritems(): + result = { + "device_id": device_id, + } + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + def mark_as_sent_devices_by_remote(self, destination, stream_id): return self.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, @@ -242,17 +415,17 @@ class DeviceStore(SQLBaseStore): defer.returnValue(set(row["user_id"] for row in rows)) @defer.inlineCallbacks - def add_device_change_to_streams(self, user_id, device_id, hosts): + def add_device_change_to_streams(self, user_id, device_ids, hosts): # device_lists_stream # device_lists_outbound_pokes with self._device_list_id_gen.get_next() as stream_id: yield self.runInteraction( "add_device_change_to_streams", self._add_device_change_txn, - user_id, device_id, hosts, stream_id, + user_id, device_ids, hosts, stream_id, ) defer.returnValue(stream_id) - def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id): + def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): txn.call_after( self._device_list_stream_cache.entity_has_changed, user_id, stream_id, @@ -263,14 +436,17 @@ class DeviceStore(SQLBaseStore): host, stream_id, ) - self._simple_insert_txn( + self._simple_insert_many_txn( txn, table="device_lists_stream", - values={ - "stream_id": stream_id, - "user_id": user_id, - "device_id": device_id, - } + values=[ + { + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + } + for device_id in device_ids + ] ) self._simple_insert_many_txn( @@ -285,6 +461,7 @@ class DeviceStore(SQLBaseStore): "sent": False, } for destination in hosts + for device_id in device_ids ] ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index f82943a7a8..a915c790ff 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -52,11 +52,11 @@ class EndToEndKeyStore(SQLBaseStore): query_params = [] for (user_id, device_id) in query_list: - query_clause = "k.user_id = ?" + query_clause = "user_id = ?" query_params.append(user_id) if device_id: - query_clause += " AND k.device_id = ?" + query_clause += " AND device_id = ?" query_params.append(device_id) query_clauses.append(query_clause) diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql index 61cac63bbb..d1051c6ddf 100644 --- a/synapse/storage/schema/delta/40/device_list_streams.sql +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -13,18 +13,6 @@ * limitations under the License. */ -CREATE TABLE device_list_streams_remote ( - list_id TEXT NOT NULL, - origin TEXT NOT NULL, - user_id TEXT NOT NULL, - is_full BOOLEAN NOT NULL, - ts BIGINT NOT NULL -); - -CREATE INDEX device_list_streams_remote_id_origin ON device_list_streams_remote( - origin, list_id, user_id -); - CREATE TABLE device_lists_remote_cache ( user_id TEXT NOT NULL, @@ -35,6 +23,14 @@ CREATE TABLE device_lists_remote_cache ( CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); +CREATE TABLE device_lists_remote_extremeties ( + user_id TEXT NOT NULL, + stream_id TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); + + CREATE TABLE device_lists_stream ( stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 85a970a6c9..2eaaa8253c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield utils.setup_test_homeserver(handlers=None) - self.handler = synapse.handlers.device.DeviceHandler(hs) + hs = yield utils.setup_test_homeserver() + self.handler = hs.get_device_handler() self.store = hs.get_datastore() self.clock = hs.get_clock() @defer.inlineCallbacks def test_device_is_created_if_doesnt_exist(self): res = yield self.handler.check_device_registered( - user_id="boris", + user_id="@boris:foo", device_id="fco", initial_device_display_name="display name" ) self.assertEqual(res, "fco") - dev = yield self.handler.store.get_device("boris", "fco") + dev = yield self.handler.store.get_device("@boris:foo", "fco") self.assertEqual(dev["display_name"], "display name") @defer.inlineCallbacks def test_device_is_preserved_if_exists(self): res1 = yield self.handler.check_device_registered( - user_id="boris", + user_id="@boris:foo", device_id="fco", initial_device_display_name="display name" ) self.assertEqual(res1, "fco") res2 = yield self.handler.check_device_registered( - user_id="boris", + user_id="@boris:foo", device_id="fco", initial_device_display_name="new display name" ) self.assertEqual(res2, "fco") - dev = yield self.handler.store.get_device("boris", "fco") + dev = yield self.handler.store.get_device("@boris:foo", "fco") self.assertEqual(dev["display_name"], "display name") @defer.inlineCallbacks def test_device_id_is_made_up_if_unspecified(self): device_id = yield self.handler.check_device_registered( - user_id="theresa", + user_id="@theresa:foo", device_id=None, initial_device_display_name="display" ) - dev = yield self.handler.store.get_device("theresa", device_id) + dev = yield self.handler.store.get_device("@theresa:foo", device_id) self.assertEqual(dev["display_name"], "display") @defer.inlineCallbacks diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5d602c1531..ceb9aa5765 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase): def setUp(self): self.mock_federation = Mock(spec=[ "make_query", + "register_edu_handler", ]) self.query_handlers = {} diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f1f664275f..979cebf600 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase): def setUp(self): self.mock_federation = Mock(spec=[ "make_query", + "register_edu_handler", ]) self.query_handlers = {} diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 9ff1abcd80..9e98d0e330 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): event_cache_size=1, password_providers=[], ) - hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) + hs = yield setup_test_homeserver( + config=config, + federation_sender=Mock(), + replication_layer=Mock(), + ) self.as_token = "token1" self.as_url = "some_url" @@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): event_cache_size=1, password_providers=[], ) - hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) + hs = yield setup_test_homeserver( + config=config, + federation_sender=Mock(), + replication_layer=Mock(), + ) self.db_pool = hs.get_db_pool() self.as_list = [ @@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, datastore=Mock(), - federation_sender=Mock() + federation_sender=Mock(), + replication_layer=Mock(), ) ApplicationServiceStore(hs) @@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, datastore=Mock(), - federation_sender=Mock() + federation_sender=Mock(), + replication_layer=Mock(), ) with self.assertRaises(ConfigError) as cm: @@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, datastore=Mock(), - federation_sender=Mock() + federation_sender=Mock(), + replication_layer=Mock(), ) with self.assertRaises(ConfigError) as cm: -- cgit 1.4.1 From ae7a132f38404e9f654ab1b7c5dd84ba6a3efda6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 13:40:09 +0000 Subject: Better handle 404 response for federation /send/ --- synapse/federation/transaction_queue.py | 1 + synapse/util/retryutils.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index d18f6b6cfd..cb106c6a1b 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -319,6 +319,7 @@ class TransactionQueue(object): destination, self.clock, self.store, + backoff_on_404=True, # If we get a 404 the other side has gone ) device_message_edus, device_stream_id, dev_list_id = ( diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index e2de7fce91..cc88a0b532 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -88,7 +88,7 @@ class RetryDestinationLimiter(object): def __init__(self, destination, clock, store, retry_interval, min_retry_interval=10 * 60 * 1000, max_retry_interval=24 * 60 * 60 * 1000, - multiplier_retry_interval=5,): + multiplier_retry_interval=5, backoff_on_404=False): """Marks the destination as "down" if an exception is thrown in the context, except for CodeMessageException with code < 500. @@ -107,6 +107,7 @@ class RetryDestinationLimiter(object): a failed request, in milliseconds. multiplier_retry_interval (int): The multiplier to use to increase the retry interval after a failed request. + backoff_on_404 (bool): Back off if we get a 404 """ self.clock = clock self.store = store @@ -116,6 +117,7 @@ class RetryDestinationLimiter(object): self.min_retry_interval = min_retry_interval self.max_retry_interval = max_retry_interval self.multiplier_retry_interval = multiplier_retry_interval + self.backoff_on_404 = backoff_on_404 def __enter__(self): pass @@ -123,7 +125,16 @@ class RetryDestinationLimiter(object): def __exit__(self, exc_type, exc_val, exc_tb): valid_err_code = False if exc_type is not None and issubclass(exc_type, CodeMessageException): - valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500 + if exc_val.code < 400: + valid_err_code = True + elif exc_val.code == 404 and self.backoff_on_404: + valid_err_code = False + elif exc_val.code == 429: + valid_err_code = False + elif exc_val.code < 500: + valid_err_code = True + else: + valid_err_code = False if exc_type is None or valid_err_code: # We connected successfully. -- cgit 1.4.1 From df4ecff5a9f52e26e01ce364e964fc141c920e52 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 15:42:19 +0000 Subject: Correctly raise exceptions for ratelimitng. Ratelimit on 401 --- synapse/federation/transaction_queue.py | 2 +- synapse/util/retryutils.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index cb106c6a1b..bb3d9258a6 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -504,7 +504,7 @@ class TransactionQueue(object): code = e.code response = e.response - if e.code == 429 or 500 <= e.code: + if e.code in (401, 404, 429) or 500 <= e.code: logger.info( "TX [%s] {%s} got %d response", destination, txn_id, code diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index b94ae369cf..153ef001ad 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -129,11 +129,13 @@ class RetryDestinationLimiter(object): # APIs may expect to never received e.g. a 404. It's important to # handle 404 as some remote servers will return a 404 when the HS # has been decommissioned. + # If we get a 401, then we should probably back off since they + # won't accept our requests for at least a while. + # 429 is us being aggresively rate limited, so lets rate limit + # ourselves. if exc_val.code == 404 and self.backoff_on_404: valid_err_code = False - elif exc_val.code == 429: - # 429 is us being aggresively rate limited, so lets rate limit - # ourselves. + elif exc_val.code in (401, 429): valid_err_code = False elif exc_val.code < 500: valid_err_code = True -- cgit 1.4.1