summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8513.feature1
-rw-r--r--synapse/handlers/account_validity.py3
-rw-r--r--synapse/handlers/deactivate_account.py2
-rw-r--r--synapse/handlers/message.py18
-rw-r--r--synapse/handlers/pagination.py2
-rw-r--r--synapse/handlers/profile.py29
-rw-r--r--synapse/server.py12
-rw-r--r--synapse/storage/databases/main/profile.py82
-rw-r--r--synapse/storage/databases/main/registration.py52
-rw-r--r--synapse/storage/databases/main/room.py168
-rw-r--r--tests/handlers/test_typing.py48
11 files changed, 196 insertions, 221 deletions
diff --git a/changelog.d/8513.feature b/changelog.d/8513.feature
new file mode 100644
index 0000000000..542993110b
--- /dev/null
+++ b/changelog.d/8513.feature
@@ -0,0 +1 @@
+Allow running background tasks in a separate worker process.
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 4caf6d591a..f33044e97a 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -70,7 +70,8 @@ class AccountValidityHandler:
                     "send_renewals", self._send_renewal_emails
                 )
 
-            self.clock.looping_call(send_emails, 30 * 60 * 1000)
+            if hs.config.run_background_tasks:
+                self.clock.looping_call(send_emails, 30 * 60 * 1000)
 
     async def _send_renewal_emails(self):
         """Gets the list of users whose account is expiring in the amount of time
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 58c9f12686..4efe6c530a 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -45,7 +45,7 @@ class DeactivateAccountHandler(BaseHandler):
 
         # Start the user parter loop so it can resume parting users from rooms where
         # it left off (if it has work left to do).
-        if hs.config.worker_app is None:
+        if hs.config.run_background_tasks:
             hs.get_reactor().callWhenRunning(self._start_user_parting)
 
         self._account_validity_enabled = hs.config.account_validity.enabled
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index b0da938aa9..c52e6824d3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -402,21 +402,23 @@ class EventCreationHandler:
             self.config.block_events_without_consent_error
         )
 
+        # we need to construct a ConsentURIBuilder here, as it checks that the necessary
+        # config options, but *only* if we have a configuration for which we are
+        # going to need it.
+        if self._block_events_without_consent_error:
+            self._consent_uri_builder = ConsentURIBuilder(self.config)
+
         # Rooms which should be excluded from dummy insertion. (For instance,
         # those without local users who can send events into the room).
         #
         # map from room id to time-of-last-attempt.
         #
         self._rooms_to_exclude_from_dummy_event_insertion = {}  # type: Dict[str, int]
-
-        # we need to construct a ConsentURIBuilder here, as it checks that the necessary
-        # config options, but *only* if we have a configuration for which we are
-        # going to need it.
-        if self._block_events_without_consent_error:
-            self._consent_uri_builder = ConsentURIBuilder(self.config)
+        # The number of forward extremeities before a dummy event is sent.
+        self._dummy_events_threshold = hs.config.dummy_events_threshold
 
         if (
-            not self.config.worker_app
+            self.config.run_background_tasks
             and self.config.cleanup_extremities_with_dummy_events
         ):
             self.clock.looping_call(
@@ -431,8 +433,6 @@ class EventCreationHandler:
 
         self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
 
-        self._dummy_events_threshold = hs.config.dummy_events_threshold
-
     async def create_event(
         self,
         requester: Requester,
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 085b685959..426b58da9e 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -92,7 +92,7 @@ class PaginationHandler:
         self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
         self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
 
-        if hs.config.retention_enabled:
+        if hs.config.run_background_tasks and hs.config.retention_enabled:
             # Run the purge jobs described in the configuration file.
             for job in hs.config.retention_purge_jobs:
                 logger.info("Setting up purge job with config: %s", job)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 5453e6dfc8..b784938755 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -35,14 +35,16 @@ MAX_DISPLAYNAME_LEN = 256
 MAX_AVATAR_URL_LEN = 1000
 
 
-class BaseProfileHandler(BaseHandler):
+class ProfileHandler(BaseHandler):
     """Handles fetching and updating user profile information.
 
