diff options
Diffstat (limited to 'synapse/replication')
27 files changed, 880 insertions, 438 deletions
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py new file mode 100644 index 0000000000..589ee94c66 --- /dev/null +++ b/synapse/replication/http/__init__.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import JsonResource +from synapse.replication.http import membership, send_event + +REPLICATION_PREFIX = "/_synapse/replication" + + +class ReplicationRestResource(JsonResource): + def __init__(self, hs): + JsonResource.__init__(self, hs, canonical_json=False) + self.register_servlets(hs) + + def register_servlets(self, hs): + send_event.register_servlets(hs, self) + membership.register_servlets(hs, self) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py new file mode 100644 index 0000000000..6bfc8a5b89 --- /dev/null +++ b/synapse/replication/http/membership.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from twisted.internet import defer + +from synapse.api.errors import MatrixCodeMessageException, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import Requester, UserID +from synapse.util.distributor import user_joined_room, user_left_room + +logger = logging.getLogger(__name__) + + +@defer.inlineCallbacks +def remote_join(client, host, port, requester, remote_room_hosts, + room_id, user_id, content): + """Ask the master to do a remote join for the given user to the given room + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + remote_room_hosts (list[str]): Servers to try and join via + room_id (str) + user_id (str) + content (dict): The event content to use for the join event + + Returns: + Deferred + """ + uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port) + + payload = { + "requester": requester.serialize(), + "remote_room_hosts": remote_room_hosts, + "room_id": room_id, + "user_id": user_id, + "content": content, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +@defer.inlineCallbacks +def remote_reject_invite(client, host, port, requester, remote_room_hosts, + room_id, user_id): + """Ask master to reject the invite for the user and room. + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + remote_room_hosts (list[str]): Servers to try and reject via + room_id (str) + user_id (str) + + Returns: + Deferred + """ + uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port) + + payload = { + "requester": requester.serialize(), + "remote_room_hosts": remote_room_hosts, + "room_id": room_id, + "user_id": user_id, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +@defer.inlineCallbacks +def get_or_register_3pid_guest(client, host, port, requester, + medium, address, inviter_user_id): + """Ask the master to get/create a guest account for given 3PID. + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + medium (str) + address (str) + inviter_user_id (str): The user ID who is trying to invite the + 3PID + + Returns: + Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the + 3PID guest account. + """ + + uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port) + + payload = { + "requester": requester.serialize(), + "medium": medium, + "address": address, + "inviter_user_id": inviter_user_id, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +@defer.inlineCallbacks +def notify_user_membership_change(client, host, port, user_id, room_id, change): + """Notify master that a user has joined or left the room + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication. + user_id (str) + room_id (str) + change (str): Either "join" or "left" + + Returns: + Deferred + """ + assert change in ("joined", "left") + + uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change) + + payload = { + "user_id": user_id, + "room_id": room_id, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +class ReplicationRemoteJoinRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/remote_join$")] + + def __init__(self, hs): + super(ReplicationRemoteJoinRestServlet, self).__init__() + + self.federation_handler = hs.get_handlers().federation_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_POST(self, request): + content = parse_json_object_from_request(request) + + remote_room_hosts = content["remote_room_hosts"] + room_id = content["room_id"] + user_id = content["user_id"] + event_content = content["content"] + + requester = Requester.deserialize(self.store, content["requester"]) + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info( + "remote_join: %s into room: %s", + user_id, room_id, + ) + + yield self.federation_handler.do_invite_join( + remote_room_hosts, + room_id, + user_id, + event_content, + ) + + defer.returnValue((200, {})) + + +class ReplicationRemoteRejectInviteRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")] + + def __init__(self, hs): + super(ReplicationRemoteRejectInviteRestServlet, self).__init__() + + self.federation_handler = hs.get_handlers().federation_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_POST(self, request): + content = parse_json_object_from_request(request) + + remote_room_hosts = content["remote_room_hosts"] + room_id = content["room_id"] + user_id = content["user_id"] + + requester = Requester.deserialize(self.store, content["requester"]) + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info( + "remote_reject_invite: %s out of room: %s", + user_id, room_id, + ) + + try: + event = yield self.federation_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + user_id, + ) + ret = event.get_pdu_json() + except Exception as e: + # if we were unable to reject the exception, just mark + # it as rejected on our end and plough ahead. + # + # The 'except' clause is very broad, but we need to + # capture everything from DNS failures upwards + # + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + user_id, room_id + ) + ret = {} + + defer.returnValue((200, ret)) + + +class ReplicationRegister3PIDGuestRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")] + + def __init__(self, hs): + super(ReplicationRegister3PIDGuestRestServlet, self).__init__() + + self.registeration_handler = hs.get_handlers().registration_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_POST(self, request): + content = parse_json_object_from_request(request) + + medium = content["medium"] + address = content["address"] + inviter_user_id = content["inviter_user_id"] + + requester = Requester.deserialize(self.store, content["requester"]) + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info("get_or_register_3pid_guest: %r", content) + + ret = yield self.registeration_handler.get_or_register_3pid_guest( + medium, address, inviter_user_id, + ) + + defer.returnValue((200, ret)) + + +class ReplicationUserJoinedLeftRoomRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")] + + def __init__(self, hs): + super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__() + + self.registeration_handler = hs.get_handlers().registration_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.distributor = hs.get_distributor() + + def on_POST(self, request, change): + content = parse_json_object_from_request(request) + + user_id = content["user_id"] + room_id = content["room_id"] + + logger.info("user membership change: %s in %s", user_id, room_id) + + user = UserID.from_string(user_id) + + if change == "joined": + user_joined_room(self.distributor, user, room_id) + elif change == "left": + user_left_room(self.distributor, user, room_id) + else: + raise Exception("Unrecognized change: %r", change) + + return (200, {}) + + +def register_servlets(hs, http_server): + ReplicationRemoteJoinRestServlet(hs).register(http_server) + ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) + ReplicationRegister3PIDGuestRestServlet(hs).register(http_server) + ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py new file mode 100644 index 0000000000..5227bc333d --- /dev/null +++ b/synapse/replication/http/send_event.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from twisted.internet import defer + +from synapse.api.errors import ( + CodeMessageException, + MatrixCodeMessageException, + SynapseError, +) +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import Requester, UserID +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + + +@defer.inlineCallbacks +def send_event_to_master(clock, store, client, host, port, requester, event, context, + ratelimit, extra_users): + """Send event to be handled on the master + + Args: + clock (synapse.util.Clock) + store (DataStore) + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + event (FrozenEvent) + context (EventContext) + ratelimit (bool) + extra_users (list(UserID)): Any extra users to notify about event + """ + uri = "http://%s:%s/_synapse/replication/send_event/%s" % ( + host, port, event.event_id, + ) + + serialized_context = yield context.serialize(event, store) + + payload = { + "event": event.get_pdu_json(), + "internal_metadata": event.internal_metadata.get_dict(), + "rejected_reason": event.rejected_reason, + "context": serialized_context, + "requester": requester.serialize(), + "ratelimit": ratelimit, + "extra_users": [u.to_string() for u in extra_users], + } + + try: + # We keep retrying the same request for timeouts. This is so that we + # have a good idea that the request has either succeeded or failed on + # the master, and so whether we should clean up or not. + while True: + try: + result = yield client.put_json(uri, payload) + break + except CodeMessageException as e: + if e.code != 504: + raise + + logger.warn("send_event request timed out") + + # If we timed out we probably don't need to worry about backing + # off too much, but lets just wait a little anyway. + yield clock.sleep(1) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +class ReplicationSendEventRestServlet(RestServlet): + """Handles events newly created on workers, including persisting and + notifying. + + The API looks like: + + POST /_synapse/replication/send_event/:event_id + + { + "event": { .. serialized event .. }, + "internal_metadata": { .. serialized internal_metadata .. }, + "rejected_reason": .., // The event.rejected_reason field + "context": { .. serialized event context .. }, + "requester": { .. serialized requester .. }, + "ratelimit": true, + "extra_users": [], + } + """ + PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")] + + def __init__(self, hs): + super(ReplicationSendEventRestServlet, self).__init__() + + self.event_creation_handler = hs.get_event_creation_handler() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + # The responses are tiny, so we may as well cache them for a while + self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000) + + def on_PUT(self, request, event_id): + return self.response_cache.wrap( + event_id, + self._handle_request, + request + ) + + @defer.inlineCallbacks + def _handle_request(self, request): + with Measure(self.clock, "repl_send_event_parse"): + content = parse_json_object_from_request(request) + + event_dict = content["event"] + internal_metadata = content["internal_metadata"] + rejected_reason = content["rejected_reason"] + event = FrozenEvent(event_dict, internal_metadata, rejected_reason) + + requester = Requester.deserialize(self.store, content["requester"]) + context = yield EventContext.deserialize(self.store, content["context"]) + + ratelimit = content["ratelimit"] + extra_users = [UserID.from_string(u) for u in content["extra_users"]] + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info( + "Got event to send with ID: %s into room: %s", + event.event_id, event.room_id, + ) + + yield self.event_creation_handler.persist_and_notify_client_event( + requester, event, context, + ratelimit=ratelimit, + extra_users=extra_users, + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + ReplicationSendEventRestServlet(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index b962641166..3f7be74e02 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -13,19 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine from ._slaved_id_tracker import SlavedIdTracker -import logging - logger = logging.getLogger(__name__) class BaseSlavedStore(SQLBaseStore): def __init__(self, db_conn, hs): - super(BaseSlavedStore, self).__init__(hs) + super(BaseSlavedStore, self).__init__(db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( db_conn, "cache_invalidation_stream", "stream_id", diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index efbd87918e..d9ba6d69b1 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,50 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.account_data import AccountDataStore -from synapse.storage.tags import TagsStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.account_data import AccountDataWorkerStore +from synapse.storage.tags import TagsWorkerStore -class SlavedAccountDataStore(BaseSlavedStore): +class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedAccountDataStore, self).__init__(db_conn, hs) self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id", ) - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", - self._account_data_id_gen.get_current_token(), - ) - - get_account_data_for_user = ( - AccountDataStore.__dict__["get_account_data_for_user"] - ) - - get_global_account_data_by_type_for_users = ( - AccountDataStore.__dict__["get_global_account_data_by_type_for_users"] - ) - get_global_account_data_by_type_for_user = ( - AccountDataStore.__dict__["get_global_account_data_by_type_for_user"] - ) - - get_tags_for_user = TagsStore.__dict__["get_tags_for_user"] - get_tags_for_room = ( - DataStore.get_tags_for_room.__func__ - ) - get_account_data_for_room = ( - DataStore.get_account_data_for_room.__func__ - ) - - get_updated_tags = DataStore.get_updated_tags.__func__ - get_updated_account_data_for_user = ( - DataStore.get_updated_account_data_for_user.__func__ - ) + super(SlavedAccountDataStore, self).__init__(db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() @@ -85,6 +56,10 @@ class SlavedAccountDataStore(BaseSlavedStore): (row.data_type, row.user_id,) ) self.get_account_data_for_user.invalidate((row.user_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id,)) + self.get_account_data_for_room_and_type.invalidate( + (row.user_id, row.room_id, row.data_type,), + ) self._account_data_stream_cache.entity_has_changed( row.user_id, token ) diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index 0d3f31a50c..b53a4c6bd1 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,33 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.config.appservice import load_appservices -from synapse.storage.appservice import _make_exclusive_regex +from synapse.storage.appservice import ( + ApplicationServiceTransactionWorkerStore, + ApplicationServiceWorkerStore, +) -class SlavedApplicationServiceStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedApplicationServiceStore, self).__init__(db_conn, hs) - self.services_cache = load_appservices( - hs.config.server_name, - hs.config.app_service_config_files - ) - self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - - get_app_service_by_token = DataStore.get_app_service_by_token.__func__ - get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ - get_app_services = DataStore.get_app_services.__func__ - get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__ - create_appservice_txn = DataStore.create_appservice_txn.__func__ - get_appservices_by_state = DataStore.get_appservices_by_state.__func__ - get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__ - _get_last_txn = DataStore._get_last_txn.__func__ - complete_appservice_txn = DataStore.complete_appservice_txn.__func__ - get_appservice_state = DataStore.get_appservice_state.__func__ - set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__ - set_appservice_state = DataStore.set_appservice_state.__func__ - get_if_app_services_interested_in_user = ( - DataStore.get_if_app_services_interested_in_user.__func__ - ) +class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore, + ApplicationServiceWorkerStore): + pass diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 65250285e8..60641f1a49 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore from synapse.storage.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache +from ._base import BaseSlavedStore + class SlavedClientIpStore(BaseSlavedStore): def __init__(self, db_conn, hs): @@ -29,9 +30,8 @@ class SlavedClientIpStore(BaseSlavedStore): max_entries=50000 * CACHE_SIZE_FACTOR, ) - def insert_client_ip(self, user, access_token, ip, user_agent, device_id): + def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) - user_id = user.to_string() key = (user_id, access_token, ip) try: diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 6f3fb64770..87eaa53004 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker from synapse.storage import DataStore -from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker class SlavedDeviceInboxStore(BaseSlavedStore): diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 7687867aee..8206a988f7 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker from synapse.storage import DataStore from synapse.storage.end_to_end_keys import EndToEndKeyStore from synapse.util.caches.stream_change_cache import StreamChangeCache +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + class SlavedDeviceStore(BaseSlavedStore): def __init__(self, db_conn, hs): diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 7301d885f2..1d1d48709a 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.directory import DirectoryWorkerStore + from ._base import BaseSlavedStore -from synapse.storage.directory import DirectoryStore -class DirectoryStore(BaseSlavedStore): - get_aliases_for_room = DirectoryStore.__dict__[ - "get_aliases_for_room" - ] +class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 94ebbffc1b..bdb5eee4af 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,20 +13,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker +import logging from synapse.api.constants import EventTypes -from synapse.storage import DataStore -from synapse.storage.roommember import RoomMemberStore -from synapse.storage.event_federation import EventFederationStore -from synapse.storage.event_push_actions import EventPushActionsStore -from synapse.storage.state import StateStore -from synapse.storage.stream import StreamStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - -import logging +from synapse.storage.event_federation import EventFederationWorkerStore +from synapse.storage.event_push_actions import EventPushActionsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore +from synapse.storage.roommember import RoomMemberWorkerStore +from synapse.storage.signatures import SignatureWorkerStore +from synapse.storage.state import StateGroupWorkerStore +from synapse.storage.stream import StreamWorkerStore +from synapse.storage.user_erasure_store import UserErasureWorkerStore +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker logger = logging.getLogger(__name__) @@ -39,163 +40,34 @@ logger = logging.getLogger(__name__) # the method descriptor on the DataStore and chuck them into our class. -class SlavedEventStore(BaseSlavedStore): +class SlavedEventStore(EventFederationWorkerStore, + RoomMemberWorkerStore, + EventPushActionsWorkerStore, + StreamWorkerStore, + EventsWorkerStore, + StateGroupWorkerStore, + SignatureWorkerStore, + UserErasureWorkerStore, + BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedEventStore, self).__init__(db_conn, hs) self._stream_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", ) self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) - events_max = self._stream_id_gen.get_current_token() - event_cache_prefill, min_event_val = self._get_cache_dict( - db_conn, "events", - entity_column="room_id", - stream_column="stream_ordering", - max_value=events_max, - ) - self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", min_event_val, - prefilled_cache=event_cache_prefill, - ) - self._membership_stream_cache = StreamChangeCache( - "MembershipStreamChangeCache", events_max, - ) - self.stream_ordering_month_ago = 0 - self._stream_order_on_start = self.get_room_max_stream_ordering() + super(SlavedEventStore, self).__init__(db_conn, hs) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. - get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] - get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] - get_hosts_in_room = RoomMemberStore.__dict__["get_hosts_in_room"] - get_users_who_share_room_with_user = ( - RoomMemberStore.__dict__["get_users_who_share_room_with_user"] - ) - get_latest_event_ids_in_room = EventFederationStore.__dict__[ - "get_latest_event_ids_in_room" - ] - get_invited_rooms_for_user = RoomMemberStore.__dict__[ - "get_invited_rooms_for_user" - ] - get_unread_event_push_actions_by_room_for_user = ( - EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"] - ) - _get_unread_counts_by_receipt_txn = ( - DataStore._get_unread_counts_by_receipt_txn.__func__ - ) - _get_unread_counts_by_pos_txn = ( - DataStore._get_unread_counts_by_pos_txn.__func__ - ) - _get_state_group_for_events = ( - StateStore.__dict__["_get_state_group_for_events"] - ) - _get_state_group_for_event = ( - StateStore.__dict__["_get_state_group_for_event"] - ) - _get_state_groups_from_groups = ( - StateStore.__dict__["_get_state_groups_from_groups"] - ) - _get_state_groups_from_groups_txn = ( - DataStore._get_state_groups_from_groups_txn.__func__ - ) - get_recent_event_ids_for_room = ( - StreamStore.__dict__["get_recent_event_ids_for_room"] - ) - get_current_state_ids = ( - StateStore.__dict__["get_current_state_ids"] - ) - get_state_group_delta = StateStore.__dict__["get_state_group_delta"] - _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"] - has_room_changed_since = DataStore.has_room_changed_since.__func__ - - get_unread_push_actions_for_user_in_range_for_http = ( - DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ - ) - get_unread_push_actions_for_user_in_range_for_email = ( - DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__ - ) - get_push_action_users_in_range = ( - DataStore.get_push_action_users_in_range.__func__ - ) - get_event = DataStore.get_event.__func__ - get_events = DataStore.get_events.__func__ - get_rooms_for_user_where_membership_is = ( - DataStore.get_rooms_for_user_where_membership_is.__func__ - ) - get_membership_changes_for_user = ( - DataStore.get_membership_changes_for_user.__func__ - ) - get_room_events_max_id = DataStore.get_room_events_max_id.__func__ - get_room_events_stream_for_room = ( - DataStore.get_room_events_stream_for_room.__func__ - ) - get_events_around = DataStore.get_events_around.__func__ - get_state_for_event = DataStore.get_state_for_event.__func__ - get_state_for_events = DataStore.get_state_for_events.__func__ - get_state_groups = DataStore.get_state_groups.__func__ - get_state_groups_ids = DataStore.get_state_groups_ids.__func__ - get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__ - get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__ - get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__ - get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ - _get_joined_users_from_context = ( - RoomMemberStore.__dict__["_get_joined_users_from_context"] - ) - - get_joined_hosts = DataStore.get_joined_hosts.__func__ - _get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"] - - get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ - get_room_events_stream_for_rooms = ( - DataStore.get_room_events_stream_for_rooms.__func__ - ) - is_host_joined = RoomMemberStore.__dict__["is_host_joined"] - get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ - - _set_before_and_after = staticmethod(DataStore._set_before_and_after) - - _get_events = DataStore._get_events.__func__ - _get_events_from_cache = DataStore._get_events_from_cache.__func__ - - _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ - _enqueue_events = DataStore._enqueue_events.__func__ - _do_fetch = DataStore._do_fetch.__func__ - _fetch_event_rows = DataStore._fetch_event_rows.__func__ - _get_event_from_row = DataStore._get_event_from_row.__func__ - _get_rooms_for_user_where_membership_is_txn = ( - DataStore._get_rooms_for_user_where_membership_is_txn.__func__ - ) - _get_state_for_groups = DataStore._get_state_for_groups.__func__ - _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ - _get_events_around_txn = DataStore._get_events_around_txn.__func__ - _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__ - - get_backfill_events = DataStore.get_backfill_events.__func__ - _get_backfill_events = DataStore._get_backfill_events.__func__ - get_missing_events = DataStore.get_missing_events.__func__ - _get_missing_events = DataStore._get_missing_events.__func__ - - get_auth_chain = DataStore.get_auth_chain.__func__ - get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__ - _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__ - - get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__ - - get_forward_extremeties_for_room = ( - DataStore.get_forward_extremeties_for_room.__func__ - ) - _get_forward_extremeties_for_room = ( - EventFederationStore.__dict__["_get_forward_extremeties_for_room"] - ) - - get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__ - - get_federation_out_pos = DataStore.get_federation_out_pos.__func__ - update_federation_out_pos = DataStore.update_federation_out_pos.__func__ + + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 819ed62881..456a14cd5c 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore from synapse.storage.filtering import FilteringStore +from ._base import BaseSlavedStore + class SlavedFilteringStore(BaseSlavedStore): def __init__(self, db_conn, hs): diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py new file mode 100644 index 0000000000..5777f07c8d --- /dev/null +++ b/synapse/replication/slave/storage/groups.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage import DataStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + + +class SlavedGroupServerStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(db_conn, hs) + + self.hs = hs + + self._group_updates_id_gen = SlavedIdTracker( + db_conn, "local_group_updates", "stream_id", + ) + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), + ) + + get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__ + get_group_stream_token = DataStore.get_group_stream_token.__func__ + get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__ + + def stream_positions(self): + result = super(SlavedGroupServerStore, self).stream_positions() + result["groups"] = self._group_updates_id_gen.get_current_token() + return result + + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "groups": + self._group_updates_id_gen.advance(token) + for row in rows: + self._group_updates_stream_cache.entity_has_changed( + row.user_id, token + ) + + return super(SlavedGroupServerStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py index dd2ae49e48..05ed168463 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore from synapse.storage import DataStore from synapse.storage.keys import KeyStore +from ._base import BaseSlavedStore + class SlavedKeyStore(BaseSlavedStore): _get_server_verify_key = KeyStore.__dict__[ diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index cfb9280181..80b744082a 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker - -from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.storage import DataStore from synapse.storage.presence import PresenceStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py new file mode 100644 index 0000000000..46c28d4171 --- /dev/null +++ b/synapse/replication/slave/storage/profile.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.storage.profile import ProfileWorkerStore + + +class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 83e880fdd2..f0200c1e98 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,31 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .events import SlavedEventStore +from synapse.storage.push_rule import PushRulesWorkerStore + from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.push_rule import PushRuleStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from .events import SlavedEventStore -class SlavedPushRuleStore(SlavedEventStore): +class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore): def __init__(self, db_conn, hs): - super(SlavedPushRuleStore, self).__init__(db_conn, hs) self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", ) - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", - self._push_rules_stream_id_gen.get_current_token(), - ) - - get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"] - get_push_rules_enabled_for_user = ( - PushRuleStore.__dict__["get_push_rules_enabled_for_user"] - ) - have_push_rules_changed_for_user = ( - DataStore.have_push_rules_changed_for_user.__func__ - ) + super(SlavedPushRuleStore, self).__init__(db_conn, hs) def get_push_rules_stream_token(self): return ( @@ -45,6 +33,9 @@ class SlavedPushRuleStore(SlavedEventStore): self._stream_id_gen.get_current_token(), ) + def get_max_push_rules_stream_id(self): + return self._push_rules_stream_id_gen.get_current_token() + def stream_positions(self): result = super(SlavedPushRuleStore, self).stream_positions() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 4e8d68ece9..3b2213c0d4 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.pusher import PusherWorkerStore + from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore - -class SlavedPusherStore(BaseSlavedStore): +class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedPusherStore, self).__init__(db_conn, hs) @@ -28,13 +29,6 @@ class SlavedPusherStore(BaseSlavedStore): extra_tables=[("deleted_pushers", "stream_id")], ) - get_all_pushers = DataStore.get_all_pushers.__func__ - get_pushers_by = DataStore.get_pushers_by.__func__ - get_pushers_by_app_id_and_pushkey = ( - DataStore.get_pushers_by_app_id_and_pushkey.__func__ - ) - _decode_pushers_rows = DataStore._decode_pushers_rows.__func__ - def stream_positions(self): result = super(SlavedPusherStore, self).stream_positions() result["pushers"] = self._pushers_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index b371574ece..ed12342f40 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.receipts import ReceiptsWorkerStore + from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.receipts import ReceiptsStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - # So, um, we want to borrow a load of functions intended for reading from # a DataStore, but we don't want to take functions that either write to the # DataStore or are cached and don't have cache invalidation logic. @@ -29,36 +28,19 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache # the method descriptor on the DataStore and chuck them into our class. -class SlavedReceiptsStore(BaseSlavedStore): +class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedReceiptsStore, self).__init__(db_conn, hs) - + # We instantiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() - ) - - get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] - get_linearized_receipts_for_room = ( - ReceiptsStore.__dict__["get_linearized_receipts_for_room"] - ) - _get_linearized_receipts_for_rooms = ( - ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"] - ) - get_last_receipt_event_id_for_user = ( - ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"] - ) - - get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__ - get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__ + super(SlavedReceiptsStore, self).__init__(db_conn, hs) - get_linearized_receipts_for_rooms = ( - DataStore.get_linearized_receipts_for_rooms.__func__ - ) + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_current_token() def stream_positions(self): result = super(SlavedReceiptsStore, self).stream_positions() @@ -67,10 +49,12 @@ class SlavedReceiptsStore(BaseSlavedStore): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) - self.get_linearized_receipts_for_room.invalidate_many((room_id,)) + self._get_linearized_receipts_for_room.invalidate_many((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) + self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) + self.get_receipts_for_room.invalidate((room_id, receipt_type)) def process_replication_rows(self, stream_name, token, rows): if stream_name == "receipts": diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index e27c7332d2..408d91df1c 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -13,21 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.storage.registration import RegistrationStore - +from synapse.storage.registration import RegistrationWorkerStore -class SlavedRegistrationStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedRegistrationStore, self).__init__(db_conn, hs) +from ._base import BaseSlavedStore - # TODO: use the cached version and invalidate deleted tokens - get_user_by_access_token = RegistrationStore.__dict__[ - "get_user_by_access_token" - ] - _query_for_auth = DataStore._query_for_auth.__func__ - get_user_by_id = RegistrationStore.__dict__[ - "get_user_by_id" - ] +class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index f510384033..0cb474928c 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -13,33 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.room import RoomWorkerStore + from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.storage.room import RoomStore from ._slaved_id_tracker import SlavedIdTracker -class RoomStore(BaseSlavedStore): +class RoomStore(RoomWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(RoomStore, self).__init__(db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) - get_public_room_ids = DataStore.get_public_room_ids.__func__ - get_current_public_room_stream_id = ( - DataStore.get_current_public_room_stream_id.__func__ - ) - get_public_room_ids_at_stream_id = ( - RoomStore.__dict__["get_public_room_ids_at_stream_id"] - ) - get_public_room_ids_at_stream_id_txn = ( - DataStore.get_public_room_ids_at_stream_id_txn.__func__ - ) - get_published_at_stream_id_txn = ( - DataStore.get_published_at_stream_id_txn.__func__ - ) - get_public_room_changes = DataStore.get_public_room_changes.__func__ + def get_current_public_room_stream_id(self): + return self._public_room_id_gen.get_current_token() def stream_positions(self): result = super(RoomStore, self).stream_positions() diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py index fbb58f35da..9c9a5eadd9 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore from synapse.storage import DataStore from synapse.storage.transactions import TransactionStore +from ._base import BaseSlavedStore + class TransactionStore(BaseSlavedStore): get_destination_retry_timings = TransactionStore.__dict__[ diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 6d2513c4e2..e592ab57bf 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -15,17 +15,20 @@ """A replication client for use by synapse workers. """ -from twisted.internet import reactor, defer +import logging + +from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory from .commands import ( - FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand, + FederationAckCommand, + InvalidateCacheCommand, + RemovePusherCommand, UserIpCommand, + UserSyncCommand, ) from .protocol import ClientReplicationStreamProtocol -import logging - logger = logging.getLogger(__name__) @@ -44,7 +47,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): self.server_name = hs.config.server_name self._clock = hs.get_clock() # As self.clock is defined in super class - reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying) + hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) def startedConnecting(self, connector): logger.info("Connecting to replication: %r", connector.getDestination()) @@ -95,7 +98,7 @@ class ReplicationClientHandler(object): factory = ReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port - reactor.connectTCP(host, port, factory) + hs.get_reactor().connectTCP(host, port, factory) def on_rdata(self, stream_name, token, rows): """Called when we get new replication data. By default this just pokes diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index a009214e43..f3908df642 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -19,8 +19,14 @@ allowed to be sent by which side. """ import logging -import ujson as json +import platform +if platform.python_implementation() == "PyPy": + import json + _json_encoder = json.JSONEncoder() +else: + import simplejson as json + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) logger = logging.getLogger(__name__) @@ -107,7 +113,7 @@ class RdataCommand(Command): return " ".join(( self.stream_name, str(self.token) if self.token is not None else "batch", - json.dumps(self.row), + _json_encoder.encode(self.row), )) @@ -301,7 +307,9 @@ class InvalidateCacheCommand(Command): return cls(cache_func, json.loads(keys_json)) def to_line(self): - return " ".join((self.cache_func, json.dumps(self.keys))) + return " ".join(( + self.cache_func, _json_encoder.encode(self.keys), + )) class UserIpCommand(Command): @@ -323,14 +331,18 @@ class UserIpCommand(Command): @classmethod def from_line(cls, line): - user_id, access_token, ip, device_id, last_seen, user_agent = line.split(" ", 5) + user_id, jsn = line.split(" ", 1) + + access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) - return cls(user_id, access_token, ip, user_agent, device_id, int(last_seen)) + return cls( + user_id, access_token, ip, user_agent, device_id, last_seen + ) def to_line(self): - return " ".join(( - self.user_id, self.access_token, self.ip, self.device_id, - str(self.last_seen), self.user_agent, + return self.user_id + " " + _json_encoder.encode(( + self.access_token, self.ip, self.user_agent, self.device_id, + self.last_seen, )) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 062272f8dd..dec5ac0913 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -49,32 +49,40 @@ indicate which side is sending, these are *not* included on the wire:: * connection closed by server * """ +import fcntl +import logging +import struct +from collections import defaultdict + +from six import iteritems, iterkeys + +from prometheus_client import Counter + from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from commands import ( - COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, - ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand, - NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand, -) -from streams import STREAMS_MAP - +from synapse.metrics import LaterGauge from synapse.util.stringutils import random_string -from synapse.metrics.metric import CounterMetric - -import logging -import synapse.metrics -import struct -import fcntl - -metrics = synapse.metrics.get_metrics_for(__name__) - -connection_close_counter = metrics.register_counter( - "close_reason", labels=["reason_type"], +from .commands import ( + COMMAND_MAP, + VALID_CLIENT_COMMANDS, + VALID_SERVER_COMMANDS, + ErrorCommand, + NameCommand, + PingCommand, + PositionCommand, + RdataCommand, + ReplicateCommand, + ServerCommand, + SyncCommand, + UserSyncCommand, ) +from .streams import STREAMS_MAP +connection_close_counter = Counter( + "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]) # A list of all connected protocols. This allows us to send metrics about the # connections. @@ -136,12 +144,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None - self.inbound_commands_counter = CounterMetric( - "inbound_commands", labels=["command"], - ) - self.outbound_commands_counter = CounterMetric( - "outbound_commands", labels=["command"], - ) + self.inbound_commands_counter = defaultdict(int) + self.outbound_commands_counter = defaultdict(int) def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -201,7 +205,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() - self.inbound_commands_counter.inc(cmd_name) + self.inbound_commands_counter[cmd_name] = ( + self.inbound_commands_counter[cmd_name] + 1) cmd_cls = COMMAND_MAP[cmd_name] try: @@ -244,15 +249,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): becoming full. """ if self.state == ConnectionStates.CLOSED: - logger.info("[%s] Not sending, connection closed", self.id()) + logger.debug("[%s] Not sending, connection closed", self.id()) return if do_buffer and self.state != ConnectionStates.ESTABLISHED: self._queue_command(cmd) return - self.outbound_commands_counter.inc(cmd.NAME) - + self.outbound_commands_counter[cmd.NAME] = ( + self.outbound_commands_counter[cmd.NAME] + 1) string = "%s %s" % (cmd.NAME, cmd.to_line(),) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -264,7 +269,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def _queue_command(self, cmd): """Queue the command until the connection is ready to write to again. """ - logger.info("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) if len(self.pending_commands) > self.max_line_buffer: @@ -317,9 +322,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def connectionLost(self, reason): logger.info("[%s] Replication connection closed: %r", self.id(), reason) if isinstance(reason, Failure): - connection_close_counter.inc(reason.type.__name__) + connection_close_counter.labels(reason.type.__name__).inc() else: - connection_close_counter.inc(reason.__class__.__name__) + connection_close_counter.labels(reason.__class__.__name__).inc() try: # Remove us from list of connections to be monitored @@ -392,7 +397,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): if stream_name == "ALL": # Subscribe to all streams we're publishing to. - for stream in self.streamer.streams_by_name.iterkeys(): + for stream in iterkeys(self.streamer.streams_by_name): self.subscribe_to_stream(stream, token) else: self.subscribe_to_stream(stream_name, token) @@ -498,7 +503,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): BaseReplicationStreamProtocol.connectionMade(self) # Once we've connected subscribe to the necessary streams - for stream_name, token in self.handler.get_streams_to_replicate().iteritems(): + for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): self.replicate(stream_name, token) # Tell the server if we have any users currently syncing (should only @@ -517,25 +522,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_error("Wrong remote") def on_RDATA(self, cmd): + stream_name = cmd.stream_name + inbound_rdata_count.labels(stream_name).inc() + try: - row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row) + row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row) except Exception: logger.exception( "[%s] Failed to parse RDATA: %r %r", - self.id(), cmd.stream_name, cmd.row + self.id(), stream_name, cmd.row ) raise if cmd.token is None: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token - self.pending_batches.setdefault(cmd.stream_name, []).append(row) + self.pending_batches.setdefault(stream_name, []).append(row) else: # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(cmd.stream_name, []) + rows = self.pending_batches.pop(stream_name, []) rows.append(row) - self.handler.on_rdata(cmd.stream_name, cmd.token, rows) + self.handler.on_rdata(stream_name, cmd.token, rows) def on_POSITION(self, cmd): self.handler.on_position(cmd.stream_name, cmd.token) @@ -563,13 +571,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # The following simply registers metrics for the replication connections -metrics.register_callback( - "pending_commands", +pending_commands = LaterGauge( + "synapse_replication_tcp_protocol_pending_commands", + "", + ["name", "conn_id"], lambda: { - (p.name, p.conn_id): len(p.pending_commands) - for p in connected_connections + (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections }, - labels=["name", "conn_id"], ) @@ -580,13 +588,13 @@ def transport_buffer_size(protocol): return 0 -metrics.register_callback( - "transport_send_buffer", +transport_send_buffer = LaterGauge( + "synapse_replication_tcp_protocol_transport_send_buffer", + "", + ["name", "conn_id"], lambda: { - (p.name, p.conn_id): transport_buffer_size(p) - for p in connected_connections + (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections }, - labels=["name", "conn_id"], ) @@ -605,42 +613,51 @@ def transport_kernel_read_buffer_size(protocol, read=True): return 0 -metrics.register_callback( - "transport_kernel_send_buffer", +tcp_transport_kernel_send_buffer = LaterGauge( + "synapse_replication_tcp_protocol_transport_kernel_send_buffer", + "", + ["name", "conn_id"], lambda: { (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) for p in connected_connections }, - labels=["name", "conn_id"], ) -metrics.register_callback( - "transport_kernel_read_buffer", +tcp_transport_kernel_read_buffer = LaterGauge( + "synapse_replication_tcp_protocol_transport_kernel_read_buffer", + "", + ["name", "conn_id"], lambda: { (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) for p in connected_connections }, - labels=["name", "conn_id"], ) -metrics.register_callback( - "inbound_commands", +tcp_inbound_commands = LaterGauge( + "synapse_replication_tcp_protocol_inbound_commands", + "", + ["command", "name", "conn_id"], lambda: { (k[0], p.name, p.conn_id): count for p in connected_connections - for k, count in p.inbound_commands_counter.counts.iteritems() + for k, count in iteritems(p.inbound_commands_counter) }, - labels=["command", "name", "conn_id"], ) -metrics.register_callback( - "outbound_commands", +tcp_outbound_commands = LaterGauge( + "synapse_replication_tcp_protocol_outbound_commands", + "", + ["command", "name", "conn_id"], lambda: { (k[0], p.name, p.conn_id): count for p in connected_connections - for k, count in p.outbound_commands_counter.counts.iteritems() + for k, count in iteritems(p.outbound_commands_counter) }, - labels=["command", "name", "conn_id"], +) + +# number of updates received for each RDATA stream +inbound_rdata_count = Counter( + "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] ) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 3ea3ca5a6f..611fb66e1d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -15,27 +15,29 @@ """The server side of the replication stream. """ -from twisted.internet import defer, reactor -from twisted.internet.protocol import Factory +import logging -from streams import STREAMS_MAP, FederationStream -from protocol import ServerReplicationStreamProtocol +from six import itervalues -from synapse.util.metrics import Measure, measure_func +from prometheus_client import Counter -import logging -import synapse.metrics +from twisted.internet import defer +from twisted.internet.protocol import Factory +from synapse.metrics import LaterGauge +from synapse.util.metrics import Measure, measure_func + +from .protocol import ServerReplicationStreamProtocol +from .streams import STREAMS_MAP, FederationStream -metrics = synapse.metrics.get_metrics_for(__name__) -stream_updates_counter = metrics.register_counter( - "stream_updates", labels=["stream_name"] -) -user_sync_counter = metrics.register_counter("user_sync") -federation_ack_counter = metrics.register_counter("federation_ack") -remove_pusher_counter = metrics.register_counter("remove_pusher") -invalidate_cache_counter = metrics.register_counter("invalidate_cache") -user_ip_cache_counter = metrics.register_counter("user_ip_cache") +stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", + "", ["stream_name"]) +user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") +federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") +remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") +invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache", + "") +user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -69,33 +71,34 @@ class ReplicationStreamer(object): self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self._server_notices_sender = hs.get_server_notices_sender() # Current connections. self.connections = [] - metrics.register_callback("total_connections", lambda: len(self.connections)) + LaterGauge("synapse_replication_tcp_resource_total_connections", "", [], + lambda: len(self.connections)) # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. self.streams = [ - stream(hs) for stream in STREAMS_MAP.itervalues() + stream(hs) for stream in itervalues(STREAMS_MAP) if stream != FederationStream or not hs.config.send_federation ] self.streams_by_name = {stream.NAME: stream for stream in self.streams} - metrics.register_callback( - "connections_per_stream", + LaterGauge( + "synapse_replication_tcp_resource_connections_per_stream", "", + ["stream_name"], lambda: { (stream_name,): len([ conn for conn in self.connections if stream_name in conn.replication_streams ]) for stream_name in self.streams_by_name - }, - labels=["stream_name"], - ) + }) self.federation_sender = None if not hs.config.send_federation: @@ -107,7 +110,7 @@ class ReplicationStreamer(object): self.is_looping = False self.pending_updates = False - reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown) + hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) def on_shutdown(self): # close all connections on shutdown @@ -160,7 +163,11 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, stream.upto_token ) - updates, current_token = yield stream.get_updates() + try: + updates, current_token = yield stream.get_updates() + except Exception: + logger.info("Failed to handle stream %s", stream.NAME) + raise logger.debug( "Sending %d updates to %d connections", @@ -171,7 +178,7 @@ class ReplicationStreamer(object): logger.info( "Streaming: %s -> %s", stream.NAME, updates[-1][0] ) - stream_updates_counter.inc_by(len(updates), stream.NAME) + stream_updates_counter.labels(stream.NAME).inc(len(updates)) # Some streams return multiple rows with the same stream IDs, # we need to make sure they get sent out in batches. We do @@ -212,11 +219,12 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") + @defer.inlineCallbacks def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() - self.presence_handler.update_external_syncs_row( + yield self.presence_handler.update_external_syncs_row( conn_id, user_id, is_syncing, last_sync_ms, ) @@ -240,13 +248,15 @@ class ReplicationStreamer(object): getattr(self.store, cache_func).invalidate(tuple(keys)) @measure_func("repl.on_user_ip") + @defer.inlineCallbacks def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): """The client saw a user request """ user_ip_cache_counter.inc() - self.store.insert_client_ip( + yield self.store.insert_client_ip( user_id, access_token, ip, user_agent, device_id, last_seen, ) + yield self._server_notices_sender.on_user_ip(user_id) def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py index fbafe12cc2..55fe701c5c 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams.py @@ -24,11 +24,10 @@ Each stream is defined by the following information: update_function: The function that returns a list of updates between two tokens """ -from twisted.internet import defer -from collections import namedtuple - import logging +from collections import namedtuple +from twisted.internet import defer logger = logging.getLogger(__name__) @@ -118,6 +117,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", ( "state_key", # str "event_id", # str, optional )) +GroupsStreamRow = namedtuple("GroupsStreamRow", ( + "group_id", # str + "user_id", # str + "type", # str + "content", # dict +)) class Stream(object): @@ -464,6 +469,19 @@ class CurrentStateDeltaStream(Stream): super(CurrentStateDeltaStream, self).__init__(hs) +class GroupServerStream(Stream): + NAME = "groups" + ROW_TYPE = GroupsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_group_stream_token + self.update_function = store.get_all_groups_changes + + super(GroupServerStream, self).__init__(hs) + + STREAMS_MAP = { stream.NAME: stream for stream in ( @@ -482,5 +500,6 @@ STREAMS_MAP = { TagAccountDataStream, AccountDataStream, CurrentStateDeltaStream, + GroupServerStream, ) } |