summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/_base.py38
-rw-r--r--synapse/config/_base.pyi5
-rw-r--r--synapse/config/federation.py37
-rw-r--r--synapse/config/push.py5
-rw-r--r--synapse/federation/sender/__init__.py16
-rw-r--r--synapse/federation/sender/per_destination_queue.py2
-rw-r--r--synapse/push/pusherpool.py78
7 files changed, 99 insertions, 82 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1391e5fc43..fd137853b1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -19,9 +19,11 @@ import argparse
 import errno
 import os
 from collections import OrderedDict
+from hashlib import sha256
 from textwrap import dedent
-from typing import Any, MutableMapping, Optional
+from typing import Any, List, MutableMapping, Optional
 
+import attr
 import yaml
 
 
@@ -717,4 +719,36 @@ def find_config_files(search_paths):
     return config_files
 
 
-__all__ = ["Config", "RootConfig"]
+@attr.s
+class ShardedWorkerHandlingConfig:
+    """Algorithm for choosing which instance is responsible for handling some
+    sharded work.
+
+    For example, the federation senders use this to determine which instances
+    handles sending stuff to a given destination (which is used as the `key`
+    below).
+    """
+
+    instances = attr.ib(type=List[str])
+
+    def should_handle(self, instance_name: str, key: str) -> bool:
+        """Whether this instance is responsible for handling the given key.
+        """
+
+        # If multiple instances are not defined we always return true.
+        if not self.instances or len(self.instances) == 1:
+            return True
+
+        # We shard by taking the hash, modulo it by the number of instances and
+        # then checking whether this instance matches the instance at that
+        # index.
+        #
+        # (Technically this introduces some bias and is not entirely uniform,
+        # but since the hash is so large the bias is ridiculously small).
+        dest_hash = sha256(key.encode("utf8")).digest()
+        dest_int = int.from_bytes(dest_hash, byteorder="little")
+        remainder = dest_int % (len(self.instances))
+        return self.instances[remainder] == instance_name
+
+
+__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 9e576060d4..eb911e8f9f 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -137,3 +137,8 @@ class Config:
 
 def read_config_files(config_files: List[str]): ...
 def find_config_files(search_paths: List[str]): ...
+
+class ShardedWorkerHandlingConfig:
+    instances: List[str]
+    def __init__(self, instances: List[str]) -> None: ...
+    def should_handle(self, instance_name: str, key: str) -> bool: ...
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index 7782ab4c9d..82ff9664de 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -13,42 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from hashlib import sha256
-from typing import List, Optional
+from typing import Optional
 
-import attr
 from netaddr import IPSet
 
-from ._base import Config, ConfigError
-
-
-@attr.s
-class ShardedFederationSendingConfig:
-    """Algorithm for choosing which federation sender instance is responsible
-    for which destionation host.
-    """
-
-    instances = attr.ib(type=List[str])
-
-    def should_send_to(self, instance_name: str, destination: str) -> bool:
-        """Whether this instance is responsible for sending transcations for
-        the given host.
-        """
-
-        # If multiple federation senders are not defined we always return true.
-        if not self.instances or len(self.instances) == 1:
-            return True
-
-        # We shard by taking the hash, modulo it by the number of federation
-        # senders and then checking whether this instance matches the instance
-        # at that index.
-        #
-        # (Technically this introduces some bias and is not entirely uniform, but
-        # since the hash is so large the bias is ridiculously small).
-        dest_hash = sha256(destination.encode("utf8")).digest()
-        dest_int = int.from_bytes(dest_hash, byteorder="little")
-        remainder = dest_int % (len(self.instances))
-        return self.instances[remainder] == instance_name
+from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
 
 
 class FederationConfig(Config):
@@ -61,7 +30,7 @@ class FederationConfig(Config):
         self.send_federation = config.get("send_federation", True)
 
         federation_sender_instances = config.get("federation_sender_instances") or []