-    BaseProfileHandler can be instantiated directly on workers and will
-    delegate to master when necessary. The master process should use the
-    subclass MasterProfileHandler
+    ProfileHandler can be instantiated directly on workers and will
+    delegate to master when necessary.
     """
 
+    PROFILE_UPDATE_MS = 60 * 1000
+    PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
+
     def __init__(self, hs):
         super().__init__(hs)
 
@@ -53,6 +55,11 @@ class BaseProfileHandler(BaseHandler):
 
         self.user_directory_handler = hs.get_user_directory_handler()
 
+        if hs.config.run_background_tasks:
+            self.clock.looping_call(
+                self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
+            )
+
     async def get_profile(self, user_id):
         target_user = UserID.from_string(user_id)
 
@@ -363,20 +370,6 @@ class BaseProfileHandler(BaseHandler):
                 raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
             raise
 
-
-class MasterProfileHandler(BaseProfileHandler):
-    PROFILE_UPDATE_MS = 60 * 1000
-    PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
-
-    def __init__(self, hs):
-        super().__init__(hs)
-
-        assert hs.config.worker_app is None
-
-        self.clock.looping_call(
-            self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
-        )
-
     def _start_update_remote_profile_cache(self):
         return run_as_background_process(
             "Update remote profile", self._update_remote_profile_cache
diff --git a/synapse/server.py b/synapse/server.py
index e793793cdc..f921ee4b53 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -75,7 +75,7 @@ from synapse.handlers.message import EventCreationHandler, MessageHandler
 from synapse.handlers.pagination import PaginationHandler
 from synapse.handlers.password_policy import PasswordPolicyHandler
 from synapse.handlers.presence import PresenceHandler
-from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
+from synapse.handlers.profile import ProfileHandler
 from synapse.handlers.read_marker import ReadMarkerHandler
 from synapse.handlers.receipts import ReceiptsHandler
 from synapse.handlers.register import RegistrationHandler
@@ -191,7 +191,12 @@ class HomeServer(metaclass=abc.ABCMeta):
     """
 
     REQUIRED_ON_BACKGROUND_TASK_STARTUP = [
+        "account_validity",
         "auth",
+        "deactivate_account",
+        "message",
+        "pagination",
+        "profile",
         "stats",
     ]
 
@@ -462,10 +467,7 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_profile_handler(self):
-        if self.config.worker_app:
-            return BaseProfileHandler(self)
-        else:
-            return MasterProfileHandler(self)
+        return ProfileHandler(self)
 
     @cache_in_self
     def get_event_creation_handler(self) -> EventCreationHandler:
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index d2e0685e9e..1681caa1f0 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -91,27 +91,6 @@ class ProfileWorkerStore(SQLBaseStore):
             desc="set_profile_avatar_url",
         )
 
-
-class ProfileStore(ProfileWorkerStore):
-    async def add_remote_profile_cache(
-        self, user_id: str, displayname: str, avatar_url: str
-    ) -> None:
-        """Ensure we are caching the remote user's profiles.
-
-        This should only be called when `is_subscribed_remote_profile_for_user`
-        would return true for the user.
-        """
-        await self.db_pool.simple_upsert(
-            table="remote_profile_cache",
-            keyvalues={"user_id": user_id},
-            values={
-                "displayname": displayname,
-                "avatar_url": avatar_url,
-                "last_check": self._clock.time_msec(),
-            },
-            desc="add_remote_profile_cache",
-        )
-
     async def update_remote_profile_cache(
         self, user_id: str, displayname: str, avatar_url: str
     ) -> int:
@@ -138,6 +117,31 @@ class ProfileStore(ProfileWorkerStore):
                 desc="delete_remote_profile_cache",
             )
 
