summary refs log tree commit diff
path: root/synapse/app/pusher.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/app/pusher.py')
-rw-r--r--synapse/app/pusher.py237
1 files changed, 80 insertions, 157 deletions
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 073f2c2489..26930d1b3b 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -13,39 +13,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+import sys
 
 import synapse
-
-from synapse.server import HomeServer
+from synapse import events
+from synapse.app import _base
 from synapse.config._base import ConfigError
-from synapse.config.logger import setup_logging
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
 from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
-from synapse.storage.roommember import RoomMemberStore
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.pushers import SlavedPusherStore
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.storage.engines import create_engine
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.server import HomeServer
 from synapse.storage import DataStore
-from synapse.util.async import sleep
+from synapse.storage.engines import create_engine
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
 from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
-
-from synapse import events
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
 
 logger = logging.getLogger("synapse.app.pusher")
 
@@ -82,42 +74,15 @@ class PusherSlaveStore(
         DataStore.get_profile_displayname.__func__
     )
 
-    who_forgot_in_room = (
-        RoomMemberStore.__dict__["who_forgot_in_room"]
-    )
-
 
 class PusherServer(HomeServer):
-
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = PusherSlaveStore(self.get_db_conn(), self)
         logger.info("Finished setting up.")
 
     def remove_pusher(self, app_id, push_key, user_id):
-        http_client = self.get_simple_http_client()
-        replication_url = self.config.worker_replication_url
-        url = replication_url + "/remove_pushers"
-        return http_client.post_json_get_json(url, {
-            "remove": [{
-                "app_id": app_id,
-                "push_key": push_key,
-                "user_id": user_id,
-            }]
-        })
+        self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
 
     def _listen_http(self, listener_config):
         port = listener_config["port"]
@@ -129,19 +94,18 @@ class PusherServer(HomeServer):
                 if name == "metrics":
                     resources[METRICS_PREFIX] = MetricsResource(self)
 
-        root_resource = create_resource_tree(resources, Resource())
-
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        root_resource = create_resource_tree(resources, NoResource())
+
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse pusher now listening on port %d", port)
 
@@ -150,88 +114,67 @@ class PusherServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
+        self.get_tcp_replication().start_replication(self)
+
+    def build_tcp_replication(self):
+        return PusherReplicationHandler(self)
+
+
+class PusherReplicationHandler(ReplicationClientHandler):
+    def __init__(self, hs):
+        super(PusherReplicationHandler, self).__init__(hs.get_datastore())
+
+        self.pusher_pool = hs.get_pusherpool()
+
+    def on_rdata(self, stream_name, token, rows):
+        super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
+        run_in_background(self.poke_pushers, stream_name, token, rows)
+
     @defer.inlineCallbacks
-    def replicate(self):
-        http_client = self.get_simple_http_client()
-        store = self.get_datastore()
-        replication_url = self.config.worker_replication_url
-        pusher_pool = self.get_pusherpool()
-
-        def stop_pusher(user_id, app_id, pushkey):
-            key = "%s:%s" % (app_id, pushkey)
-            pushers_for_user = pusher_pool.pushers.get(user_id, {})
-            pusher = pushers_for_user.pop(key, None)
-            if pusher is None:
-                return
-            logger.info("Stopping pusher %r / %r", user_id, key)
-            pusher.on_stop()
-
-        def start_pusher(user_id, app_id, pushkey):
-            key = "%s:%s" % (app_id, pushkey)
-            logger.info("Starting pusher %r / %r", user_id, key)
-            return pusher_pool._refresh_pusher(app_id, pushkey, user_id)
-
-        @defer.inlineCallbacks
-        def poke_pushers(results):
-            pushers_rows = set(
-                map(tuple, results.get("pushers", {}).get("rows", []))
-            )
-            deleted_pushers_rows = set(
-                map(tuple, results.get("deleted_pushers", {}).get("rows", []))
-            )
-            for row in sorted(pushers_rows | deleted_pushers_rows):
-                if row in deleted_pushers_rows:
-                    user_id, app_id, pushkey = row[1:4]
-                    stop_pusher(user_id, app_id, pushkey)
-                elif row in pushers_rows:
-                    user_id = row[1]
-                    app_id = row[5]
-                    pushkey = row[8]
-                    yield start_pusher(user_id, app_id, pushkey)
-
-            stream = results.get("events")
-            if stream and stream["rows"]:
-                min_stream_id = stream["rows"][0][0]
-                max_stream_id = stream["position"]
-                preserve_fn(pusher_pool.on_new_notifications)(
-                    min_stream_id, max_stream_id
+    def poke_pushers(self, stream_name, token, rows):
+        try:
+            if stream_name == "pushers":
+                for row in rows:
+                    if row.deleted:
+                        yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+                    else:
+                        yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
+            elif stream_name == "events":
+                yield self.pusher_pool.on_new_notifications(
+                    token, token,
                 )
-
-            stream = results.get("receipts")
-            if stream and stream["rows"]:
-                rows = stream["rows"]
-                affected_room_ids = set(row[1] for row in rows)
-                min_stream_id = rows[0][0]
-                max_stream_id = stream["position"]
-                preserve_fn(pusher_pool.on_new_receipts)(
-                    min_stream_id, max_stream_id, affected_room_ids
+            elif stream_name == "receipts":
+                yield self.pusher_pool.on_new_receipts(
+                    token, token, set(row.room_id for row in rows)
                 )
+        except Exception:
+            logger.exception("Error poking pushers")
+
+    def stop_pusher(self, user_id, app_id, pushkey):
+        key = "%s:%s" % (app_id, pushkey)
+        pushers_for_user = self.pusher_pool.pushers.get(user_id, {})
+        pusher = pushers_for_user.pop(key, None)
+        if pusher is None:
+            return
+        logger.info("Stopping pusher %r / %r", user_id, key)
+        pusher.on_stop()
 
-        while True:
-            try:
-                args = store.stream_positions()
-                args["timeout"] = 30000
-                result = yield http_client.get_json(replication_url, args=args)
-                yield store.process_replication(result)
-                poke_pushers(result)
-            except:
-                logger.exception("Error replicating from %r", replication_url)
-                yield sleep(30)
+    def start_pusher(self, user_id, app_id, pushkey):
+        key = "%s:%s" % (app_id, pushkey)
+        logger.info("Starting pusher %r / %r", user_id, key)
+        return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id)
 
 
 def start(config_options):
@@ -245,7 +188,7 @@ def start(config_options):
 
     assert config.worker_app == "synapse.app.pusher"
 
-    setup_logging(config.worker_log_config, config.worker_log_file)
+    setup_logging(config, use_worker_options=True)
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
@@ -274,34 +217,14 @@ def start(config_options):
     ps.setup()
     ps.start_listening(config.worker_listeners)
 
-    def run():
-        with LoggingContext("run"):
-            logger.info("Running")
-            change_resource_limit(config.soft_file_limit)
-            if config.gc_thresholds:
-                gc.set_threshold(*config.gc_thresholds)
-            reactor.run()
-
     def start():
-        ps.replicate()
         ps.get_pusherpool().start()
         ps.get_datastore().start_profiling()
         ps.get_state_handler().start_caching()
 
     reactor.callWhenRunning(start)
 
-    if config.worker_daemonize:
-        daemon = Daemonize(
-            app="synapse-pusher",
-            pid=config.worker_pid_file,
-            action=run,
-            auto_close_fds=False,
-            verbose=True,
-            logger=logger,
-        )
-        daemon.start()
-    else:
-        run()
+    _base.start_worker_reactor("synapse-pusher", config)
 
 
 if __name__ == '__main__':