-        self.federation_shard_config = ShardedFederationSendingConfig(
+        self.federation_shard_config = ShardedWorkerHandlingConfig(
             federation_sender_instances
         )
 
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 6f2b3a7faa..a1f3752c8a 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import Config
+from ._base import Config, ShardedWorkerHandlingConfig
 
 
 class PushConfig(Config):
@@ -24,6 +24,9 @@ class PushConfig(Config):
         push_config = config.get("push", {})
         self.push_include_content = push_config.get("include_content", True)
 
+        pusher_instances = config.get("pusher_instances") or []
+        self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
+
         # There was a a 'redact_content' setting but mistakenly read from the
         # 'email'section'. Check for the flag in the 'push' section, and log,
         # but do not honour it to avoid nasty surprises when people upgrade.
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4b63a0755f..b328a4df09 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -197,7 +197,7 @@ class FederationSender(object):
                     destinations = {
                         d
                         for d in destinations
-                        if self._federation_shard_config.should_send_to(
+                        if self._federation_shard_config.should_handle(
                             self._instance_name, d
                         )
                     }
@@ -335,7 +335,7 @@ class FederationSender(object):
             d
             for d in domains
             if d != self.server_name
-            and self._federation_shard_config.should_send_to(self._instance_name, d)
+            and self._federation_shard_config.should_handle(self._instance_name, d)
         ]
         if not domains:
             return
@@ -441,7 +441,7 @@ class FederationSender(object):
         for destination in destinations:
             if destination == self.server_name:
                 continue
-            if not self._federation_shard_config.should_send_to(
+            if not self._federation_shard_config.should_handle(
                 self._instance_name, destination
             ):
                 continue
@@ -460,7 +460,7 @@ class FederationSender(object):
                 if destination == self.server_name:
                     continue
 
-                if not self._federation_shard_config.should_send_to(
+                if not self._federation_shard_config.should_handle(
                     self._instance_name, destination
                 ):
                     continue
@@ -486,7 +486,7 @@ class FederationSender(object):
             logger.info("Not sending EDU to ourselves")
             return
 
-        if not self._federation_shard_config.should_send_to(
+        if not self._federation_shard_config.should_handle(
             self._instance_name, destination
         ):
             return
@@ -507,7 +507,7 @@ class FederationSender(object):
             edu: edu to send
             key: clobbering key for this edu
         """
-        if not self._federation_shard_config.should_send_to(
+        if not self._federation_shard_config.should_handle(
             self._instance_name, edu.destination
         ):
             return
@@ -523,7 +523,7 @@ class FederationSender(object):
             logger.warning("Not sending device update to ourselves")
             return
 
-        if not self._federation_shard_config.should_send_to(
+        if not self._federation_shard_config.should_handle(
             self._instance_name, destination
         ):
             return
@@ -541,7 +541,7 @@ class FederationSender(object):
             logger.warning("Not waking up ourselves")
             return
 
-        if not self._federation_shard_config.should_send_to(
+        if not self._federation_shard_config.should_handle(
             self._instance_name, destination
         ):
             return
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 6402136e8a..3436741783 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -78,7 +78,7 @@ class PerDestinationQueue(object):
         self._federation_shard_config = hs.config.federation.federation_shard_config
 
         self._should_send_on_this_instance = True
-        if not self._federation_shard_config.should_send_to(
+        if not self._federation_shard_config.should_handle(
             self._instance_name, destination
         ):
             # We don't raise an exception here to avoid taking out any other
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f6a5458681..2456f12f46 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,13 +15,12 @@
 # limitations under the License.
 
 import logging
-from collections import defaultdict
-from threading import Lock
-from typing import Dict, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Union
+
+from prometheus_client import Gauge
 
 from twisted.internet import defer
 
-from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.push import PusherConfigException
 from synapse.push.emailpusher import EmailPusher
@@ -29,9 +28,18 @@ from synapse.push.httppusher import HttpPusher
 from synapse.push.pusher import PusherFactory
 from synapse.util.async_helpers import concurrently_execute
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
+synapse_pushers = Gauge(
+    "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"]
+)
+
+
 class PusherPool:
     """
     The pusher pool. This is responsible for dispatching notifications of new events to
@@ -47,36 +55,20 @@ class PusherPool:
     Pusher.on_new_receipts are not expected to return deferreds.
     """
 
-    def __init__(self, _hs):
-        self.hs = _hs
-        self.pusher_factory = PusherFactory(_hs)
-        self._should_start_pushers = _hs.config.start_pushers
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.pusher_factory = PusherFactory(hs)
+        self._should_start_pushers = hs.config.start_pushers
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
+        # We shard the handling of push notifications by user ID.
+        self._pusher_shard_config = hs.config.push.pusher_shard_config
+        self._instance_name = hs.get_instance_name()
+
         # map from user id to app_id:pushkey to pusher
         self.pushers = {}  # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
 
-        # a lock for the pushers dict, since `count_pushers` is called from an different
-        # and we otherwise get concurrent modification errors
-        self._pushers_lock = Lock()
-
-        def count_pushers():
-            results = defaultdict(int)  # type: Dict[Tuple[str, str], int]
-            with self._pushers_lock:
-                for pushers in self.pushers.values():
-                    for pusher in pushers.values():
-                        k = (type(pusher).__name__, pusher.app_id)
-                        results[k] += 1
-            return results
-
-        LaterGauge(
-            name="synapse_pushers",
-            desc="the number of active pushers",
-            labels=["kind", "app_id"],
-            caller=count_pushers,
-        )
-
     def start(self):
         """Starts the pushers off in a background process.
         """
@@ -104,6 +96,7 @@ class PusherPool:
         Returns:
             Deferred[EmailPusher|HttpPusher]
         """
+
         time_now_msec = self.clock.time_msec()
 
         # we try to create the pusher just to validate the config: it
@@ -176,6 +169,9 @@ class PusherPool:
             access_tokens (Iterable[int]): access token *ids* to remove pushers
                 for
         """
+        if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+            return
+
         tokens = set(access_tokens)
         for p in (yield self.store.get_pushers_by_user_id(user_id)):
             if p["access_token"] in tokens:
@@ -237,6 +233,9 @@ class PusherPool:
         if not self._should_start_pushers:
             return
 
+        if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+            return
+
         resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
 
         pusher_dict = None
@@ -275,6 +274,11 @@ class PusherPool:
         Returns:
             Deferred[EmailPusher|HttpPusher]
         """
+        if not self._pusher_shard_config.should_handle(
+            self._instance_name, pusherdict["user_name"]
+        ):
+            return
+
         try:
             p = self.pusher_factory.create_pusher(pusherdict)
         except PusherConfigException as e:
@@ -298,11 +302,12 @@ class PusherPool:
 
         appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
 
-        with self._pushers_lock:
-            byuser = self.pushers.setdefault(pusherdict["user_name"], {})
-            if appid_pushkey in byuser:
-                byuser[appid_pushkey].on_stop()
-            byuser[appid_pushkey] = p
+        byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+        if appid_pushkey in byuser:
+            byuser[appid_pushkey].on_stop()
+        byuser[appid_pushkey] = p
+
+        synapse_pushers.labels(type(p).__name__, p.app_id).inc()
 
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
@@ -330,9 +335,10 @@ class PusherPool:
 
         if appid_pushkey in byuser:
             logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
-            byuser[appid_pushkey].on_stop()
-            with self._pushers_lock:
-                del byuser[appid_pushkey]
+            pusher = byuser.pop(appid_pushkey)
+            pusher.on_stop()
+
+            synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
 
         yield self.store.delete_pusher_by_app_id_pushkey_user_id(
             app_id, pushkey, user_id