+    async def is_subscribed_remote_profile_for_user(self, user_id):
+        """Check whether we are interested in a remote user's profile.
+        """
+        res = await self.db_pool.simple_select_one_onecol(
+            table="group_users",
+            keyvalues={"user_id": user_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="should_update_remote_profile_cache_for_user",
+        )
+
+        if res:
+            return True
+
+        res = await self.db_pool.simple_select_one_onecol(
+            table="group_invites",
+            keyvalues={"user_id": user_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="should_update_remote_profile_cache_for_user",
+        )
+
+        if res:
+            return True
+
     async def get_remote_profile_cache_entries_that_expire(
         self, last_checked: int
     ) -> Dict[str, str]:
@@ -160,27 +164,23 @@ class ProfileStore(ProfileWorkerStore):
             _get_remote_profile_cache_entries_that_expire_txn,
         )
 
-    async def is_subscribed_remote_profile_for_user(self, user_id):
-        """Check whether we are interested in a remote user's profile.
-        """
-        res = await self.db_pool.simple_select_one_onecol(
-            table="group_users",
-            keyvalues={"user_id": user_id},
-            retcol="user_id",
-            allow_none=True,
-            desc="should_update_remote_profile_cache_for_user",
-        )
 
-        if res:
-            return True
+class ProfileStore(ProfileWorkerStore):
+    async def add_remote_profile_cache(
+        self, user_id: str, displayname: str, avatar_url: str
+    ) -> None:
+        """Ensure we are caching the remote user's profiles.
 
-        res = await self.db_pool.simple_select_one_onecol(
-            table="group_invites",
+        This should only be called when `is_subscribed_remote_profile_for_user`
+        would return true for the user.
+        """
+        await self.db_pool.simple_upsert(
+            table="remote_profile_cache",
             keyvalues={"user_id": user_id},
-            retcol="user_id",
-            allow_none=True,
-            desc="should_update_remote_profile_cache_for_user",
+            values={
+                "displayname": displayname,
+                "avatar_url": avatar_url,
+                "last_check": self._clock.time_msec(),
+            },
+            desc="add_remote_profile_cache",
         )
-
-        if res:
-            return True
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9a003e30f9..4c843b7679 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -862,6 +862,32 @@ class RegistrationWorkerStore(SQLBaseStore):
             values={"expiration_ts_ms": expiration_ts, "email_sent": False},
         )
 
