diff options
Diffstat (limited to 'synapse/replication')
26 files changed, 357 insertions, 355 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 0a432a16fa..fe482e279f 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -83,8 +83,7 @@ class ReplicationEndpoint(object): def __init__(self, hs): if self.CACHE: self.response_cache = ResponseCache( - hs, "repl." + self.NAME, - timeout_ms=30 * 60 * 1000, + hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) assert self.METHOD in ("PUT", "POST", "GET") @@ -134,8 +133,7 @@ class ReplicationEndpoint(object): data = yield cls._serialize_payload(**kwargs) url_args = [ - urllib.parse.quote(kwargs[name], safe='') - for name in cls.PATH_ARGS + urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS ] if cls.CACHE: @@ -156,7 +154,10 @@ class ReplicationEndpoint(object): ) uri = "http://%s:%s/_synapse/replication/%s/%s" % ( - host, port, cls.NAME, "/".join(url_args) + host, + port, + cls.NAME, + "/".join(url_args), ) try: @@ -202,10 +203,7 @@ class ReplicationEndpoint(object): url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) - pattern = re.compile("^/_synapse/replication/%s/%s$" % ( - self.NAME, - args - )) + pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) http_server.register_paths(method, [pattern], handler) @@ -219,8 +217,4 @@ class ReplicationEndpoint(object): assert self.CACHE - return self.response_cache.wrap( - txn_id, - self._handle_request, - request, **kwargs - ) + return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 0f0a07c422..61eafbe708 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -68,18 +68,17 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): for event, context in event_and_contexts: serialized_context = yield context.serialize(event, store) - event_payloads.append({ - "event": event.get_pdu_json(), - "event_format_version": event.format_version, - "internal_metadata": event.internal_metadata.get_dict(), - "rejected_reason": event.rejected_reason, - "context": serialized_context, - }) - - payload = { - "events": event_payloads, - "backfilled": backfilled, - } + event_payloads.append( + { + "event": event.get_pdu_json(), + "event_format_version": event.format_version, + "internal_metadata": event.internal_metadata.get_dict(), + "rejected_reason": event.rejected_reason, + "context": serialized_context, + } + ) + + payload = {"events": event_payloads, "backfilled": backfilled} defer.returnValue(payload) @@ -103,18 +102,15 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) context = yield EventContext.deserialize( - self.store, event_payload["context"], + self.store, event_payload["context"] ) event_and_contexts.append((event, context)) - logger.info( - "Got %d events from federation", - len(event_and_contexts), - ) + logger.info("Got %d events from federation", len(event_and_contexts)) yield self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled, + event_and_contexts, backfilled ) defer.returnValue((200, {})) @@ -146,10 +142,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): @staticmethod def _serialize_payload(edu_type, origin, content): - return { - "origin": origin, - "content": content, - } + return {"origin": origin, "content": content} @defer.inlineCallbacks def _handle_request(self, request, edu_type): @@ -159,10 +152,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): origin = content["origin"] edu_content = content["content"] - logger.info( - "Got %r edu from %s", - edu_type, origin, - ) + logger.info("Got %r edu from %s", edu_type, origin) result = yield self.registry.on_edu(edu_type, origin, edu_content) @@ -201,9 +191,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): query_type (str) args (dict): The arguments received for the given query type """ - return { - "args": args, - } + return {"args": args} @defer.inlineCallbacks def _handle_request(self, request, query_type): @@ -212,10 +200,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): args = content["args"] - logger.info( - "Got %r query", - query_type, - ) + logger.info("Got %r query", query_type) result = yield self.registry.on_query(query_type, args) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 63bc0405ea..7c1197e5dd 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -61,13 +61,10 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest = content["is_guest"] device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest, + user_id, device_id, initial_display_name, is_guest ) - defer.returnValue((200, { - "device_id": device_id, - "access_token": access_token, - })) + defer.returnValue((200, {"device_id": device_id, "access_token": access_token})) def register_servlets(hs, http_server): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 81a2b204c7..0a76a3762f 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -40,7 +40,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): """ NAME = "remote_join" - PATH_ARGS = ("room_id", "user_id",) + PATH_ARGS = ("room_id", "user_id") def __init__(self, hs): super(ReplicationRemoteJoinRestServlet, self).__init__(hs) @@ -50,8 +50,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts, - content): + def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): """ Args: requester(Requester) @@ -78,16 +77,10 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): if requester.user: request.authenticated_entity = requester.user.to_string() - logger.info( - "remote_join: %s into room: %s", - user_id, room_id, - ) + logger.info("remote_join: %s into room: %s", user_id, room_id) yield self.federation_handler.do_invite_join( - remote_room_hosts, - room_id, - user_id, - event_content, + remote_room_hosts, room_id, user_id, event_content ) defer.returnValue((200, {})) @@ -107,7 +100,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): """ NAME = "remote_reject_invite" - PATH_ARGS = ("room_id", "user_id",) + PATH_ARGS = ("room_id", "user_id") def __init__(self, hs): super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs) @@ -141,16 +134,11 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): if requester.user: request.authenticated_entity = requester.user.to_string() - logger.info( - "remote_reject_invite: %s out of room: %s", - user_id, room_id, - ) + logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: event = yield self.federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - user_id, + remote_room_hosts, room_id, user_id ) ret = event.get_pdu_json() except Exception as e: @@ -162,9 +150,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warn("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite( - user_id, room_id - ) + yield self.store.locally_reject_invite(user_id, room_id) ret = {} defer.returnValue((200, ret)) @@ -228,7 +214,7 @@ class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint): logger.info("get_or_register_3pid_guest: %r", content) ret = yield self.registeration_handler.get_or_register_3pid_guest( - medium, address, inviter_user_id, + medium, address, inviter_user_id ) defer.returnValue((200, ret)) @@ -264,7 +250,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): user_id (str) change (str): Either "joined" or "left" """ - assert change in ("joined", "left",) + assert change in ("joined", "left") return {} diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 912a5ac341..f81a0f1b8f 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -37,8 +37,16 @@ class ReplicationRegisterServlet(ReplicationEndpoint): @staticmethod def _serialize_payload( - user_id, token, password_hash, was_guest, make_guest, appservice_id, - create_profile_with_displayname, admin, user_type, address, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id, + create_profile_with_displayname, + admin, + user_type, + address, ): """ Args: @@ -85,7 +93,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): create_profile_with_displayname=content["create_profile_with_displayname"], admin=content["admin"], user_type=content["user_type"], - address=content["address"] + address=content["address"], ) defer.returnValue((200, {})) @@ -104,8 +112,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - def _serialize_payload(user_id, auth_result, access_token, bind_email, - bind_msisdn): + def _serialize_payload(user_id, auth_result, access_token, bind_email, bind_msisdn): """ Args: user_id (str): The user ID that consented diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 3635015eda..034763fe99 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -45,6 +45,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "extra_users": [], } """ + NAME = "send_event" PATH_ARGS = ("event_id",) @@ -57,8 +58,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): @staticmethod @defer.inlineCallbacks - def _serialize_payload(event_id, store, event, context, requester, - ratelimit, extra_users): + def _serialize_payload( + event_id, store, event, context, requester, ratelimit, extra_users + ): """ Args: event_id (str) @@ -108,14 +110,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): request.authenticated_entity = requester.user.to_string() logger.info( - "Got event to send with ID: %s into room: %s", - event.event_id, event.room_id, + "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) yield self.event_creation_handler.persist_and_notify_client_event( - requester, event, context, - ratelimit=ratelimit, - extra_users=extra_users, + requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) defer.returnValue((200, {})) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 817d1f67f9..182cb2a1d8 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -37,7 +37,7 @@ class BaseSlavedStore(SQLBaseStore): super(BaseSlavedStore, self).__init__(db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( - db_conn, "cache_invalidation_stream", "stream_id", + db_conn, "cache_invalidation_stream", "stream_id" ) else: self._cache_id_gen = None diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index d9ba6d69b1..3c44d1d48d 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -21,10 +21,9 @@ from synapse.storage.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( - db_conn, "account_data_max_stream_id", "stream_id", + db_conn, "account_data_max_stream_id", "stream_id" ) super(SlavedAccountDataStore, self).__init__(db_conn, hs) @@ -45,24 +44,20 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved self._account_data_id_gen.advance(token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed( - row.user_id, token - ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == "account_data": self._account_data_id_gen.advance(token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( - (row.data_type, row.user_id,) + (row.data_type, row.user_id) ) self.get_account_data_for_user.invalidate((row.user_id,)) - self.get_account_data_for_room.invalidate((row.user_id, row.room_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) self.get_account_data_for_room_and_type.invalidate( - (row.user_id, row.room_id, row.data_type,), - ) - self._account_data_stream_cache.entity_has_changed( - row.user_id, token + (row.user_id, row.room_id, row.data_type) ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedAccountDataStore, self).process_replication_rows( stream_name, token, rows ) diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index b53a4c6bd1..cda12ea70d 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -20,6 +20,7 @@ from synapse.storage.appservice import ( ) -class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore, - ApplicationServiceWorkerStore): +class SlavedApplicationServiceStore( + ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore +): pass diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 5b8521c770..14ced32333 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -25,9 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore): super(SlavedClientIpStore, self).__init__(db_conn, hs) self.client_ip_last_seen = Cache( - name="client_ip_last_seen", - keylen=4, - max_entries=50000 * CACHE_SIZE_FACTOR, + name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR ) def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 4d59778863..284fd30d89 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -24,15 +24,15 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_max_stream_id", "stream_id", + db_conn, "device_max_stream_id", "stream_id" ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token() + self._device_inbox_id_gen.get_current_token(), ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token() + self._device_inbox_id_gen.get_current_token(), ) self._last_device_delete_cache = ExpiringCache( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 16c9a162c5..d9300fce33 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -27,14 +27,14 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self.hs = hs self._device_list_id_gen = SlavedIdTracker( - db_conn, "device_lists_stream", "stream_id", + db_conn, "device_lists_stream", "stream_id" ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max, + "DeviceListStreamChangeCache", device_list_max ) self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max, + "DeviceListFederationStreamChangeCache", device_list_max ) def stream_positions(self): @@ -46,17 +46,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto if stream_name == "device_lists": self._device_list_id_gen.advance(token) for row in rows: - self._invalidate_caches_for_devices( - token, row.user_id, row.destination, - ) + self._invalidate_caches_for_devices(token, row.user_id, row.destination) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) def _invalidate_caches_for_devices(self, token, user_id, destination): - self._device_list_stream_cache.entity_has_changed( - user_id, token - ) + self._device_list_stream_cache.entity_has_changed(user_id, token) if destination: self._device_list_federation_stream_cache.entity_has_changed( diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index a3952506c1..ab5937e638 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -45,21 +45,20 @@ logger = logging.getLogger(__name__) # the method descriptor on the DataStore and chuck them into our class. -class SlavedEventStore(EventFederationWorkerStore, - RoomMemberWorkerStore, - EventPushActionsWorkerStore, - StreamWorkerStore, - StateGroupWorkerStore, - EventsWorkerStore, - SignatureWorkerStore, - UserErasureWorkerStore, - RelationsWorkerStore, - BaseSlavedStore): - +class SlavedEventStore( + EventFederationWorkerStore, + RoomMemberWorkerStore, + EventPushActionsWorkerStore, + StreamWorkerStore, + StateGroupWorkerStore, + EventsWorkerStore, + SignatureWorkerStore, + UserErasureWorkerStore, + RelationsWorkerStore, + BaseSlavedStore, +): def __init__(self, db_conn, hs): - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", - ) + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) @@ -90,8 +89,13 @@ class SlavedEventStore(EventFederationWorkerStore, self._backfill_id_gen.advance(-token) for row in rows: self.invalidate_caches_for_event( - -token, row.event_id, row.room_id, row.type, row.state_key, - row.redacts, row.relates_to, + -token, + row.event_id, + row.room_id, + row.type, + row.state_key, + row.redacts, + row.relates_to, backfilled=True, ) return super(SlavedEventStore, self).process_replication_rows( @@ -103,41 +107,48 @@ class SlavedEventStore(EventFederationWorkerStore, if row.type == EventsStreamEventRow.TypeId: self.invalidate_caches_for_event( - token, data.event_id, data.room_id, data.type, data.state_key, - data.redacts, data.relates_to, + token, + data.event_id, + data.room_id, + data.type, + data.state_key, + data.redacts, + data.relates_to, backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( - (data.state_key, ), + (data.state_key,) ) else: - raise Exception("Unknown events stream row type %s" % (row.type, )) - - def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, - etype, state_key, redacts, relates_to, - backfilled): + raise Exception("Unknown events stream row type %s" % (row.type,)) + + def invalidate_caches_for_event( + self, + stream_ordering, + event_id, + room_id, + etype, + state_key, + redacts, + relates_to, + backfilled, + ): self._invalidate_get_event_cache(event_id) self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_event_push_actions_by_room_for_user.invalidate_many( - (room_id,) - ) + self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: - self._events_stream_cache.entity_has_changed( - room_id, stream_ordering - ) + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) if redacts: self._invalidate_get_event_cache(redacts) if etype == EventTypes.Member: - self._membership_stream_cache.entity_has_changed( - state_key, stream_ordering - ) + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) self.get_invited_rooms_for_user.invalidate((state_key,)) if relates_to: diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index e933b170bb..28a46edd28 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -27,10 +27,11 @@ class SlavedGroupServerStore(BaseSlavedStore): self.hs = hs self._group_updates_id_gen = SlavedIdTracker( - db_conn, "local_group_updates", "stream_id", + db_conn, "local_group_updates", "stream_id" ) self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), + "_group_updates_stream_cache", + self._group_updates_id_gen.get_current_token(), ) get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user) @@ -46,9 +47,7 @@ class SlavedGroupServerStore(BaseSlavedStore): if stream_name == "groups": self._group_updates_id_gen.advance(token) for row in rows: - self._group_updates_stream_cache.entity_has_changed( - row.user_id, token - ) + self._group_updates_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedGroupServerStore, self).process_replication_rows( stream_name, token, rows diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 0ec1db25ce..82d808af4c 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -24,9 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedPresenceStore, self).__init__(db_conn, hs) - self._presence_id_gen = SlavedIdTracker( - db_conn, "presence_stream", "stream_id", - ) + self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) @@ -55,9 +53,7 @@ class SlavedPresenceStore(BaseSlavedStore): if stream_name == "presence": self._presence_id_gen.advance(token) for row in rows: - self.presence_stream_cache.entity_has_changed( - row.user_id, token - ) + self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super(SlavedPresenceStore, self).process_replication_rows( stream_name, token, rows diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 45fc913c52..af7012702e 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -23,7 +23,7 @@ from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def __init__(self, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id", + db_conn, "push_rules_stream", "stream_id" ) super(SlavedPushRuleStore, self).__init__(db_conn, hs) @@ -47,9 +47,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) - self.push_rules_stream_cache.entity_has_changed( - row.user_id, token - ) + self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedPushRuleStore, self).process_replication_rows( stream_name, token, rows ) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 3b2213c0d4..8eeb267d61 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -21,12 +21,10 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): super(SlavedPusherStore, self).__init__(db_conn, hs) self._pushers_id_gen = SlavedIdTracker( - db_conn, "pushers", "id", - extra_tables=[("deleted_pushers", "stream_id")], + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) def stream_positions(self): diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index ed12342f40..91afa5a72b 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -29,7 +29,6 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 0cb474928c..f68b3378e3 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -38,6 +38,4 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication_rows( - stream_name, token, rows - ) + return super(RoomStore, self).process_replication_rows(stream_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 206dc3b397..a44ceb00e7 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -39,6 +39,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): Accepts a handler that will be called when new data is available or data is required. """ + maxDelay = 30 # Try at least once every N seconds def __init__(self, hs, client_name, handler): @@ -64,9 +65,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): def clientConnectionFailed(self, connector, reason): logger.error("Failed to connect to replication: %r", reason) - ReconnectingClientFactory.clientConnectionFailed( - self, connector, reason - ) + ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) class ReplicationClientHandler(object): @@ -74,6 +73,7 @@ class ReplicationClientHandler(object): By default proxies incoming replication data to the SlaveStore. """ + def __init__(self, store): self.store = store diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 2098c32a77..0ff2a7199f 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -23,9 +23,11 @@ import platform if platform.python_implementation() == "PyPy": import json + _json_encoder = json.JSONEncoder() else: import simplejson as json + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) logger = logging.getLogger(__name__) @@ -41,6 +43,7 @@ class Command(object): The default implementation creates a command of form `<NAME> <data>` """ + NAME = None def __init__(self, data): @@ -73,6 +76,7 @@ class ServerCommand(Command): SERVER <server_name> """ + NAME = "SERVER" @@ -99,6 +103,7 @@ class RdataCommand(Command): RDATA presence batch ["@bar:example.com", "online", ...] RDATA presence 59 ["@baz:example.com", "online", ...] """ + NAME = "RDATA" def __init__(self, stream_name, token, row): @@ -110,17 +115,17 @@ class RdataCommand(Command): def from_line(cls, line): stream_name, token, row_json = line.split(" ", 2) return cls( - stream_name, - None if token == "batch" else int(token), - json.loads(row_json) + stream_name, None if token == "batch" else int(token), json.loads(row_json) ) def to_line(self): - return " ".join(( - self.stream_name, - str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), - )) + return " ".join( + ( + self.stream_name, + str(self.token) if self.token is not None else "batch", + _json_encoder.encode(self.row), + ) + ) def get_logcontext_id(self): return "RDATA-" + self.stream_name @@ -133,6 +138,7 @@ class PositionCommand(Command): Sent to the client after all missing updates for a stream have been sent to the client and they're now up to date. """ + NAME = "POSITION" def __init__(self, stream_name, token): @@ -145,19 +151,21 @@ class PositionCommand(Command): return cls(stream_name, int(token)) def to_line(self): - return " ".join((self.stream_name, str(self.token),)) + return " ".join((self.stream_name, str(self.token))) class ErrorCommand(Command): """Sent by either side if there was an ERROR. The data is a string describing the error. """ + NAME = "ERROR" class PingCommand(Command): """Sent by either side as a keep alive. The data is arbitary (often timestamp) """ + NAME = "PING" @@ -165,6 +173,7 @@ class NameCommand(Command): """Sent by client to inform the server of the client's identity. The data is the name """ + NAME = "NAME" @@ -184,6 +193,7 @@ class ReplicateCommand(Command): REPLICATE ALL NOW """ + NAME = "REPLICATE" def __init__(self, stream_name, token): @@ -200,7 +210,7 @@ class ReplicateCommand(Command): return cls(stream_name, token) def to_line(self): - return " ".join((self.stream_name, str(self.token),)) + return " ".join((self.stream_name, str(self.token))) def get_logcontext_id(self): return "REPLICATE-" + self.stream_name @@ -218,6 +228,7 @@ class UserSyncCommand(Command): Where <state> is either "start" or "stop" """ + NAME = "USER_SYNC" def __init__(self, user_id, is_syncing, last_sync_ms): @@ -235,9 +246,13 @@ class UserSyncCommand(Command): return cls(user_id, state == "start", int(last_sync_ms)) def to_line(self): - return " ".join(( - self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), - )) + return " ".join( + ( + self.user_id, + "start" if self.is_syncing else "end", + str(self.last_sync_ms), + ) + ) class FederationAckCommand(Command): @@ -251,6 +266,7 @@ class FederationAckCommand(Command): FEDERATION_ACK <token> """ + NAME = "FEDERATION_ACK" def __init__(self, token): @@ -268,6 +284,7 @@ class SyncCommand(Command): """Used for testing. The client protocol implementation allows waiting on a SYNC command with a specified data. """ + NAME = "SYNC" @@ -278,6 +295,7 @@ class RemovePusherCommand(Command): REMOVE_PUSHER <app_id> <push_key> <user_id> """ + NAME = "REMOVE_PUSHER" def __init__(self, app_id, push_key, user_id): @@ -309,6 +327,7 @@ class InvalidateCacheCommand(Command): Where <keys_json> is a json list. """ + NAME = "INVALIDATE_CACHE" def __init__(self, cache_func, keys): @@ -322,9 +341,7 @@ class InvalidateCacheCommand(Command): return cls(cache_func, json.loads(keys_json)) def to_line(self): - return " ".join(( - self.cache_func, _json_encoder.encode(self.keys), - )) + return " ".join((self.cache_func, _json_encoder.encode(self.keys))) class UserIpCommand(Command): @@ -334,6 +351,7 @@ class UserIpCommand(Command): USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent> """ + NAME = "USER_IP" def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen): @@ -350,15 +368,22 @@ class UserIpCommand(Command): access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) - return cls( - user_id, access_token, ip, user_agent, device_id, last_seen - ) + return cls(user_id, access_token, ip, user_agent, device_id, last_seen) def to_line(self): - return self.user_id + " " + _json_encoder.encode(( - self.access_token, self.ip, self.user_agent, self.device_id, - self.last_seen, - )) + return ( + self.user_id + + " " + + _json_encoder.encode( + ( + self.access_token, + self.ip, + self.user_agent, + self.device_id, + self.last_seen, + ) + ) + ) # Map of command name to command type. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index b51590cf8f..97efb835ad 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -84,7 +84,8 @@ from .commands import ( from .streams import STREAMS_MAP connection_close_counter = Counter( - "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]) + "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] +) # A list of all connected protocols. This allows us to send metrics about the # connections. @@ -119,7 +120,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) """ - delimiter = b'\n' + + delimiter = b"\n" VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send @@ -183,10 +185,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if now - self.last_sent_command >= PING_TIME: self.send_command(PingCommand(now)) - if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS: + if ( + self.received_ping + and now - self.last_received_command > PING_TIMEOUT_MS + ): logger.info( "[%s] Connection hasn't received command in %r ms. Closing.", - self.id(), now - self.last_received_command + self.id(), + now - self.last_received_command, ) self.send_error("ping timeout") @@ -208,7 +214,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() self.inbound_commands_counter[cmd_name] = ( - self.inbound_commands_counter[cmd_name] + 1) + self.inbound_commands_counter[cmd_name] + 1 + ) cmd_cls = COMMAND_MAP[cmd_name] try: @@ -224,9 +231,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Now lets try and call on_<CMD_NAME> function run_as_background_process( - "replication-" + cmd.get_logcontext_id(), - self.handle_command, - cmd, + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd ) def handle_command(self, cmd): @@ -274,8 +279,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): return self.outbound_commands_counter[cmd.NAME] = ( - self.outbound_commands_counter[cmd.NAME] + 1) - string = "%s %s" % (cmd.NAME, cmd.to_line(),) + self.outbound_commands_counter[cmd.NAME] + 1 + ) + string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -283,10 +289,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if len(encoded_string) > self.MAX_LENGTH: raise Exception( - "Failed to send command %s as too long (%d > %d)" % ( - cmd.NAME, - len(encoded_string), self.MAX_LENGTH, - ) + "Failed to send command %s as too long (%d > %d)" + % (cmd.NAME, len(encoded_string), self.MAX_LENGTH) ) self.sendLine(encoded_string) @@ -379,7 +383,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if self.transport: addr = str(self.transport.getPeer()) return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % ( - self.name, self.conn_id, addr, + self.name, + self.conn_id, + addr, ) def id(self): @@ -422,7 +428,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): def on_USER_SYNC(self, cmd): return self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, + self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) def on_REPLICATE(self, cmd): @@ -432,10 +438,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): if stream_name == "ALL": # Subscribe to all streams we're publishing to. deferreds = [ - run_in_background( - self.subscribe_to_stream, - stream, token, - ) + run_in_background(self.subscribe_to_stream, stream, token) for stream in iterkeys(self.streamer.streams_by_name) ] @@ -449,16 +452,18 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): return self.streamer.federation_ack(cmd.token) def on_REMOVE_PUSHER(self, cmd): - return self.streamer.on_remove_pusher( - cmd.app_id, cmd.push_key, cmd.user_id, - ) + return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) def on_INVALIDATE_CACHE(self, cmd): return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) def on_USER_IP(self, cmd): return self.streamer.on_user_ip( - cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, cmd.last_seen, ) @@ -476,7 +481,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): try: # Get missing updates updates, current_token = yield self.streamer.get_stream_updates( - stream_name, token, + stream_name, token ) # Send all the missing updates @@ -608,8 +613,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): row = STREAMS_MAP[stream_name].parse_row(cmd.row) except Exception: logger.exception( - "[%s] Failed to parse RDATA: %r %r", - self.id(), stream_name, cmd.row + "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row ) raise @@ -643,7 +647,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): logger.info( "[%s] Subscribing to replication stream: %r from %r", - self.id(), stream_name, token + self.id(), + stream_name, + token, ) self.streams_connecting.add(stream_name) @@ -661,9 +667,7 @@ pending_commands = LaterGauge( "synapse_replication_tcp_protocol_pending_commands", "", ["name"], - lambda: { - (p.name,): len(p.pending_commands) for p in connected_connections - }, + lambda: {(p.name,): len(p.pending_commands) for p in connected_connections}, ) @@ -678,9 +682,7 @@ transport_send_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_send_buffer", "", ["name"], - lambda: { - (p.name,): transport_buffer_size(p) for p in connected_connections - }, + lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections}, ) @@ -694,7 +696,7 @@ def transport_kernel_read_buffer_size(protocol, read=True): op = SIOCINQ else: op = SIOCOUTQ - size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0] + size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0] return size return 0 @@ -726,7 +728,7 @@ tcp_inbound_commands = LaterGauge( "", ["command", "name"], lambda: { - (k, p.name,): count + (k, p.name): count for p in connected_connections for k, count in iteritems(p.inbound_commands_counter) }, @@ -737,7 +739,7 @@ tcp_outbound_commands = LaterGauge( "", ["command", "name"], lambda: { - (k, p.name,): count + (k, p.name): count for p in connected_connections for k, count in iteritems(p.outbound_commands_counter) }, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index f6a38f5140..d1e98428bc 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -33,13 +33,15 @@ from .protocol import ServerReplicationStreamProtocol from .streams import STREAMS_MAP from .streams.federation import FederationStream -stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", - "", ["stream_name"]) +stream_updates_counter = Counter( + "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] +) user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache", - "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -48,6 +50,7 @@ logger = logging.getLogger(__name__) class ReplicationStreamProtocolFactory(Factory): """Factory for new replication connections. """ + def __init__(self, hs): self.streamer = ReplicationStreamer(hs) self.clock = hs.get_clock() @@ -55,9 +58,7 @@ class ReplicationStreamProtocolFactory(Factory): def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, - self.clock, - self.streamer, + self.server_name, self.clock, self.streamer ) @@ -80,29 +81,39 @@ class ReplicationStreamer(object): # Current connections. self.connections = [] - LaterGauge("synapse_replication_tcp_resource_total_connections", "", [], - lambda: len(self.connections)) + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self.connections), + ) # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. self.streams = [ - stream(hs) for stream in itervalues(STREAMS_MAP) + stream(hs) + for stream in itervalues(STREAMS_MAP) if stream != FederationStream or not hs.config.send_federation ] self.streams_by_name = {stream.NAME: stream for stream in self.streams} LaterGauge( - "synapse_replication_tcp_resource_connections_per_stream", "", + "synapse_replication_tcp_resource_connections_per_stream", + "", ["stream_name"], lambda: { - (stream_name,): len([ - conn for conn in self.connections - if stream_name in conn.replication_streams - ]) + (stream_name,): len( + [ + conn + for conn in self.connections + if stream_name in conn.replication_streams + ] + ) for stream_name in self.streams_by_name - }) + }, + ) self.federation_sender = None if not hs.config.send_federation: @@ -179,7 +190,9 @@ class ReplicationStreamer(object): logger.debug( "Getting stream: %s: %s -> %s", - stream.NAME, stream.last_token, stream.upto_token + stream.NAME, + stream.last_token, + stream.upto_token, ) try: updates, current_token = yield stream.get_updates() @@ -189,7 +202,8 @@ class ReplicationStreamer(object): logger.debug( "Sending %d updates to %d connections", - len(updates), len(self.connections), + len(updates), + len(self.connections), ) if updates: @@ -243,7 +257,7 @@ class ReplicationStreamer(object): """ user_sync_counter.inc() yield self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms, + conn_id, user_id, is_syncing, last_sync_ms ) @measure_func("repl.on_remove_pusher") @@ -272,7 +286,7 @@ class ReplicationStreamer(object): """ user_ip_cache_counter.inc() yield self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen, + user_id, access_token, ip, user_agent, device_id, last_seen ) yield self._server_notices_sender.on_user_ip(user_id) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b6ce7a7bee..7ef67a5a73 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -26,78 +26,75 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 10000 -BackfillStreamRow = namedtuple("BackfillStreamRow", ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional -)) -PresenceStreamRow = namedtuple("PresenceStreamRow", ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool -)) -TypingStreamRow = namedtuple("TypingStreamRow", ( - "room_id", # str - "user_ids", # list(str) -)) -ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict -)) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ( - "user_id", # str -)) -PushersStreamRow = namedtuple("PushersStreamRow", ( - "user_id", # str - "app_id", # str - "pushkey", # str - "deleted", # bool -)) -CachesStreamRow = namedtuple("CachesStreamRow", ( - "cache_func", # str - "keys", # list(str) - "invalidation_ts", # int -)) -PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional -)) -DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ( - "user_id", # str - "destination", # str -)) -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ( - "entity", # str -)) -TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", ( - "user_id", # str - "room_id", # str - "data", # dict -)) -AccountDataStreamRow = namedtuple("AccountDataStream", ( - "user_id", # str - "room_id", # str - "data_type", # str - "data", # dict -)) -GroupsStreamRow = namedtuple("GroupsStreamRow", ( - "group_id", # str - "user_id", # str - "type", # str - "content", # dict -)) +BackfillStreamRow = namedtuple( + "BackfillStreamRow", + ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional + "relates_to", # str, optional + ), +) +PresenceStreamRow = namedtuple( + "PresenceStreamRow", + ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool + ), +) +TypingStreamRow = namedtuple( + "TypingStreamRow", ("room_id", "user_ids") # str # list(str) +) +ReceiptsStreamRow = namedtuple( + "ReceiptsStreamRow", + ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict + ), +) +PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str +PushersStreamRow = namedtuple( + "PushersStreamRow", + ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool +) +CachesStreamRow = namedtuple( + "CachesStreamRow", + ("cache_func", "keys", "invalidation_ts"), # str # list(str) # int +) +PublicRoomsStreamRow = namedtuple( + "PublicRoomsStreamRow", + ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional + ), +) +DeviceListsStreamRow = namedtuple( + "DeviceListsStreamRow", ("user_id", "destination") # str # str +) +ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str +TagAccountDataStreamRow = namedtuple( + "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict +) +AccountDataStreamRow = namedtuple( + "AccountDataStream", + ("user_id", "room_id", "data_type", "data"), # str # str # str # dict +) +GroupsStreamRow = namedtuple( + "GroupsStreamRow", + ("group_id", "user_id", "type", "content"), # str # str # str # dict +) class Stream(object): @@ -106,6 +103,7 @@ class Stream(object): Provides a `get_updates()` function that returns new updates since the last time it was called up until the point `advance_current_token` was called. """ + NAME = None # The name of the stream ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. _LIMITED = True # Whether the update function takes a limit @@ -185,16 +183,13 @@ class Stream(object): if self._LIMITED: rows = yield self.update_function( - from_token, current_token, - limit=MAX_EVENTS_BEHIND + 1, + from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 ) # never turn more than MAX_EVENTS_BEHIND + 1 into updates. rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: - rows = yield self.update_function( - from_token, current_token, - ) + rows = yield self.update_function(from_token, current_token) updates = [(row[0], row[1:]) for row in rows] @@ -230,6 +225,7 @@ class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. """ + NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -286,6 +282,7 @@ class ReceiptsStream(Stream): class PushRulesStream(Stream): """A user has changed their push rules """ + NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -306,6 +303,7 @@ class PushRulesStream(Stream): class PushersStream(Stream): """A user has added/changed/removed a pusher """ + NAME = "pushers" ROW_TYPE = PushersStreamRow @@ -322,6 +320,7 @@ class CachesStream(Stream): """A cache was invalidated on the master and no other stream would invalidate the cache on the workers """ + NAME = "caches" ROW_TYPE = CachesStreamRow @@ -337,6 +336,7 @@ class CachesStream(Stream): class PublicRoomsStream(Stream): """The public rooms list changed """ + NAME = "public_rooms" ROW_TYPE = PublicRoomsStreamRow @@ -352,6 +352,7 @@ class PublicRoomsStream(Stream): class DeviceListsStream(Stream): """Someone added/changed/removed a device """ + NAME = "device_lists" _LIMITED = False ROW_TYPE = DeviceListsStreamRow @@ -368,6 +369,7 @@ class DeviceListsStream(Stream): class ToDeviceStream(Stream): """New to_device messages for a client """ + NAME = "to_device" ROW_TYPE = ToDeviceStreamRow @@ -383,6 +385,7 @@ class ToDeviceStream(Stream): class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ + NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow @@ -398,6 +401,7 @@ class TagAccountDataStream(Stream): class AccountDataStream(Stream): """Global or per room account data was changed """ + NAME = "account_data" ROW_TYPE = AccountDataStreamRow @@ -416,7 +420,7 @@ class AccountDataStream(Stream): results = list(room_results) results.extend( - (stream_id, user_id, None, account_data_type, content,) + (stream_id, user_id, None, account_data_type, content) for stream_id, user_id, account_data_type, content in global_results ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f1290d022a..3d0694bb11 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -52,6 +52,7 @@ data part are: @attr.s(slots=True, frozen=True) class EventsStreamRow(object): """A parsed row from the events replication stream""" + type = attr.ib() # str: the TypeId of one of the *EventsStreamRows data = attr.ib() # BaseEventsStreamRow @@ -80,11 +81,11 @@ class BaseEventsStreamRow(object): class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional + event_id = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str, optional + redacts = attr.ib() # str, optional relates_to = attr.ib() # str, optional @@ -92,24 +93,21 @@ class EventsStreamEventRow(BaseEventsStreamRow): class EventsStreamCurrentStateRow(BaseEventsStreamRow): TypeId = "state" - room_id = attr.ib() # str - type = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str state_key = attr.ib() # str - event_id = attr.ib() # str, optional + event_id = attr.ib() # str, optional TypeToRow = { - Row.TypeId: Row - for Row in ( - EventsStreamEventRow, - EventsStreamCurrentStateRow, - ) + Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow) } class EventsStream(Stream): """We received a new event, or an event went from being an outlier to not """ + NAME = "events" def __init__(self, hs): @@ -121,19 +119,17 @@ class EventsStream(Stream): @defer.inlineCallbacks def update_function(self, from_token, current_token, limit=None): event_rows = yield self._store.get_all_new_forward_event_rows( - from_token, current_token, limit, + from_token, current_token, limit ) event_updates = ( - (row[0], EventsStreamEventRow.TypeId, row[1:]) - for row in event_rows + (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows ) state_rows = yield self._store.get_all_updated_current_state_deltas( from_token, current_token, limit ) state_updates = ( - (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) - for row in state_rows + (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows ) all_updates = heapq.merge(event_updates, state_updates) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 9aa43aa8d2..dc2484109d 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -17,16 +17,20 @@ from collections import namedtuple from ._base import Stream -FederationStreamRow = namedtuple("FederationStreamRow", ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow -)) +FederationStreamRow = namedtuple( + "FederationStreamRow", + ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow + ), +) class FederationStream(Stream): """Data to be sent over federation. Only available when master has federation sending disabled. """ + NAME = "federation" ROW_TYPE = FederationStreamRow |