summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7855.feature1
-rw-r--r--synapse/config/_base.py38
-rw-r--r--synapse/config/_base.pyi5
-rw-r--r--synapse/config/federation.py37
-rw-r--r--synapse/config/push.py5
-rw-r--r--synapse/federation/sender/__init__.py16
-rw-r--r--synapse/federation/sender/per_destination_queue.py2
-rw-r--r--synapse/push/pusherpool.py78
-rw-r--r--tests/replication/test_pusher_shard.py193
9 files changed, 293 insertions, 82 deletions
diff --git a/changelog.d/7855.feature b/changelog.d/7855.feature
new file mode 100644
index 0000000000..2b6a9f0e71
--- /dev/null
+++ b/changelog.d/7855.feature
@@ -0,0 +1 @@
+Add experimental support for running multiple pusher workers.
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1391e5fc43..fd137853b1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -19,9 +19,11 @@ import argparse
 import errno
 import os
 from collections import OrderedDict
+from hashlib import sha256
 from textwrap import dedent
-from typing import Any, MutableMapping, Optional
+from typing import Any, List, MutableMapping, Optional
 
+import attr
 import yaml
 
 
@@ -717,4 +719,36 @@ def find_config_files(search_paths):
     return config_files
 
 
-__all__ = ["Config", "RootConfig"]
+@attr.s
+class ShardedWorkerHandlingConfig:
+    """Algorithm for choosing which instance is responsible for handling some
+    sharded work.
+
+    For example, the federation senders use this to determine which instances
+    handles sending stuff to a given destination (which is used as the `key`
+    below).
+    """
+
+    instances = attr.ib(type=List[str])
+
+    def should_handle(self, instance_name: str, key: str) -> bool:
+        """Whether this instance is responsible for handling the given key.
+        """
+
+        # If multiple instances are not defined we always return true.
+        if not self.instances or len(self.instances) == 1:
+            return True
+
+        # We shard by taking the hash, modulo it by the number of instances and
+        # then checking whether this instance matches the instance at that
+        # index.
+        #
+        # (Technically this introduces some bias and is not entirely uniform,
+        # but since the hash is so large the bias is ridiculously small).
+        dest_hash = sha256(key.encode("utf8")).digest()
+        dest_int = int.from_bytes(dest_hash, byteorder="little")
+        remainder = dest_int % (len(self.instances))
+        return self.instances[remainder] == instance_name
+
+
+__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 9e576060d4..eb911e8f9f 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -137,3 +137,8 @@ class Config:
 
 def read_config_files(config_files: List[str]): ...
 def find_config_files(search_paths: List[str]): ...
+
+class ShardedWorkerHandlingConfig:
+    instances: List[str]
+    def __init__(self, instances: List[str]) -> None: ...
+    def should_handle(self, instance_name: str, key: str) -> bool: ...
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index 7782ab4c9d..82ff9664de 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -13,42 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from hashlib import sha256
-from typing import List, Optional
+from typing import Optional
 
-import attr
 from netaddr import IPSet
 
-from ._base import Config, ConfigError
-
-
-@attr.s
-class ShardedFederationSendingConfig:
-    """Algorithm for choosing which federation sender instance is responsible
-    for which destionation host.
-    """
-
-    instances = attr.ib(type=List[str])
-
-    def should_send_to(self, instance_name: str, destination: str) -> bool:
-        """Whether this instance is responsible for sending transcations for
-        the given host.
-        """
-
-        # If multiple federation senders are not defined we always return true.
-        if not self.instances or len(self.instances) == 1:
-            return True
-
-        # We shard by taking the hash, modulo it by the number of federation
-        # senders and then checking whether this instance matches the instance
-        # at that index.
-        #
-        # (Technically this introduces some bias and is not entirely uniform, but
-        # since the hash is so large the bias is ridiculously small).
-        dest_hash = sha256(destination.encode("utf8")).digest()
-        dest_int = int.from_bytes(dest_hash, byteorder="little")
-        remainder = dest_int % (len(self.instances))
-        return self.instances[remainder] == instance_name
+from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
 
 
 class FederationConfig(Config):
