summary refs log tree commit diff
path: root/synapse/app/pusher.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/app/pusher.py')
-rw-r--r--synapse/app/pusher.py133
1 files changed, 53 insertions, 80 deletions
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 073f2c2489..f9114acfcb 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -27,11 +27,12 @@ from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.pushers import SlavedPusherStore
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
+from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.storage.engines import create_engine
 from synapse.storage import DataStore
-from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, preserve_fn, \
+    PreserveLoggingContext
 from synapse.util.manhole import manhole
 from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
@@ -88,7 +89,6 @@ class PusherSlaveStore(
 
 
 class PusherServer(HomeServer):
-
     def get_db_conn(self, run_new_connection=True):
         # Any param beginning with cp_ is a parameter for adbapi, and should
         # not be passed to the database engine.
@@ -108,16 +108,7 @@ class PusherServer(HomeServer):
         logger.info("Finished setting up.")
 
     def remove_pusher(self, app_id, push_key, user_id):
-        http_client = self.get_simple_http_client()
-        replication_url = self.config.worker_replication_url
-        url = replication_url + "/remove_pushers"
-        return http_client.post_json_get_json(url, {
-            "remove": [{
-                "app_id": app_id,
-                "push_key": push_key,
-                "user_id": user_id,
-            }]
-        })
+        self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
 
     def _listen_http(self, listener_config):
         port = listener_config["port"]
@@ -165,73 +156,52 @@ class PusherServer(HomeServer):
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
+        self.get_tcp_replication().start_replication(self)
+
+    def build_tcp_replication(self):
+        return PusherReplicationHandler(self)
+
+
+class PusherReplicationHandler(ReplicationClientHandler):
+    def __init__(self, hs):
+        super(PusherReplicationHandler, self).__init__(hs.get_datastore())
+
+        self.pusher_pool = hs.get_pusherpool()
+
+    def on_rdata(self, stream_name, token, rows):
+        super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
+        preserve_fn(self.poke_pushers)(stream_name, token, rows)
+
     @defer.inlineCallbacks
-    def replicate(self):
-        http_client = self.get_simple_http_client()
-        store = self.get_datastore()
-        replication_url = self.config.worker_replication_url
-        pusher_pool = self.get_pusherpool()
-
-        def stop_pusher(user_id, app_id, pushkey):
-            key = "%s:%s" % (app_id, pushkey)
-            pushers_for_user = pusher_pool.pushers.get(user_id, {})
-            pusher = pushers_for_user.pop(key, None)
-            if pusher is None:
-                return
-            logger.info("Stopping pusher %r / %r", user_id, key)
-            pusher.on_stop()
-
-        def start_pusher(user_id, app_id, pushkey):
-            key = "%s:%s" % (app_id, pushkey)
-            logger.info("Starting pusher %r / %r", user_id, key)
-            return pusher_pool._refresh_pusher(app_id, pushkey, user_id)
-
-        @defer.inlineCallbacks
-        def poke_pushers(results):
-            pushers_rows = set(
-                map(tuple, results.get("pushers", {}).get("rows", []))
+    def poke_pushers(self, stream_name, token, rows):
+        if stream_name == "pushers":
+            for row in rows:
+                if row.deleted:
+                    yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+                else:
+                    yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
+        elif stream_name == "events":
+            yield self.pusher_pool.on_new_notifications(
+                token, token,
             )
-            deleted_pushers_rows = set(
-                map(tuple, results.get("deleted_pushers", {}).get("rows", []))
+        elif stream_name == "receipts":
+            yield self.pusher_pool.on_new_receipts(
+                token, token, set(row.room_id for row in rows)
             )
-            for row in sorted(pushers_rows | deleted_pushers_rows):
-                if row in deleted_pushers_rows:
-                    user_id, app_id, pushkey = row[1:4]
-                    stop_pusher(user_id, app_id, pushkey)
-                elif row in pushers_rows:
-                    user_id = row[1]
-                    app_id = row[5]
-                    pushkey = row[8]
-                    yield start_pusher(user_id, app_id, pushkey)
-
-            stream = results.get("events")
-            if stream and stream["rows"]:
-                min_stream_id = stream["rows"][0][0]
-                max_stream_id = stream["position"]
-                preserve_fn(pusher_pool.on_new_notifications)(
-                    min_stream_id, max_stream_id
-                )
-
-            stream = results.get("receipts")
-            if stream and stream["rows"]:
-                rows = stream["rows"]
-                affected_room_ids = set(row[1] for row in rows)
-                min_stream_id = rows[0][0]
-                max_stream_id = stream["position"]
-                preserve_fn(pusher_pool.on_new_receipts)(
-                    min_stream_id, max_stream_id, affected_room_ids
-                )
-
-        while True:
-            try:
-                args = store.stream_positions()
-                args["timeout"] = 30000
-                result = yield http_client.get_json(replication_url, args=args)
-                yield store.process_replication(result)
-                poke_pushers(result)
-            except:
-                logger.exception("Error replicating from %r", replication_url)
-                yield sleep(30)
+
+    def stop_pusher(self, user_id, app_id, pushkey):
+        key = "%s:%s" % (app_id, pushkey)
+        pushers_for_user = self.pusher_pool.pushers.get(user_id, {})
+        pusher = pushers_for_user.pop(key, None)
+        if pusher is None:
+            return
+        logger.info("Stopping pusher %r / %r", user_id, key)
+        pusher.on_stop()
+
+    def start_pusher(self, user_id, app_id, pushkey):
+        key = "%s:%s" % (app_id, pushkey)
+        logger.info("Starting pusher %r / %r", user_id, key)
+        return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id)
 
 
 def start(config_options):
@@ -245,7 +215,7 @@ def start(config_options):
 
     assert config.worker_app == "synapse.app.pusher"
 
-    setup_logging(config.worker_log_config, config.worker_log_file)
+    setup_logging(config, use_worker_options=True)
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
@@ -275,7 +245,11 @@ def start(config_options):
     ps.start_listening(config.worker_listeners)
 
     def run():
-        with LoggingContext("run"):
+        # make sure that we run the reactor with the sentinel log context,
+        # otherwise other PreserveLoggingContext instances will get confused
+        # and complain when they see the logcontext arbitrarily swapping
+        # between the sentinel and `run` logcontexts.
+        with PreserveLoggingContext():
             logger.info("Running")
             change_resource_limit(config.soft_file_limit)
             if config.gc_thresholds:
@@ -283,7 +257,6 @@ def start(config_options):
             reactor.run()
 
     def start():
-        ps.replicate()
         ps.get_pusherpool().start()
         ps.get_datastore().start_profiling()
         ps.get_state_handler().start_caching()