+    async def get_user_pending_deactivation(self) -> Optional[str]:
+        """
+        Gets one user from the table of users waiting to be parted from all the rooms
+        they're in.
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            "users_pending_deactivation",
+            keyvalues={},
+            retcol="user_id",
+            allow_none=True,
+            desc="get_users_pending_deactivation",
+        )
+
+    async def del_user_pending_deactivation(self, user_id: str) -> None:
+        """
+        Removes the given user to the table of users who need to be parted from all the
+        rooms they're in, effectively marking that user as fully deactivated.
+        """
+        # XXX: This should be simple_delete_one but we failed to put a unique index on
+        # the table, so somehow duplicate entries have ended up in it.
+        await self.db_pool.simple_delete(
+            "users_pending_deactivation",
+            keyvalues={"user_id": user_id},
+            desc="del_user_pending_deactivation",
+        )
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -1371,32 +1397,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="add_user_pending_deactivation",
         )
 
-    async def del_user_pending_deactivation(self, user_id: str) -> None:
-        """
-        Removes the given user to the table of users who need to be parted from all the
-        rooms they're in, effectively marking that user as fully deactivated.
-        """
-        # XXX: This should be simple_delete_one but we failed to put a unique index on
-        # the table, so somehow duplicate entries have ended up in it.
-        await self.db_pool.simple_delete(
-            "users_pending_deactivation",
-            keyvalues={"user_id": user_id},
-            desc="del_user_pending_deactivation",
-        )
-
-    async def get_user_pending_deactivation(self) -> Optional[str]:
-        """
-        Gets one user from the table of users waiting to be parted from all the rooms
-        they're in.
-        """
-        return await self.db_pool.simple_select_one_onecol(
-            "users_pending_deactivation",
-            keyvalues={},
-            retcol="user_id",
-            allow_none=True,
-            desc="get_users_pending_deactivation",
-        )
-
     async def validate_threepid_session(
         self, session_id: str, client_secret: str, token: str, current_ts: int
     ) -> Optional[str]:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index c0f2af0785..e83d961c20 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -869,6 +869,89 @@ class RoomWorkerStore(SQLBaseStore):
             "get_all_new_public_rooms", get_all_new_public_rooms
         )
 
+    async def get_rooms_for_retention_period_in_range(
+        self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
+    ) -> Dict[str, dict]:
+        """Retrieves all of the rooms within the given retention range.
+
+        Optionally includes the rooms which don't have a retention policy.
+
+        Args:
+            min_ms: Duration in milliseconds that define the lower limit of
+                the range to handle (exclusive). If None, doesn't set a lower limit.
+            max_ms: Duration in milliseconds that define the upper limit of
+                the range to handle (inclusive). If None, doesn't set an upper limit.
+            include_null: Whether to include rooms which retention policy is NULL
+                in the returned set.
+
+        Returns:
+            The rooms within this range, along with their retention
+            policy. The key is "room_id", and maps to a dict describing the retention
+            policy associated with this room ID. The keys for this nested dict are
+            "min_lifetime" (int|None), and "max_lifetime" (int|None).
+        """
+
+        def get_rooms_for_retention_period_in_range_txn(txn):
+            range_conditions = []
+            args = []
+
+            if min_ms is not None:
+                range_conditions.append("max_lifetime > ?")
+                args.append(min_ms)
+
+            if max_ms is not None:
+                range_conditions.append("max_lifetime <= ?")
+                args.append(max_ms)
+
+            # Do a first query which will retrieve the rooms that have a retention policy
+            # in their current state.
+            sql = """
+                SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+                INNER JOIN current_state_events USING (event_id, room_id)
+                """
+
+            if len(range_conditions):
+                sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+                if include_null:
+                    sql += " OR max_lifetime IS NULL"
+
+            txn.execute(sql, args)
+
+            rows = self.db_pool.cursor_to_dict(txn)
+            rooms_dict = {}
+
+            for row in rows:
+                rooms_dict[row["room_id"]] = {
+                    "min_lifetime": row["min_lifetime"],
+                    "max_lifetime": row["max_lifetime"],
+                }
+
+            if include_null:
+                # If required, do a second query that retrieves all of the rooms we know
+                # of so we can handle rooms with no retention policy.
+                sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+                txn.execute(sql)
+
+                rows = self.db_pool.cursor_to_dict(txn)
+
+                # If a room isn't already in the dict (i.e. it doesn't have a retention
+                # policy in its state), add it with a null policy.
+                for row in rows:
+                    if row["room_id"] not in rooms_dict:
+                        rooms_dict[row["room_id"]] = {
+                            "min_lifetime": None,
+                            "max_lifetime": None,
+                        }
+
+            return rooms_dict
+
+        return await self.db_pool.runInteraction(
+            "get_rooms_for_retention_period_in_range",
+            get_rooms_for_retention_period_in_range_txn,
+        )
+
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1446,88 +1529,3 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             self.is_room_blocked,
             (room_id,),
         )
-
-    async def get_rooms_for_retention_period_in_range(
-        self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
-    ) -> Dict[str, dict]:
-        """Retrieves all of the rooms within the given retention range.
-
-        Optionally includes the rooms which don't have a retention policy.
-
-        Args:
-            min_ms: Duration in milliseconds that define the lower limit of
-                the range to handle (exclusive). If None, doesn't set a lower limit.
-            max_ms: Duration in milliseconds that define the upper limit of
-                the range to handle (inclusive). If None, doesn't set an upper limit.
-            include_null: Whether to include rooms which retention policy is NULL
-                in the returned set.
-
-        Returns:
-            The rooms within this range, along with their retention
-            policy. The key is "room_id", and maps to a dict describing the retention
-            policy associated with this room ID. The keys for this nested dict are
-            "min_lifetime" (int|None), and "max_lifetime" (int|None).
-        """
-
-        def get_rooms_for_retention_period_in_range_txn(txn):
-            range_conditions = []
-            args = []
-
-            if min_ms is not None:
-                range_conditions.append("max_lifetime > ?")
-                args.append(min_ms)
-
-            if max_ms is not None:
-                range_conditions.append("max_lifetime <= ?")
-                args.append(max_ms)
-
-            # Do a first query which will retrieve the rooms that have a retention policy
-            # in their current state.
-            sql = """
-                SELECT room_id, min_lifetime, max_lifetime FROM room_retention
-                INNER JOIN current_state_events USING (event_id, room_id)
-                """
-
-            if len(range_conditions):
-                sql += " WHERE (" + " AND ".join(range_conditions) + ")"
-
-                if include_null:
-                    sql += " OR max_lifetime IS NULL"
-
-            txn.execute(sql, args)
-
-            rows = self.db_pool.cursor_to_dict(txn)
-            rooms_dict = {}
-
-            for row in rows:
-                rooms_dict[row["room_id"]] = {
-                    "min_lifetime": row["min_lifetime"],
-                    "max_lifetime": row["max_lifetime"],
-                }
-
-            if include_null:
-                # If required, do a second query that retrieves all of the rooms we know
-                # of so we can handle rooms with no retention policy.
-                sql = "SELECT DISTINCT room_id FROM current_state_events"
-
-                txn.execute(sql)
-
-                rows = self.db_pool.cursor_to_dict(txn)
-
-                # If a room isn't already in the dict (i.e. it doesn't have a retention
-                # policy in its state), add it with a null policy.
-                for row in rows:
-                    if row["room_id"] not in rooms_dict:
-                        rooms_dict[row["room_id"]] = {
-                            "min_lifetime": None,
-                            "max_lifetime": None,
-                        }
-
-            return rooms_dict
-
-        rooms = await self.db_pool.runInteraction(
-            "get_rooms_for_retention_period_in_range",
-            get_rooms_for_retention_period_in_range_txn,
-        )
-
-        return rooms
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 3fec09ea8a..16ff2e22d2 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -65,26 +65,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         mock_federation_client = Mock(spec=["put_json"])
         mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
 
-        datastores = Mock()
-        datastores.main = Mock(
-            spec=[
-                # Bits that Federation needs
-                "prep_send_transaction",
-                "delivered_txn",
-                "get_received_txn_response",
-                "set_received_txn_response",
-                "get_destination_last_successful_stream_ordering",
-                "get_destination_retry_timings",
-                "get_devices_by_remote",
-                "maybe_store_room_on_invite",
-                # Bits that user_directory needs
-                "get_user_directory_stream_pos",
-                "get_current_state_deltas",
-                "get_device_updates_by_remote",
-                "get_room_max_stream_ordering",
-            ]
-        )
-
         # the tests assume that we are starting at unix time 1000
         reactor.pump((1000,))
 
@@ -95,8 +75,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             replication_streams={},
         )
 
-        hs.datastores = datastores
-
         return hs
 
     def prepare(self, reactor, clock, hs):
@@ -114,16 +92,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             "retry_interval": 0,
             "failure_ts": None,
         }
-        self.datastore.get_destination_retry_timings.return_value = defer.succeed(
-            retry_timings_res
+        self.datastore.get_destination_retry_timings = Mock(
+            return_value=defer.succeed(retry_timings_res)
         )
 
-        self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
-            (0, [])
+        self.datastore.get_device_updates_by_remote = Mock(
+            return_value=make_awaitable((0, []))
         )
 
-        self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable(
-            None
+        self.datastore.get_destination_last_successful_stream_ordering = Mock(
+            return_value=make_awaitable(None)
         )
 
         def get_received_txn_response(*args):
@@ -145,17 +123,19 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
 
-        def get_users_in_room(room_id):
-            return defer.succeed({str(u) for u in self.room_members})
+        async def get_users_in_room(room_id):
+            return {str(u) for u in self.room_members}
 
         self.datastore.get_users_in_room = get_users_in_room
 
-        self.datastore.get_user_directory_stream_pos.side_effect = (
-            # we deliberately return a non-None stream pos to avoid doing an initial_spam
-            lambda: make_awaitable(1)
+        self.datastore.get_user_directory_stream_pos = Mock(
+            side_effect=(
+                # we deliberately return a non-None stream pos to avoid doing an initial_spam
+                lambda: make_awaitable(1)
+            )
         )
 
-        self.datastore.get_current_state_deltas.return_value = (0, None)
+        self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
 
         self.datastore.get_to_device_stream_token = lambda: 0
         self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(