@@ -61,7 +30,7 @@ class FederationConfig(Config):
         self.send_federation = config.get("send_federation", True)
 
         federation_sender_instances = config.get("federation_sender_instances") or []
-        self.federation_shard_config = ShardedFederationSendingConfig(
+        self.federation_shard_config = ShardedWorkerHandlingConfig(
             federation_sender_instances
         )
 
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 6f2b3a7faa..a1f3752c8a 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import Config
+from ._base import Config, ShardedWorkerHandlingConfig
 
 
 class PushConfig(Config):
@@ -24,6 +24,9 @@ class PushConfig(Config):
         push_config = config.get("push", {})
         self.push_include_content = push_config.get("include_content", True)
 
+        pusher_instances = config.get("pusher_instances") or []
+        self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
+
         # There was a a 'redact_content' setting but mistakenly read from the
         # 'email'section'. Check for the flag in the 'push' section, and log,
         # but do not honour it to avoid nasty surprises when people upgrade.
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4b63a0755f..b328a4df09 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -197,7 +197,7 @@ class FederationSender(object):
                     destinations = {
                         d
                         for d in destinations
-                        if self._federation_shard_config.should_send_to(
+                        if self._federation_shard_config.should_handle(
                             self._instance_name, d
                         )
                     }
@@ -335,7 +335,7 @@ class FederationSender(object):
             d
             for d in domains
             if d != self.server_name
-            and self._federation_shard_config.should_send_to(self._instance_name, d)
+            and self._federation_shard_config.should_handle(self._instance_name, d)
         ]
         if not domains:
             return
@@ -441,7 +441,7 @@ class FederationSender(object):
         for destination in destinations:
             if destination == self.server_name:
                 continue
-            if not self._federation_shard_config.should_send_to(
+            if not self._federation_shard_config.should_handle(
                 self._instance_name, destination
             ):
                 continue
@@ -460,7 +460,7 @@ class FederationSender(object):
                 if destination == self.server_name:
                     continue
 
-                if not self._federation_shard_config.should_send_to(
+                if not self._federation_shard_config.should_handle(
                     self._instance_name, destination
                 ):
                     continue
@@ -486,7 +486,7 @@ class FederationSender(object):
             logger.info("Not sending EDU to ourselves")
             return
 
-        if not self._federation_shard_config.should_send_to(
+        if not self._federation_shard_config.should_handle(
             self._instance_name, destination
         ):
             return
@@ -507,7 +507,7 @@ class FederationSender(object):
             edu: edu to send
             key: clobbering key for this edu
         """
-        if not self._federation_shard_config.should_send_to(
+        if not self._federation_shard_config.should_handle(
             self._instance_name, edu.destination
         ):
             return
@@ -523,7 +523,7 @@ class FederationSender(object):
             logger.warning("Not sending device update to ourselves")
             return
 
-        if not self._federation_shard_config.should_send_to(
+        if not self._federation_shard_config.should_handle(
             self._instance_name, destination
         ):
             return
@@ -541,7 +541,7 @@ class FederationSender(object):
             logger.warning("Not waking up ourselves")
             return
 
-        if not self._federation_shard_config.should_send_to(
+        if not self._federation_shard_config.should_handle(
             self._instance_name, destination
         ):
             return
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 6402136e8a..3436741783 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -78,7 +78,7 @@ class PerDestinationQueue(object):
         self._federation_shard_config = hs.config.federation.federation_shard_config
 
         self._should_send_on_this_instance = True
-        if not self._federation_shard_config.should_send_to(
+        if not self._federation_shard_config.should_handle(
             self._instance_name, destination
         ):
             # We don't raise an exception here to avoid taking out any other
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f6a5458681..2456f12f46 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,13 +15,12 @@
 # limitations under the License.
 
 import logging
-from collections import defaultdict
-from threading import Lock
-from typing import Dict, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Union
+
+from prometheus_client import Gauge
 
 from twisted.internet import defer
 
-from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.push import PusherConfigException
 from synapse.push.emailpusher import EmailPusher
@@ -29,9 +28,18 @@ from synapse.push.httppusher import HttpPusher
 from synapse.push.pusher import PusherFactory
 from synapse.util.async_helpers import concurrently_execute
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
+synapse_pushers = Gauge(
+    "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"]
+)
+
+
 class PusherPool:
     """
     The pusher pool. This is responsible for dispatching notifications of new events to
@@ -47,36 +55,20 @@ class PusherPool:
     Pusher.on_new_receipts are not expected to return deferreds.
     """
 
-    def __init__(self, _hs):
-        self.hs = _hs
-        self.pusher_factory = PusherFactory(_hs)
-        self._should_start_pushers = _hs.config.start_pushers
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.pusher_factory = PusherFactory(hs)
+        self._should_start_pushers = hs.config.start_pushers
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
+        # We shard the handling of push notifications by user ID.
+        self._pusher_shard_config = hs.config.push.pusher_shard_config
+        self._instance_name = hs.get_instance_name()
+
         # map from user id to app_id:pushkey to pusher
         self.pushers = {}  # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
 
-        # a lock for the pushers dict, since `count_pushers` is called from an different
-        # and we otherwise get concurrent modification errors
-        self._pushers_lock = Lock()
-
-        def count_pushers():
-            results = defaultdict(int)  # type: Dict[Tuple[str, str], int]
-            with self._pushers_lock:
-                for pushers in self.pushers.values():
-                    for pusher in pushers.values():
-                        k = (type(pusher).__name__, pusher.app_id)
-                        results[k] += 1
-            return results
-
-        LaterGauge(
-            name="synapse_pushers",
-            desc="the number of active pushers",
-            labels=["kind", "app_id"],
-            caller=count_pushers,
-        )
-
     def start(self):
         """Starts the pushers off in a background process.
         """
@@ -104,6 +96,7 @@ class PusherPool:
         Returns:
             Deferred[EmailPusher|HttpPusher]
         """
+
         time_now_msec = self.clock.time_msec()
 
         # we try to create the pusher just to validate the config: it
@@ -176,6 +169,9 @@ class PusherPool:
             access_tokens (Iterable[int]): access token *ids* to remove pushers
                 for
         """
+        if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+            return
+
         tokens = set(access_tokens)
         for p in (yield self.store.get_pushers_by_user_id(user_id)):
             if p["access_token"] in tokens:
@@ -237,6 +233,9 @@ class PusherPool:
         if not self._should_start_pushers:
             return
 
+        if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+            return
+
         resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
 
         pusher_dict = None
@@ -275,6 +274,11 @@ class PusherPool:
         Returns:
             Deferred[EmailPusher|HttpPusher]
         """
+        if not self._pusher_shard_config.should_handle(
+            self._instance_name, pusherdict["user_name"]
+        ):
+            return
+
         try:
             p = self.pusher_factory.create_pusher(pusherdict)
         except PusherConfigException as e:
@@ -298,11 +302,12 @@ class PusherPool:
 
         appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
 
-        with self._pushers_lock:
-            byuser = self.pushers.setdefault(pusherdict["user_name"], {})
-            if appid_pushkey in byuser:
-                byuser[appid_pushkey].on_stop()
-            byuser[appid_pushkey] = p
+        byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+        if appid_pushkey in byuser:
+            byuser[appid_pushkey].on_stop()
+        byuser[appid_pushkey] = p
+
+        synapse_pushers.labels(type(p).__name__, p.app_id).inc()
 
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
@@ -330,9 +335,10 @@ class PusherPool:
 
         if appid_pushkey in byuser:
             logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
-            byuser[appid_pushkey].on_stop()
-            with self._pushers_lock:
-                del byuser[appid_pushkey]
+            pusher = byuser.pop(appid_pushkey)
+            pusher.on_stop()
+
+            synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
 
         yield self.store.delete_pusher_by_app_id_pushkey_user_id(
             app_id, pushkey, user_id
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
new file mode 100644
index 0000000000..2bdc6edbb1
--- /dev/null
+++ b/tests/replication/test_pusher_shard.py
@@ -0,0 +1,193 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+
+class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
+    """Checks pusher sharding works
+    """
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        # Register a user who sends a message that we'll get notified about
+        self.other_user_id = self.register_user("otheruser", "pass")
+        self.other_access_token = self.login("otheruser", "pass")
+
+    def default_config(self):
+        conf = super().default_config()
+        conf["start_pushers"] = False
+        return conf
+
+    def _create_pusher_and_send_msg(self, localpart):
+        # Create a user that will get push notifications
+        user_id = self.register_user(localpart, "pass")
+        access_token = self.login(localpart, "pass")
+
+        # Register a pusher
+        user_dict = self.get_success(
+            self.hs.get_datastore().get_user_by_access_token(access_token)
+        )
+        token_id = user_dict["token_id"]
+
+        self.get_success(
+            self.hs.get_pusherpool().add_pusher(
+                user_id=user_id,
+                access_token=token_id,
+                kind="http",
+                app_id="m.http",
+                app_display_name="HTTP Push Notifications",
+                device_display_name="pushy push",
+                pushkey="a@example.com",
+                lang=None,
+                data={"url": "https://push.example.com/push"},
+            )
+        )
+
+        self.pump()
+
+        # Create a room
+        room = self.helper.create_room_as(user_id, tok=access_token)
+
+        # The other user joins
+        self.helper.join(
+            room=room, user=self.other_user_id, tok=self.other_access_token
+        )
+
+        # The other user sends some messages
+        response = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+        event_id = response["event_id"]
+
+        return event_id
+
+    def test_send_push_single_worker(self):
+        """Test that registration works when using a pusher worker.
+        """
+        http_client_mock = Mock(spec_set=["post_json_get_json"])
+        http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+            {}
+        )
+
+        self.make_worker_hs(
+            "synapse.app.pusher",
+            {"start_pushers": True},
+            proxied_http_client=http_client_mock,
+        )
+
+        event_id = self._create_pusher_and_send_msg("user")
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+
+        http_client_mock.post_json_get_json.assert_called_once()
+        self.assertEqual(
+            http_client_mock.post_json_get_json.call_args[0][0],
+            "https://push.example.com/push",
+        )
+        self.assertEqual(
+            event_id,
+            http_client_mock.post_json_get_json.call_args[0][1]["notification"][
+                "event_id"
+            ],
+        )
+
+    def test_send_push_multiple_workers(self):
+        """Test that registration works when using sharded pusher workers.
+        """
+        http_client_mock1 = Mock(spec_set=["post_json_get_json"])
+        http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+            {}
+        )
+
+        self.make_worker_hs(
+            "synapse.app.pusher",
+            {
+                "start_pushers": True,
+                "worker_name": "pusher1",
+                "pusher_instances": ["pusher1", "pusher2"],
+            },
+            proxied_http_client=http_client_mock1,
+        )
+
+        http_client_mock2 = Mock(spec_set=["post_json_get_json"])
+        http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+            {}
+        )
+
+        self.make_worker_hs(
+            "synapse.app.pusher",
+            {
+                "start_pushers": True,
+                "worker_name": "pusher2",
+                "pusher_instances": ["pusher1", "pusher2"],
+            },
+            proxied_http_client=http_client_mock2,
+        )
+
+        # We choose a user name that we know should go to pusher1.
+        event_id = self._create_pusher_and_send_msg("user2")
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+
+        http_client_mock1.post_json_get_json.assert_called_once()
+        http_client_mock2.post_json_get_json.assert_not_called()
+        self.assertEqual(
+            http_client_mock1.post_json_get_json.call_args[0][0],
+            "https://push.example.com/push",
+        )
+        self.assertEqual(
+            event_id,
+            http_client_mock1.post_json_get_json.call_args[0][1]["notification"][
+                "event_id"
+            ],
+        )
+
+        http_client_mock1.post_json_get_json.reset_mock()
+        http_client_mock2.post_json_get_json.reset_mock()
+
+        # Now we choose a user name that we know should go to pusher2.
+        event_id = self._create_pusher_and_send_msg("user4")
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+
+        http_client_mock1.post_json_get_json.assert_not_called()
+        http_client_mock2.post_json_get_json.assert_called_once()
+        self.assertEqual(
+            http_client_mock2.post_json_get_json.call_args[0][0],
+            "https://push.example.com/push",
+        )
+        self.assertEqual(
+            event_id,
+            http_client_mock2.post_json_get_json.call_args[0][1]["notification"][
+                "event_id"
+            ],
+        )