diff options
29 files changed, 1037 insertions, 731 deletions
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index a6f1e7594e..9a476efa63 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -26,17 +26,17 @@ from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.storage.engines import create_engine -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn from synapse.util.manhole import manhole from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string from synapse import events -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.web.resource import Resource from daemonize import Daemonize @@ -120,30 +120,25 @@ class AppserviceServer(HomeServer): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - appservice_handler = self.get_application_service_handler() - - @defer.inlineCallbacks - def replicate(results): - stream = results.get("events") - if stream: - max_stream_id = stream["position"] - yield appservice_handler.notify_interested_services(max_stream_id) - - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - replicate(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(30) + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return ASReplicationHandler(self) + + +class ASReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(ASReplicationHandler, self).__init__(hs.get_datastore()) + self.appservice_handler = hs.get_application_service_handler() + + def on_rdata(self, stream_name, token, rows): + super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) + + if stream_name == "events": + max_stream_id = self.store.get_room_max_stream_ordering() + preserve_fn( + self.appservice_handler.notify_interested_services + )(max_stream_id) def start(config_options): @@ -199,7 +194,6 @@ def start(config_options): reactor.run() def start(): - ps.replicate() ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index e4ea3ab933..9b72c649ac 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -30,11 +30,11 @@ from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v1.room import PublicRoomListRestServlet from synapse.server import HomeServer from synapse.storage.client_ips import ClientIpStore from synapse.storage.engines import create_engine -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.manhole import manhole @@ -45,7 +45,7 @@ from synapse.crypto import context_factory from synapse import events -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.web.resource import Resource from daemonize import Daemonize @@ -145,21 +145,10 @@ class ClientReaderServer(HomeServer): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url + self.get_tcp_replication().start_replication(self) - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) def start(config_options): @@ -209,7 +198,6 @@ def start(config_options): def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() - ss.replicate() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index e52b0f240d..eb392e1c9d 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -27,9 +27,9 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.engines import create_engine -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.manhole import manhole @@ -42,7 +42,7 @@ from synapse.crypto import context_factory from synapse import events -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.web.resource import Resource from daemonize import Daemonize @@ -134,21 +134,10 @@ class FederationReaderServer(HomeServer): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url + self.get_tcp_replication().start_replication(self) - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) def start(config_options): @@ -198,7 +187,6 @@ def start(config_options): def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() - ss.replicate() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 76c4cc54d1..477e16e0fa 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -23,7 +23,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.http.site import SynapseSite from synapse.federation import send_queue -from synapse.federation.units import Edu from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.events import SlavedEventStore @@ -31,11 +30,11 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.storage.engines import create_engine -from synapse.storage.presence import UserPresenceState -from synapse.util.async import sleep +from synapse.util.async import Linearizer from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn from synapse.util.manhole import manhole from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string @@ -50,7 +49,6 @@ from daemonize import Daemonize import sys import logging import gc -import ujson as json logger = logging.getLogger("synapse.app.appservice") @@ -59,7 +57,28 @@ class FederationSenderSlaveStore( SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, SlavedRegistrationStore, SlavedDeviceStore, ): - pass + def __init__(self, db_conn, hs): + super(FederationSenderSlaveStore, self).__init__(db_conn, hs) + + # We pull out the current federation stream position now so that we + # always have a known value for the federation position in memory so + # that we don't have to bounce via a deferred once when we start the + # replication streams. + self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) + + def _get_federation_out_pos(self, db_conn): + sql = ( + "SELECT stream_id FROM federation_stream_position" + " WHERE type = ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, ("federation",)) + rows = txn.fetchall() + txn.close() + + return rows[0][0] if rows else -1 class FederationSenderServer(HomeServer): @@ -127,26 +146,27 @@ class FederationSenderServer(HomeServer): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - send_handler = FederationSenderHandler(self) - - send_handler.on_start() - - while True: - try: - args = store.stream_positions() - args.update((yield send_handler.stream_positions())) - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - yield send_handler.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(30) + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return FederationSenderReplicationHandler(self) + + +class FederationSenderReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) + self.send_handler = FederationSenderHandler(hs, self) + + def on_rdata(self, stream_name, token, rows): + super(FederationSenderReplicationHandler, self).on_rdata( + stream_name, token, rows + ) + self.send_handler.process_replication_rows(stream_name, token, rows) + + def get_streams_to_replicate(self): + args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate() + args.update(self.send_handler.stream_positions()) + return args def start(config_options): @@ -205,7 +225,6 @@ def start(config_options): reactor.run() def start(): - ps.replicate() ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() @@ -229,9 +248,15 @@ class FederationSenderHandler(object): """Processes the replication stream and forwards the appropriate entries to the federation sender. """ - def __init__(self, hs): + def __init__(self, hs, replication_client): self.store = hs.get_datastore() self.federation_sender = hs.get_federation_sender() + self.replication_client = replication_client + + self.federation_position = self.store.federation_out_pos_startup + self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") + + self._last_ack = self.federation_position self._room_serials = {} self._room_typing = {} @@ -243,98 +268,35 @@ class FederationSenderHandler(object): self.store.get_room_max_stream_ordering() ) - @defer.inlineCallbacks def stream_positions(self): - stream_id = yield self.store.get_federation_out_pos("federation") - defer.returnValue({ - "federation": stream_id, + return {"federation": self.federation_position} - # Ack stuff we've "processed", this should only be called from - # one process. - "federation_ack": stream_id, - }) - - @defer.inlineCallbacks - def process_replication(self, result): + def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. - fed_stream = result.get("federation") - if fed_stream: - latest_id = int(fed_stream["position"]) - - # The federation stream containis a bunch of different types of - # rows that need to be handled differently. We parse the rows, put - # them into the appropriate collection and then send them off. - presence_to_send = {} - keyed_edus = {} - edus = {} - failures = {} - device_destinations = set() - - # Parse the rows in the stream - for row in fed_stream["rows"]: - position, typ, content_js = row - content = json.loads(content_js) - - if typ == send_queue.PRESENCE_TYPE: - destination = content["destination"] - state = UserPresenceState.from_dict(content["state"]) - - presence_to_send.setdefault(destination, []).append(state) - elif typ == send_queue.KEYED_EDU_TYPE: - key = content["key"] - edu = Edu(**content["edu"]) - - keyed_edus.setdefault( - edu.destination, {} - )[(edu.destination, tuple(key))] = edu - elif typ == send_queue.EDU_TYPE: - edu = Edu(**content) - - edus.setdefault(edu.destination, []).append(edu) - elif typ == send_queue.FAILURE_TYPE: - destination = content["destination"] - failure = content["failure"] - - failures.setdefault(destination, []).append(failure) - elif typ == send_queue.DEVICE_MESSAGE_TYPE: - device_destinations.add(content["destination"]) - else: - raise Exception("Unrecognised federation type: %r", typ) - - # We've finished collecting, send everything off - for destination, states in presence_to_send.items(): - self.federation_sender.send_presence(destination, states) - - for destination, edu_map in keyed_edus.items(): - for key, edu in edu_map.items(): - self.federation_sender.send_edu( - edu.destination, edu.edu_type, edu.content, key=key, - ) - - for destination, edu_list in edus.items(): - for edu in edu_list: - self.federation_sender.send_edu( - edu.destination, edu.edu_type, edu.content, key=None, - ) - - for destination, failure_list in failures.items(): - for failure in failure_list: - self.federation_sender.send_failure(destination, failure) - - for destination in device_destinations: - self.federation_sender.send_device_messages(destination) - - # Record where we are in the stream. - yield self.store.update_federation_out_pos( - "federation", latest_id - ) + if stream_name == "federation": + send_queue.process_rows_for_federation(self.federation_sender, rows) + preserve_fn(self.update_token)(token) # We also need to poke the federation sender when new events happen - event_stream = result.get("events") - if event_stream: - latest_pos = event_stream["position"] - self.federation_sender.notify_new_events(latest_pos) + elif stream_name == "events": + self.federation_sender.notify_new_events(token) + + @defer.inlineCallbacks + def update_token(self, token): + self.federation_position = token + + # We linearize here to ensure we don't have races updating the token + with (yield self._fed_position_linearizer.queue(None)): + if self._last_ack < self.federation_position: + yield self.store.update_federation_out_pos( + "federation", self.federation_position + ) + + # We ACK this token over replication so that the master can drop + # its in memory queues + self.replication_client.send_federation_ack(self.federation_position) + self._last_ack = self.federation_position if __name__ == '__main__': diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 1444e69a42..26c4416956 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -25,13 +25,13 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.server import HomeServer from synapse.storage.client_ips import ClientIpStore from synapse.storage.engines import create_engine from synapse.storage.media_repository import MediaRepositoryStore -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.manhole import manhole @@ -45,7 +45,7 @@ from synapse.crypto import context_factory from synapse import events -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.web.resource import Resource from daemonize import Daemonize @@ -142,21 +142,10 @@ class MediaRepositoryServer(HomeServer): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url + self.get_tcp_replication().start_replication(self) - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) def start(config_options): @@ -206,7 +195,6 @@ def start(config_options): def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() - ss.replicate() reactor.callWhenRunning(start) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index ab682e52ec..f9114acfcb 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -27,9 +27,9 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.storage.engines import create_engine from synapse.storage import DataStore -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, preserve_fn, \ PreserveLoggingContext @@ -89,7 +89,6 @@ class PusherSlaveStore( class PusherServer(HomeServer): - def get_db_conn(self, run_new_connection=True): # Any param beginning with cp_ is a parameter for adbapi, and should # not be passed to the database engine. @@ -109,16 +108,7 @@ class PusherServer(HomeServer): logger.info("Finished setting up.") def remove_pusher(self, app_id, push_key, user_id): - http_client = self.get_simple_http_client() - replication_url = self.config.worker_replication_url - url = replication_url + "/remove_pushers" - return http_client.post_json_get_json(url, { - "remove": [{ - "app_id": app_id, - "push_key": push_key, - "user_id": user_id, - }] - }) + self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) def _listen_http(self, listener_config): port = listener_config["port"] @@ -166,73 +156,52 @@ class PusherServer(HomeServer): else: logger.warn("Unrecognized listener type: %s", listener["type"]) + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return PusherReplicationHandler(self) + + +class PusherReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(PusherReplicationHandler, self).__init__(hs.get_datastore()) + + self.pusher_pool = hs.get_pusherpool() + + def on_rdata(self, stream_name, token, rows): + super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) + preserve_fn(self.poke_pushers)(stream_name, token, rows) + @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - pusher_pool = self.get_pusherpool() - - def stop_pusher(user_id, app_id, pushkey): - key = "%s:%s" % (app_id, pushkey) - pushers_for_user = pusher_pool.pushers.get(user_id, {}) - pusher = pushers_for_user.pop(key, None) - if pusher is None: - return - logger.info("Stopping pusher %r / %r", user_id, key) - pusher.on_stop() - - def start_pusher(user_id, app_id, pushkey): - key = "%s:%s" % (app_id, pushkey) - logger.info("Starting pusher %r / %r", user_id, key) - return pusher_pool._refresh_pusher(app_id, pushkey, user_id) - - @defer.inlineCallbacks - def poke_pushers(results): - pushers_rows = set( - map(tuple, results.get("pushers", {}).get("rows", [])) + def poke_pushers(self, stream_name, token, rows): + if stream_name == "pushers": + for row in rows: + if row.deleted: + yield self.stop_pusher(row.user_id, row.app_id, row.pushkey) + else: + yield self.start_pusher(row.user_id, row.app_id, row.pushkey) + elif stream_name == "events": + yield self.pusher_pool.on_new_notifications( + token, token, ) - deleted_pushers_rows = set( - map(tuple, results.get("deleted_pushers", {}).get("rows", [])) + elif stream_name == "receipts": + yield self.pusher_pool.on_new_receipts( + token, token, set(row.room_id for row in rows) ) - for row in sorted(pushers_rows | deleted_pushers_rows): - if row in deleted_pushers_rows: - user_id, app_id, pushkey = row[1:4] - stop_pusher(user_id, app_id, pushkey) - elif row in pushers_rows: - user_id = row[1] - app_id = row[5] - pushkey = row[8] - yield start_pusher(user_id, app_id, pushkey) - - stream = results.get("events") - if stream and stream["rows"]: - min_stream_id = stream["rows"][0][0] - max_stream_id = stream["position"] - preserve_fn(pusher_pool.on_new_notifications)( - min_stream_id, max_stream_id - ) - - stream = results.get("receipts") - if stream and stream["rows"]: - rows = stream["rows"] - affected_room_ids = set(row[1] for row in rows) - min_stream_id = rows[0][0] - max_stream_id = stream["position"] - preserve_fn(pusher_pool.on_new_receipts)( - min_stream_id, max_stream_id, affected_room_ids - ) - - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - poke_pushers(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(30) + + def stop_pusher(self, user_id, app_id, pushkey): + key = "%s:%s" % (app_id, pushkey) + pushers_for_user = self.pusher_pool.pushers.get(user_id, {}) + pusher = pushers_for_user.pop(key, None) + if pusher is None: + return + logger.info("Stopping pusher %r / %r", user_id, key) + pusher.on_stop() + + def start_pusher(self, user_id, app_id, pushkey): + key = "%s:%s" % (app_id, pushkey) + logger.info("Starting pusher %r / %r", user_id, key) + return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id) def start(config_options): @@ -288,7 +257,6 @@ def start(config_options): reactor.run() def start(): - ps.replicate() ps.get_pusherpool().start() ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 34e34e5580..d39e3161fe 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -16,7 +16,7 @@ import synapse -from synapse.api.constants import EventTypes, PresenceState +from synapse.api.constants import EventTypes from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging @@ -40,15 +40,14 @@ from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.client_ips import ClientIpStore from synapse.storage.engines import create_engine from synapse.storage.presence import PresenceStore, UserPresenceState from synapse.storage.roommember import RoomMemberStore -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, preserve_fn, \ - PreserveLoggingContext +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn from synapse.util.manhole import manhole from synapse.util.rlimit import change_resource_limit from synapse.util.stringutils import random_string @@ -63,7 +62,6 @@ import sys import logging import contextlib import gc -import ujson as json logger = logging.getLogger("synapse.app.synchrotron") @@ -107,11 +105,11 @@ UPDATE_SYNCING_USERS_MS = 10 * 1000 class SynchrotronPresence(object): def __init__(self, hs): + self.hs = hs self.is_mine_id = hs.is_mine_id self.http_client = hs.get_simple_http_client() self.store = hs.get_datastore() self.user_to_num_current_syncs = {} - self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -121,17 +119,52 @@ class SynchrotronPresence(object): for state in active_presence } - self.process_id = random_string(16) - logger.info("Presence process_id is %r", self.process_id) + # user_id -> last_sync_ms. Lists the users that have stopped syncing + # but we haven't notified the master of that yet + self.users_going_offline = {} - self._sending_sync = False - self._need_to_send_sync = False - self.clock.looping_call( - self._send_syncing_users_regularly, - UPDATE_SYNCING_USERS_MS, + self._send_stop_syncing_loop = self.clock.looping_call( + self.send_stop_syncing, 10 * 1000 ) - reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) + self.process_id = random_string(16) + logger.info("Presence process_id is %r", self.process_id) + + def send_user_sync(self, user_id, is_syncing, last_sync_ms): + self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms) + + def mark_as_coming_online(self, user_id): + """A user has started syncing. Send a UserSync to the master, unless they + had recently stopped syncing. + + Args: + user_id (str) + """ + going_offline = self.users_going_offline.pop(user_id, None) + if not going_offline: + # Safe to skip because we haven't yet told the master they were offline + self.send_user_sync(user_id, True, self.clock.time_msec()) + + def mark_as_going_offline(self, user_id): + """A user has stopped syncing. We wait before notifying the master as + its likely they'll come back soon. This allows us to avoid sending + a stopped syncing immediately followed by a started syncing notification + to the master + + Args: + user_id (str) + """ + self.users_going_offline[user_id] = self.clock.time_msec() + + def send_stop_syncing(self): + """Check if there are any users who have stopped syncing a while ago + and haven't come back yet. If there are poke the master about them. + """ + now = self.clock.time_msec() + for user_id, last_sync_ms in self.users_going_offline.items(): + if now - last_sync_ms > 10 * 1000: + self.users_going_offline.pop(user_id, None) + self.send_user_sync(user_id, False, last_sync_ms) def set_state(self, user, state, ignore_status_msg=False): # TODO Hows this supposed to work? @@ -142,15 +175,14 @@ class SynchrotronPresence(object): _get_interested_parties = PresenceHandler._get_interested_parties.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__ - @defer.inlineCallbacks def user_syncing(self, user_id, affect_presence): if affect_presence: curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 - prev_states = yield self.current_state_for_users([user_id]) - if prev_states[user_id].state == PresenceState.OFFLINE: - # TODO: Don't block the sync request on this HTTP hit. - yield self._send_syncing_users_now() + + # If we went from no in flight sync to some, notify replication + if self.user_to_num_current_syncs[user_id] == 1: + self.mark_as_coming_online(user_id) def _end(): # We check that the user_id is in user_to_num_current_syncs because @@ -159,6 +191,10 @@ class SynchrotronPresence(object): if affect_presence and user_id in self.user_to_num_current_syncs: self.user_to_num_current_syncs[user_id] -= 1 + # If we went from one in flight sync to non, notify replication + if self.user_to_num_current_syncs[user_id] == 0: + self.mark_as_going_offline(user_id) + @contextlib.contextmanager def _user_syncing(): try: @@ -166,49 +202,7 @@ class SynchrotronPresence(object): finally: _end() - defer.returnValue(_user_syncing()) - - @defer.inlineCallbacks - def _on_shutdown(self): - # When the synchrotron is shutdown tell the master to clear the in - # progress syncs for this process - self.user_to_num_current_syncs.clear() - yield self._send_syncing_users_now() - - def _send_syncing_users_regularly(self): - # Only send an update if we aren't in the middle of sending one. - if not self._sending_sync: - preserve_fn(self._send_syncing_users_now)() - - @defer.inlineCallbacks - def _send_syncing_users_now(self): - if self._sending_sync: - # We don't want to race with sending another update. - # Instead we wait for that update to finish and send another - # update afterwards. - self._need_to_send_sync = True - return - - # Flag that we are sending an update. - self._sending_sync = True - - yield self.http_client.post_json_get_json(self.syncing_users_url, { - "process_id": self.process_id, - "syncing_users": [ - user_id for user_id, count in self.user_to_num_current_syncs.items() - if count > 0 - ], - }) - - # Unset the flag as we are no longer sending an update. - self._sending_sync = False - if self._need_to_send_sync: - # If something happened while we were sending the update then - # we might need to send another update. - # TODO: Check if the update that was sent matches the current state - # as we only need to send an update if they are different. - self._need_to_send_sync = False - yield self._send_syncing_users_now() + return defer.succeed(_user_syncing()) @defer.inlineCallbacks def notify_from_replication(self, states, stream_id): @@ -223,26 +217,24 @@ class SynchrotronPresence(object): ) @defer.inlineCallbacks - def process_replication(self, result): - stream = result.get("presence", {"rows": []}) - states = [] - for row in stream["rows"]: - ( - position, user_id, state, last_active_ts, - last_federation_update_ts, last_user_sync_ts, status_msg, - currently_active - ) = row - state = UserPresenceState( - user_id, state, last_active_ts, - last_federation_update_ts, last_user_sync_ts, status_msg, - currently_active - ) - self.user_to_current_state[user_id] = state - states.append(state) + def process_replication_rows(self, token, rows): + states = [UserPresenceState( + row.user_id, row.state, row.last_active_ts, + row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg, + row.currently_active + ) for row in rows] + + for state in states: + self.user_to_current_state[row.user_id] = state - if states and "position" in stream: - stream_id = int(stream["position"]) - yield self.notify_from_replication(states, stream_id) + stream_id = token + yield self.notify_from_replication(states, stream_id) + + def get_currently_syncing_users(self): + return [ + user_id for user_id, count in self.user_to_num_current_syncs.iteritems() + if count > 0 + ] class SynchrotronTyping(object): @@ -257,16 +249,12 @@ class SynchrotronTyping(object): # value which we *must* use for the next replication request. return {"typing": self._latest_room_serial} - def process_replication(self, result): - stream = result.get("typing") - if stream: - self._latest_room_serial = int(stream["position"]) + def process_replication_rows(self, token, rows): + self._latest_room_serial = token - for row in stream["rows"]: - position, room_id, typing_json = row - typing = json.loads(typing_json) - self._room_serials[room_id] = position - self._room_typing[room_id] = typing + for row in rows: + self._room_serials[row.room_id] = token + self._room_typing[row.room_id] = row.user_ids class SynchrotronApplicationService(object): @@ -351,124 +339,89 @@ class SynchrotronServer(HomeServer): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - notifier = self.get_notifier() - presence_handler = self.get_presence_handler() - typing_handler = self.get_typing_handler() - - def notify_from_stream( - result, stream_name, stream_key, room=None, user=None - ): - stream = result.get(stream_name) - if stream: - position_index = stream["field_names"].index("position") - if room: - room_index = stream["field_names"].index(room) - if user: - user_index = stream["field_names"].index(user) - - users = () - rooms = () - for row in stream["rows"]: - position = row[position_index] - - if user: - users = (row[user_index],) - - if room: - rooms = (row[room_index],) - - notifier.on_new_event( - stream_key, position, users=users, rooms=rooms - ) + self.get_tcp_replication().start_replication(self) - @defer.inlineCallbacks - def notify_device_list_update(result): - stream = result.get("device_lists") - if not stream: - return + def build_tcp_replication(self): + return SyncReplicationHandler(self) - position_index = stream["field_names"].index("position") - user_index = stream["field_names"].index("user_id") + def build_presence_handler(self): + return SynchrotronPresence(self) - for row in stream["rows"]: - position = row[position_index] - user_id = row[user_index] + def build_typing_handler(self): + return SynchrotronTyping(self) - room_ids = yield store.get_rooms_for_user(user_id) - notifier.on_new_event( - "device_list_key", position, rooms=room_ids, - ) +class SyncReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(SyncReplicationHandler, self).__init__(hs.get_datastore()) - @defer.inlineCallbacks - def notify(result): - stream = result.get("events") - if stream: - max_position = stream["position"] - - event_map = yield store.get_events([row[1] for row in stream["rows"]]) - - for row in stream["rows"]: - position = row[0] - event_id = row[1] - event = event_map.get(event_id, None) - if not event: - continue - - extra_users = () - if event.type == EventTypes.Member: - extra_users = (event.state_key,) - notifier.on_new_room_event( - event, position, max_position, extra_users - ) + self.store = hs.get_datastore() + self.typing_handler = hs.get_typing_handler() + self.presence_handler = hs.get_presence_handler() + self.notifier = hs.get_notifier() - notify_from_stream( - result, "push_rules", "push_rules_key", user="user_id" - ) - notify_from_stream( - result, "user_account_data", "account_data_key", user="user_id" - ) - notify_from_stream( - result, "room_account_data", "account_data_key", user="user_id" + self.presence_handler.sync_callback = self.send_user_sync + + def on_rdata(self, stream_name, token, rows): + super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) + + preserve_fn(self.process_and_notify)(stream_name, token, rows) + + def get_streams_to_replicate(self): + args = super(SyncReplicationHandler, self).get_streams_to_replicate() + args.update(self.typing_handler.stream_positions()) + return args + + def get_currently_syncing_users(self): + return self.presence_handler.get_currently_syncing_users() + + @defer.inlineCallbacks + def process_and_notify(self, stream_name, token, rows): + if stream_name == "events": + # We shouldn't get multiple rows per token for events stream, so + # we don't need to optimise this for multiple rows. + for row in rows: + event = yield self.store.get_event(row.event_id) + extra_users = () + if event.type == EventTypes.Member: + extra_users = (event.state_key,) + max_token = self.store.get_room_max_stream_ordering() + self.notifier.on_new_room_event( + event, token, max_token, extra_users + ) + elif stream_name == "push_rules": + self.notifier.on_new_event( + "push_rules_key", token, users=[row.user_id for row in rows], ) - notify_from_stream( - result, "tag_account_data", "account_data_key", user="user_id" + elif stream_name in ("account_data", "tag_account_data",): + self.notifier.on_new_event( + "account_data_key", token, users=[row.user_id for row in rows], ) - notify_from_stream( - result, "receipts", "receipt_key", room="room_id" + elif stream_name == "receipts": + self.notifier.on_new_event( + "receipt_key", token, rooms=[row.room_id for row in rows], ) - notify_from_stream( - result, "typing", "typing_key", room="room_id" + elif stream_name == "typing": + self.typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows], ) - notify_from_stream( - result, "to_device", "to_device_key", user="user_id" + elif stream_name == "to_device": + entities = [row.entity for row in rows if row.entity.startswith("@")] + if entities: + self.notifier.on_new_event( + "to_device_key", token, users=entities, + ) + elif stream_name == "device_lists": + all_room_ids = set() + for row in rows: + room_ids = yield self.store.get_rooms_for_user(row.user_id) + all_room_ids.update(room_ids) + self.notifier.on_new_event( + "device_list_key", token, rooms=all_room_ids, ) - yield notify_device_list_update(result) - - while True: - try: - args = store.stream_positions() - args.update(typing_handler.stream_positions()) - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - typing_handler.process_replication(result) - yield presence_handler.process_replication(result) - yield notify(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) - - def build_presence_handler(self): - return SynchrotronPresence(self) - - def build_typing_handler(self): - return SynchrotronTyping(self) + elif stream_name == "presence": + yield self.presence_handler.process_replication_rows(token, rows) def start(config_options): @@ -514,7 +467,6 @@ def start(config_options): def start(): ss.get_datastore().start_profiling() - ss.replicate() ss.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index e8218d01ad..8223734845 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -125,7 +125,7 @@ def main(): "configfile", nargs="?", default="homeserver.yaml", - help="the homeserver config file, defaults to homserver.yaml", + help="the homeserver config file, defaults to homeserver.yaml", ) parser.add_argument( "-w", "--worker", diff --git a/synapse/config/workers.py b/synapse/config/workers.py index b165c67ee7..ea48d931a1 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -28,7 +28,9 @@ class WorkerConfig(Config): self.worker_pid_file = config.get("worker_pid_file") self.worker_log_file = config.get("worker_log_file") self.worker_log_config = config.get("worker_log_config") - self.worker_replication_url = config.get("worker_replication_url") + self.worker_replication_host = config.get("worker_replication_host", None) + self.worker_replication_port = config.get("worker_replication_port", None) + self.worker_name = config.get("worker_name", self.worker_app) if self.worker_listeners: for listener in self.worker_listeners: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 4bde66fbf8..748548bbe2 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -31,21 +31,19 @@ Events are replicated via a separate events stream. from .units import Edu +from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure import synapse.metrics from blist import sorteddict -import ujson +from collections import namedtuple +import logging -metrics = synapse.metrics.get_metrics_for(__name__) +logger = logging.getLogger(__name__) -PRESENCE_TYPE = "p" -KEYED_EDU_TYPE = "k" -EDU_TYPE = "e" -FAILURE_TYPE = "f" -DEVICE_MESSAGE_TYPE = "d" +metrics = synapse.metrics.get_metrics_for(__name__) class FederationRemoteSendQueue(object): @@ -240,6 +238,8 @@ class FederationRemoteSendQueue(object): if from_token > self.pos: from_token = -1 + # list of tuple(int, BaseFederationRow), where the first is the position + # of the federation stream. rows = [] # There should be only one reader, so lets delete everything its @@ -258,10 +258,10 @@ class FederationRemoteSendQueue(object): ) for (key, (dest, user_id)) in dest_user_ids: - rows.append((key, PRESENCE_TYPE, ujson.dumps({ - "destination": dest, - "state": self.presence_map[user_id].as_dict(), - }))) + rows.append((key, PresenceRow( + destination=dest, + state=self.presence_map[user_id], + ))) # Fetch changes keyed edus keys = self.keyed_edu_changed.keys() @@ -270,12 +270,10 @@ class FederationRemoteSendQueue(object): keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:j]) for (pos, (destination, edu_key)) in keyed_edus: - rows.append( - (pos, KEYED_EDU_TYPE, ujson.dumps({ - "key": edu_key, - "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(), - })) - ) + rows.append((pos, KeyedEduRow( + key=edu_key, + edu=self.keyed_edu[(destination, edu_key)], + ))) # Fetch changed edus keys = self.edus.keys() @@ -284,7 +282,7 @@ class FederationRemoteSendQueue(object): edus = set((k, self.edus[k]) for k in keys[i:j]) for (pos, edu) in edus: - rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict()))) + rows.append((pos, EduRow(edu))) # Fetch changed failures keys = self.failures.keys() @@ -293,10 +291,10 @@ class FederationRemoteSendQueue(object): failures = set((k, self.failures[k]) for k in keys[i:j]) for (pos, (destination, failure)) in failures: - rows.append((pos, FAILURE_TYPE, ujson.dumps({ - "destination": destination, - "failure": failure, - }))) + rows.append((pos, FailureRow( + destination=destination, + failure=failure, + ))) # Fetch changed device messages keys = self.device_messages.keys() @@ -305,11 +303,229 @@ class FederationRemoteSendQueue(object): device_messages = set((k, self.device_messages[k]) for k in keys[i:j]) for (pos, destination) in device_messages: - rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({ - "destination": destination, - }))) + rows.append((pos, DeviceRow( + destination=destination, + ))) # Sort rows based on pos rows.sort() - return rows + return [(pos, row.TypeId, row.to_data()) for pos, row in rows] + + +class BaseFederationRow(object): + """Base class for rows to be sent in the federation stream. + + Specifies how to identify, serialize and deserialize the different types. + """ + + TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + + @staticmethod + def from_data(data): + """Parse the data from the federation stream into a row. + + Args: + data: The value of ``data`` from FederationStreamRow.data, type + depends on the type of stream + """ + raise NotImplementedError() + + def to_data(self): + """Serialize this row to be sent over the federation stream. + + Returns: + The value to be sent in FederationStreamRow.data. The type depends + on the type of stream. + """ + raise NotImplementedError() + + def add_to_buffer(self, buff): + """Add this row to the appropriate field in the buffer ready for this + to be sent over federation. + + We use a buffer so that we can batch up events that have come in at + the same time and send them all at once. + + Args: + buff (BufferedToSend) + """ + raise NotImplementedError() + + +class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( + "destination", # str + "state", # UserPresenceState +))): + TypeId = "p" + + @staticmethod + def from_data(data): + return PresenceRow( + destination=data["destination"], + state=UserPresenceState.from_dict(data["state"]) + ) + + def to_data(self): + return { + "destination": self.destination, + "state": self.state.as_dict() + } + + def add_to_buffer(self, buff): + buff.presence.setdefault(self.destination, []).append(self.state) + + +class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( + "key", # tuple(str) - the edu key passed to send_edu + "edu", # Edu +))): + TypeId = "k" + + @staticmethod + def from_data(data): + return KeyedEduRow( + key=tuple(data["key"]), + edu=Edu(**data["edu"]), + ) + + def to_data(self): + return { + "key": self.key, + "edu": self.edu.get_internal_dict(), + } + + def add_to_buffer(self, buff): + buff.keyed_edus.setdefault( + self.edu.destination, {} + )[self.key] = self.edu + + +class EduRow(BaseFederationRow, namedtuple("EduRow", ( + "edu", # Edu +))): + TypeId = "e" + + @staticmethod + def from_data(data): + return EduRow(Edu(**data)) + + def to_data(self): + return self.edu.get_internal_dict() + + def add_to_buffer(self, buff): + buff.edus.setdefault(self.edu.destination, []).append(self.edu) + + +class FailureRow(BaseFederationRow, namedtuple("FailureRow", ( + "destination", # str + "failure", +))): + TypeId = "f" + + @staticmethod + def from_data(data): + return FailureRow( + destination=data["destination"], + failure=data["failure"], + ) + + def to_data(self): + return { + "destination": self.destination, + "failure": self.failure, + } + + def add_to_buffer(self, buff): + buff.failures.setdefault(self.destination, []).append(self.failure) + + +class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( + "destination", # str +))): + TypeId = "d" + + @staticmethod + def from_data(data): + return DeviceRow(destination=data["destination"]) + + def to_data(self): + return {"destination": self.destination} + + def add_to_buffer(self, buff): + buff.device_destinations.add(self.destination) + + +TypeToRow = { + Row.TypeId: Row + for Row in ( + PresenceRow, + KeyedEduRow, + EduRow, + FailureRow, + DeviceRow, + ) +} + + +ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( + "presence", # dict of destination -> [UserPresenceState] + "keyed_edus", # dict of destination -> { key -> Edu } + "edus", # dict of destination -> [Edu] + "failures", # dict of destination -> [failures] + "device_destinations", # set of destinations +)) + + +def process_rows_for_federation(transaction_queue, rows): + """Parse a list of rows from the federation stream and put them in the + transaction queue ready for sending to the relevant homeservers. + + Args: + transaction_queue (TransactionQueue) + rows (list(synapse.replication.tcp.streams.FederationStreamRow)) + """ + + # The federation stream contains a bunch of different types of + # rows that need to be handled differently. We parse the rows, put + # them into the appropriate collection and then send them off. + + buff = ParsedFederationStreamData( + presence={}, + keyed_edus={}, + edus={}, + failures={}, + device_destinations=set(), + ) + + # Parse the rows in the stream and add to the buffer + for row in rows: + if row.type not in TypeToRow: + logger.error("Unrecognized federation row type %r", row.type) + continue + + RowType = TypeToRow[row.type] + parsed_row = RowType.from_data(row.data) + parsed_row.add_to_buffer(buff) + + for destination, states in buff.presence.iteritems(): + transaction_queue.send_presence(destination, states) + + for destination, edu_map in buff.keyed_edus.iteritems(): + for key, edu in edu_map.items(): + transaction_queue.send_edu( + edu.destination, edu.edu_type, edu.content, key=key, + ) + + for destination, edu_list in buff.edus.iteritems(): + for edu in edu_list: + transaction_queue.send_edu( + edu.destination, edu.edu_type, edu.content, key=None, + ) + + for destination, failure_list in buff.failures.iteritems(): + for failure in failure_list: + transaction_queue.send_failure(destination, failure) + + for destination in buff.device_destinations: + transaction_queue.send_device_messages(destination) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 53baf3e79a..9ed5af3cb4 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -596,14 +596,14 @@ class PresenceHandler(object): for user_id in user_ids } - missing = [user_id for user_id, state in states.items() if not state] + missing = [user_id for user_id, state in states.iteritems() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = yield self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in states.items() if not state] + missing = [user_id for user_id, state in states.iteritems() if not state] if missing: new = { user_id: UserPresenceState.default(user_id) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d6809862e0..3b7818af5c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -24,7 +24,6 @@ from synapse.types import UserID, get_domain_from_id import logging from collections import namedtuple -import ujson as json logger = logging.getLogger(__name__) @@ -288,8 +287,7 @@ class TypingHandler(object): for room_id, serial in self._room_serials.items(): if last_id < serial and serial <= current_id: typing = self._room_typing[room_id] - typing_bytes = json.dumps(list(typing), ensure_ascii=False) - rows.append((serial, room_id, typing_bytes)) + rows.append((serial, room_id, list(typing))) rows.sort() return rows diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index ab133db872..b962641166 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -15,7 +15,6 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine -from twisted.internet import defer from ._slaved_id_tracker import SlavedIdTracker @@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore): else: self._cache_id_gen = None - self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache" - self.http_client = hs.get_simple_http_client() + self.hs = hs def stream_positions(self): pos = {} @@ -43,35 +41,20 @@ class BaseSlavedStore(SQLBaseStore): pos["caches"] = self._cache_id_gen.get_current_token() return pos - def process_replication(self, result): - stream = result.get("caches") - if stream: - for row in stream["rows"]: - ( - position, cache_func, keys, invalidation_ts, - ) = row - + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "caches": + self._cache_id_gen.advance(token) + for row in rows: try: - getattr(self, cache_func).invalidate(tuple(keys)) + getattr(self, row.cache_func).invalidate(tuple(row.keys)) except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. pass - self._cache_id_gen.advance(int(stream["position"])) - return defer.succeed(None) def _invalidate_cache_and_stream(self, txn, cache_func, keys): txn.call_after(cache_func.invalidate, keys) txn.call_after(self._send_invalidation_poke, cache_func, keys) - @defer.inlineCallbacks def _send_invalidation_poke(self, cache_func, keys): - try: - yield self.http_client.post_json_get_json(self.expire_cache_url, { - "invalidate": [{ - "name": cache_func.__name__, - "keys": list(keys), - }] - }) - except: - logger.exception("Failed to poke on expire_cache") + self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 77c64722c7..efbd87918e 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore): result["tag_account_data"] = position return result - def process_replication(self, result): - stream = result.get("user_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id, data_type = row[:3] - self.get_global_account_data_by_type_for_user.invalidate( - (data_type, user_id,) - ) - self.get_account_data_for_user.invalidate((user_id,)) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "tag_account_data": + self._account_data_id_gen.advance(token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - stream = result.get("room_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] - self.get_account_data_for_user.invalidate((user_id,)) + elif stream_name == "account_data": + self._account_data_id_gen.advance(token) + for row in rows: + if not row.room_id: + self.get_global_account_data_by_type_for_user.invalidate( + (row.data_type, row.user_id,) + ) + self.get_account_data_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - stream = result.get("tag_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] - self.get_tags_for_user.invalidate((user_id,)) - self._account_data_stream_cache.entity_has_changed( - user_id, position - ) - - return super(SlavedAccountDataStore, self).process_replication(result) + return super(SlavedAccountDataStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index f9102e0d89..6f3fb64770 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -53,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore): result["to_device"] = self._device_inbox_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("to_device") - if stream: - self._device_inbox_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - stream_id = row[0] - entity = row[1] - - if entity.startswith("@"): + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "to_device": + self._device_inbox_id_gen.advance(token) + for row in rows: + if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( - entity, stream_id + row.entity, token ) else: self._device_federation_outbox_stream_cache.entity_has_changed( - entity, stream_id + row.entity, token ) - - return super(SlavedDeviceInboxStore, self).process_replication(result) + return super(SlavedDeviceInboxStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index ca46aa17b6..4d4a435471 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -51,22 +51,18 @@ class SlavedDeviceStore(BaseSlavedStore): result["device_lists"] = self._device_list_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("device_lists") - if stream: - self._device_list_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - stream_id = row[0] - user_id = row[1] - destination = row[2] - + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "device_lists": + self._device_list_id_gen.advance(token) + for row in rows: self._device_list_stream_cache.entity_has_changed( - user_id, stream_id + row.user_id, token ) - if destination: + if row.destination: self._device_list_federation_stream_cache.entity_has_changed( - destination, stream_id + row.destination, token ) - - return super(SlavedDeviceStore, self).process_replication(result) + return super(SlavedDeviceStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d4db1e452e..5fd47706ef 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -201,48 +201,25 @@ class SlavedEventStore(BaseSlavedStore): result["backfill"] = -self._backfill_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("events") - if stream: - self._stream_id_gen.advance(int(stream["position"])) - - if stream["rows"]: - logger.info("Got %d event rows", len(stream["rows"])) - - for row in stream["rows"]: - self._process_replication_row( - row, backfilled=False, + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "events": + self._stream_id_gen.advance(token) + for row in rows: + self.invalidate_caches_for_event( + token, row.event_id, row.room_id, row.type, row.state_key, + row.redacts, + backfilled=False, ) - - stream = result.get("backfill") - if stream: - self._backfill_id_gen.advance(-int(stream["position"])) - for row in stream["rows"]: - self._process_replication_row( - row, backfilled=True, + elif stream_name == "backfill": + self._backfill_id_gen.advance(-token) + for row in rows: + self.invalidate_caches_for_event( + -token, row.event_id, row.room_id, row.type, row.state_key, + row.redacts, + backfilled=True, ) - - stream = result.get("forward_ex_outliers") - if stream: - self._stream_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - event_id = row[1] - self._invalidate_get_event_cache(event_id) - - stream = result.get("backward_ex_outliers") - if stream: - self._backfill_id_gen.advance(-int(stream["position"])) - for row in stream["rows"]: - event_id = row[1] - self._invalidate_get_event_cache(event_id) - - return super(SlavedEventStore, self).process_replication(result) - - def _process_replication_row(self, row, backfilled): - stream_ordering = row[0] if not backfilled else -row[0] - self.invalidate_caches_for_event( - stream_ordering, row[1], row[2], row[3], row[4], row[5], - backfilled=backfilled, + return super(SlavedEventStore, self).process_replication_rows( + stream_name, token, rows ) def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index e4a2414d78..dffc80adc3 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -48,15 +48,14 @@ class SlavedPresenceStore(BaseSlavedStore): result["presence"] = position return result - def process_replication(self, result): - stream = result.get("presence") - if stream: - self._presence_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "presence": + self._presence_id_gen.advance(token) + for row in rows: self.presence_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - self._get_presence_for_user.invalidate((user_id,)) - - return super(SlavedPresenceStore, self).process_replication(result) + self._get_presence_for_user.invalidate((row.user_id,)) + return super(SlavedPresenceStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 21ceb0213a..83e880fdd2 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -50,18 +50,15 @@ class SlavedPushRuleStore(SlavedEventStore): result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("push_rules") - if stream: - for row in stream["rows"]: - position = row[0] - user_id = row[2] - self.get_push_rules_for_user.invalidate((user_id,)) - self.get_push_rules_enabled_for_user.invalidate((user_id,)) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "push_rules": + self._push_rules_stream_id_gen.advance(token) + for row in rows: + self.get_push_rules_for_user.invalidate((row.user_id,)) + self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - self._push_rules_stream_id_gen.advance(int(stream["position"])) - - return super(SlavedPushRuleStore, self).process_replication(result) + return super(SlavedPushRuleStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index d88206b3bb..4e8d68ece9 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -40,13 +40,9 @@ class SlavedPusherStore(BaseSlavedStore): result["pushers"] = self._pushers_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("pushers") - if stream: - self._pushers_id_gen.advance(int(stream["position"])) - - stream = result.get("deleted_pushers") - if stream: - self._pushers_id_gen.advance(int(stream["position"])) - - return super(SlavedPusherStore, self).process_replication(result) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "pushers": + self._pushers_id_gen.advance(token) + return super(SlavedPusherStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index ac9662d399..b371574ece 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -65,20 +65,22 @@ class SlavedReceiptsStore(BaseSlavedStore): result["receipts"] = self._receipts_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("receipts") - if stream: - self._receipts_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, room_id, receipt_type, user_id = row[:4] - self.invalidate_caches_for_receipt(room_id, receipt_type, user_id) - self._receipts_stream_cache.entity_has_changed(room_id, position) - - return super(SlavedReceiptsStore, self).process_replication(result) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_linearized_receipts_for_room.invalidate_many((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) + + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "receipts": + self._receipts_id_gen.advance(token) + for row in rows: + self.invalidate_caches_for_receipt( + row.room_id, row.receipt_type, row.user_id + ) + self._receipts_stream_cache.entity_has_changed(row.room_id, token) + + return super(SlavedReceiptsStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 6df9a25ef3..f510384033 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -46,9 +46,10 @@ class RoomStore(BaseSlavedStore): result["public_rooms"] = self._public_room_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("public_rooms") - if stream: - self._public_room_id_gen.advance(int(stream["position"])) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "public_rooms": + self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication(result) + return super(RoomStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py new file mode 100644 index 0000000000..90fb6c1336 --- /dev/null +++ b/synapse/replication/tcp/client.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A replication client for use by synapse workers. +""" + +from twisted.internet import reactor, defer +from twisted.internet.protocol import ReconnectingClientFactory + +from .commands import ( + FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand, +) +from .protocol import ClientReplicationStreamProtocol + +import logging + +logger = logging.getLogger(__name__) + + +class ReplicationClientFactory(ReconnectingClientFactory): + """Factory for building connections to the master. Will reconnect if the + connection is lost. + + Accepts a handler that will be called when new data is available or data + is required. + """ + maxDelay = 5 # Try at least once every N seconds + + def __init__(self, hs, client_name, handler): + self.client_name = client_name + self.handler = handler + self.server_name = hs.config.server_name + self._clock = hs.get_clock() # As self.clock is defined in super class + + reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying) + + def startedConnecting(self, connector): + logger.info("Connecting to replication: %r", connector.getDestination()) + + def buildProtocol(self, addr): + logger.info("Connected to replication: %r", addr) + self.resetDelay() + return ClientReplicationStreamProtocol( + self.client_name, self.server_name, self._clock, self.handler + ) + + def clientConnectionLost(self, connector, reason): + logger.error("Lost replication conn: %r", reason) + ReconnectingClientFactory.clientConnectionLost(self, connector, reason) + + def clientConnectionFailed(self, connector, reason): + logger.error("Failed to connect to replication: %r", reason) + ReconnectingClientFactory.clientConnectionFailed( + self, connector, reason + ) + + +class ReplicationClientHandler(object): + """A base handler that can be passed to the ReplicationClientFactory. + + By default proxies incoming replication data to the SlaveStore. + """ + def __init__(self, store): + self.store = store + + # The current connection. None if we are currently (re)connecting + self.connection = None + + # Any pending commands to be sent once a new connection has been + # established + self.pending_commands = [] + + # Map from string -> deferred, to wake up when receiveing a SYNC with + # the given string. + # Used for tests. + self.awaiting_syncs = {} + + def start_replication(self, hs): + """Helper method to start a replication connection to the remote server + using TCP. + """ + client_name = hs.config.worker_name + factory = ReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + reactor.connectTCP(host, port, factory) + + def on_rdata(self, stream_name, token, rows): + """Called when we get new replication data. By default this just pokes + the slave store. + + Can be overriden in subclasses to handle more. + """ + logger.info("Received rdata %s -> %s", stream_name, token) + self.store.process_replication_rows(stream_name, token, rows) + + def on_position(self, stream_name, token): + """Called when we get new position data. By default this just pokes + the slave store. + + Can be overriden in subclasses to handle more. + """ + self.store.process_replication_rows(stream_name, token, []) + + def on_sync(self, data): + """When we received a SYNC we wake up any deferreds that were waiting + for the sync with the given data. + + Used by tests. + """ + d = self.awaiting_syncs.pop(data, None) + if d: + d.callback(data) + + def get_streams_to_replicate(self): + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns a dictionary of stream name to token. + """ + args = self.store.stream_positions() + user_account_data = args.pop("user_account_data", None) + room_account_data = args.pop("room_account_data", None) + if user_account_data: + args["account_data"] = user_account_data + elif room_account_data: + args["account_data"] = room_account_data + return args + + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users. (Overriden by the synchrotron's only) + """ + return [] + + def send_command(self, cmd): + """Send a command to master (when we get establish a connection if we + don't have one already.) + """ + if self.connection: + self.connection.send_command(cmd) + else: + logger.warn("Queuing command as not connected: %r", cmd.NAME) + self.pending_commands.append(cmd) + + def send_federation_ack(self, token): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync(self, user_id, is_syncing, last_sync_ms): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms)) + + def send_remove_pusher(self, app_id, push_key, user_id): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_invalidate_cache(self, cache_func, keys): + """Poke the master to invalidate a cache. + """ + cmd = InvalidateCacheCommand(cache_func.__name__, keys) + self.send_command(cmd) + + def await_sync(self, data): + """Returns a deferred that is resolved when we receive a SYNC command + with given data. + + Used by tests. + """ + return self.awaiting_syncs.setdefault(data, defer.Deferred()) + + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + self.connection = connection + if connection: + for cmd in self.pending_commands: + connection.send_command(cmd) + self.pending_commands = [] diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 6864204616..5770b7125a 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -51,6 +51,7 @@ indicate which side is sending, these are *not* included on the wire:: from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver +from twisted.python.failure import Failure from commands import ( COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, @@ -60,6 +61,7 @@ from commands import ( from streams import STREAMS_MAP from synapse.util.stringutils import random_string +from synapse.metrics.metric import CounterMetric import logging import synapse.metrics @@ -69,11 +71,8 @@ import fcntl metrics = synapse.metrics.get_metrics_for(__name__) -inbound_commands_counter = metrics.register_counter( - "inbound_commands", labels=["command", "name", "conn_id"], -) -outbound_commands_counter = metrics.register_counter( - "outbound_commands", labels=["command", "name", "conn_id"], +connection_close_counter = metrics.register_counter( + "close_reason", labels=["reason_type"], ) @@ -86,6 +85,8 @@ logger = logging.getLogger(__name__) PING_TIME = 5000 +PING_TIMEOUT_MULTIPLIER = 5 +PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER class ConnectionStates(object): @@ -135,6 +136,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None + self.inbound_commands_counter = CounterMetric( + "inbound_commands", labels=["command"], + ) + self.outbound_commands_counter = CounterMetric( + "outbound_commands", labels=["command"], + ) + def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -160,7 +168,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): now = self.clock.time_msec() if self.time_we_closed: - if now - self.time_we_closed > PING_TIME * 3: + if now - self.time_we_closed > PING_TIMEOUT_MS: logger.info( "[%s] Failed to close connection gracefully, aborting", self.id() ) @@ -169,7 +177,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if now - self.last_sent_command >= PING_TIME: self.send_command(PingCommand(now)) - if self.received_ping and now - self.last_received_command > PING_TIME * 3: + if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS: logger.info( "[%s] Connection hasn't received command in %r ms. Closing.", self.id(), now - self.last_received_command @@ -193,7 +201,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() - inbound_commands_counter.inc(cmd_name, self.name, self.conn_id) + self.inbound_commands_counter.inc(cmd_name) cmd_cls = COMMAND_MAP[cmd_name] try: @@ -214,6 +222,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): logger.exception("[%s] Failed to handle line: %r", self.id(), line) def close(self): + logger.warn("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() self.transport.loseConnection() self.on_connection_closed() @@ -242,7 +251,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self._queue_command(cmd) return - outbound_commands_counter.inc(cmd.NAME, self.name, self.conn_id) + self.outbound_commands_counter.inc(cmd.NAME) string = "%s %s" % (cmd.NAME, cmd.to_line(),) if "\n" in string: @@ -307,6 +316,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def connectionLost(self, reason): logger.info("[%s] Replication connection closed: %r", self.id(), reason) + if isinstance(reason, Failure): + connection_close_counter.inc(reason.type.__name__) + else: + connection_close_counter.inc(reason.__class__.__name__) try: # Remove us from list of connections to be monitored @@ -495,7 +508,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) - self.transport.abortConnection() + self.send_error("Wrong remote") def on_RDATA(self, cmd): try: @@ -604,3 +617,24 @@ metrics.register_callback( }, labels=["name", "conn_id"], ) + + +metrics.register_callback( + "inbound_commands", + lambda: { + (k[0], p.name, p.conn_id): count + for p in connected_connections + for k, count in p.inbound_commands_counter.counts.iteritems() + }, + labels=["command", "name", "conn_id"], +) + +metrics.register_callback( + "outbound_commands", + lambda: { + (k[0], p.name, p.conn_id): count + for p in connected_connections + for k, count in p.outbound_commands_counter.counts.iteritems() + }, + labels=["command", "name", "conn_id"], +) diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py index 4de4ebe84d..369d5f2428 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams.py @@ -36,34 +36,82 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 10000 -EventStreamRow = namedtuple("EventStreamRow", - ("event_id", "room_id", "type", "state_key", "redacts")) -BackfillStreamRow = namedtuple("BackfillStreamRow", - ("event_id", "room_id", "type", "state_key", "redacts")) -PresenceStreamRow = namedtuple("PresenceStreamRow", - ("user_id", "state", "last_active_ts", - "last_federation_update_ts", "last_user_sync_ts", - "status_msg", "currently_active")) -TypingStreamRow = namedtuple("TypingStreamRow", - ("room_id", "user_ids")) -ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", - ("room_id", "receipt_type", "user_id", "event_id", - "data")) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) -PushersStreamRow = namedtuple("PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted",)) -CachesStreamRow = namedtuple("CachesStreamRow", - ("cache_func", "keys", "invalidation_ts",)) -PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", - ("room_id", "visibility", "appservice_id", - "network_id",)) -DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ("user_id", "destination",)) -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) -FederationStreamRow = namedtuple("FederationStreamRow", ("type", "data",)) -TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", - ("user_id", "room_id", "data")) -AccountDataStreamRow = namedtuple("AccountDataStream", - ("user_id", "room_id", "data_type", "data")) +EventStreamRow = namedtuple("EventStreamRow", ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional +)) +BackfillStreamRow = namedtuple("BackfillStreamRow", ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional +)) +PresenceStreamRow = namedtuple("PresenceStreamRow", ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool +)) +TypingStreamRow = namedtuple("TypingStreamRow", ( + "room_id", # str + "user_ids", # list(str) +)) +ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict +)) +PushRulesStreamRow = namedtuple("PushRulesStreamRow", ( + "user_id", # str +)) +PushersStreamRow = namedtuple("PushersStreamRow", ( + "user_id", # str + "app_id", # str + "pushkey", # str + "deleted", # bool +)) +CachesStreamRow = namedtuple("CachesStreamRow", ( + "cache_func", # str + "keys", # list(str) + "invalidation_ts", # int +)) +PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional +)) +DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ( + "user_id", # str + "destination", # str +)) +ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ( + "entity", # str +)) +FederationStreamRow = namedtuple("FederationStreamRow", ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow +)) +TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", ( + "user_id", # str + "room_id", # str + "data", # dict +)) +AccountDataStreamRow = namedtuple("AccountDataStream", ( + "user_id", # str + "room_id", # str + "data_type", # str + "data", # dict +)) class Stream(object): diff --git a/synapse/server.py b/synapse/server.py index c577032041..6310152560 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -132,6 +132,7 @@ class HomeServer(object): 'federation_sender', 'receipts_handler', 'macaroon_generator', + 'tcp_replication', ] def __init__(self, hostname, **kwargs): @@ -290,6 +291,9 @@ class HomeServer(object): def build_receipts_handler(self): return ReceiptsHandler(self) + def build_tcp_replication(self): + raise NotImplementedError() + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 2714519d21..0b62b493d5 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -325,23 +325,26 @@ class DeviceInboxStore(BackgroundUpdateStore): # we return. upper_pos = min(current_pos, last_pos + limit) sql = ( - "SELECT stream_id, user_id" + "SELECT max(stream_id), user_id" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" + " GROUP BY user_id" ) txn.execute(sql, (last_pos, upper_pos)) rows = txn.fetchall() sql = ( - "SELECT stream_id, destination" + "SELECT max(stream_id), destination" " FROM device_federation_outbox" " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" + " GROUP BY destination" ) txn.execute(sql, (last_pos, upper_pos)) rows.extend(txn) + # Order by ascending stream ordering + rows.sort() + return rows return self.runInteraction( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index fb23f6f462..acd69944c4 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,7 +14,7 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches import intern_string from synapse.storage.engines import PostgresEngine @@ -69,17 +69,33 @@ class StateStore(SQLBaseStore): where_clause="type='m.room.member'", ) - @cachedInlineCallbacks(max_entries=100000, iterable=True) + @cached(max_entries=100000, iterable=True) def get_current_state_ids(self, room_id): - rows = yield self._simple_select_list( - table="current_state_events", - keyvalues={"room_id": room_id}, - retcols=["event_id", "type", "state_key"], - desc="_calculate_state_delta", + """Get the current state event ids for a room based on the + current_state_events table. + + Args: + room_id (str) + + Returns: + deferred: dict of (type, state_key) -> event_id + """ + def _get_current_state_ids_txn(txn): + txn.execute( + """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """, + (room_id,) + ) + + return { + (r[0], r[1]): r[2] for r in txn + } + + return self.runInteraction( + "get_current_state_ids", + _get_current_state_ids_txn, ) - defer.returnValue({ - (r["type"], r["state_key"]): r["event_id"] for r in rows - }) @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index b82868054d..81063f19a1 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from twisted.internet import defer, reactor from tests import unittest from mock import Mock, NonCallableMock from tests.utils import setup_test_homeserver -from synapse.replication.resource import ReplicationResource +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.replication.tcp.client import ( + ReplicationClientHandler, ReplicationClientFactory, +) class BaseSlavedStoreTestCase(unittest.TestCase): @@ -33,18 +36,29 @@ class BaseSlavedStoreTestCase(unittest.TestCase): ) self.hs.get_ratelimiter().send_message.return_value = (True, 0) - self.replication = ReplicationResource(self.hs) - self.master_store = self.hs.get_datastore() self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 + server_factory = ReplicationStreamProtocolFactory(self.hs) + listener = reactor.listenUNIX("\0xxx", server_factory) + self.addCleanup(listener.stopListening) + self.streamer = server_factory.streamer + + self.replication_handler = ReplicationClientHandler(self.slaved_store) + client_factory = ReplicationClientFactory( + self.hs, "client_name", self.replication_handler + ) + client_connector = reactor.connectUNIX("\0xxx", client_factory) + self.addCleanup(client_factory.stopTrying) + self.addCleanup(client_connector.disconnect) + @defer.inlineCallbacks def replicate(self): - streams = self.slaved_store.stream_positions() - writer = yield self.replication.replicate(streams, 100) - result = writer.finish() - yield self.slaved_store.process_replication(result) + yield self.streamer.on_notifier_poke() + d = self.replication_handler.await_sync("replication_test") + self.streamer.send_sync_to_all_connections("replication_test") + yield d @defer.inlineCallbacks def check(self, method, args, expected_result=None): |