summary refs log tree commit diff
path: root/synapse/handlers/appservice.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/appservice.py')
-rw-r--r--synapse/handlers/appservice.py121
1 files changed, 82 insertions, 39 deletions
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 64dea23fc5..9fc8444228 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,9 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
 
 from prometheus_client import Counter
 
@@ -30,17 +29,24 @@ from synapse.metrics import (
     event_processing_loop_counter,
     event_processing_loop_room_count,
 )
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
+from synapse.storage.databases.main.directory import RoomAliasMapping
+from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
 
 
 class ApplicationServicesHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.is_mine_id = hs.is_mine_id
         self.appservice_api = hs.get_application_service_api()
@@ -53,7 +59,7 @@ class ApplicationServicesHandler:
         self.current_max = 0
         self.is_processing = False
 
-    async def notify_interested_services(self, max_token: RoomStreamToken):
+    def notify_interested_services(self, max_token: RoomStreamToken):
         """Notifies (pushes) all application services interested in this event.
 
         Pushing is done asynchronously, so this method won't block for any
@@ -72,6 +78,12 @@ class ApplicationServicesHandler:
         if self.is_processing:
             return
 
+        # We only start a new background process if necessary rather than
+        # optimistically (to cut down on overhead).
+        self._notify_interested_services(max_token)
+
+    @wrap_as_background_process("notify_interested_services")
+    async def _notify_interested_services(self, max_token: RoomStreamToken):
         with Measure(self.clock, "notify_interested_services"):
             self.is_processing = True
             try:
@@ -166,8 +178,11 @@ class ApplicationServicesHandler:
             finally:
                 self.is_processing = False
 
-    async def notify_interested_services_ephemeral(
-        self, stream_key: str, new_token: Optional[int], users: Collection[UserID] = [],
+    def notify_interested_services_ephemeral(
+        self,
+        stream_key: str,
+        new_token: Optional[int],
+        users: Collection[Union[str, UserID]] = [],
     ):
         """This is called by the notifier in the background
         when a ephemeral event handled by the homeserver.
@@ -183,13 +198,34 @@ class ApplicationServicesHandler:
             new_token: The latest stream token
             users: The user(s) involved with the event.
         """
+        if not self.notify_appservices:
+            return
+
+        if stream_key not in ("typing_key", "receipt_key", "presence_key"):
+            return
+
         services = [
             service
             for service in self.store.get_app_services()
             if service.supports_ephemeral
         ]
-        if not services or not self.notify_appservices:
+        if not services:
             return
+
+        # We only start a new background process if necessary rather than
+        # optimistically (to cut down on overhead).
+        self._notify_interested_services_ephemeral(
+            services, stream_key, new_token, users
+        )
+
+    @wrap_as_background_process("notify_interested_services_ephemeral")
+    async def _notify_interested_services_ephemeral(
+        self,
+        services: List[ApplicationService],
+        stream_key: str,
+        new_token: Optional[int],
+        users: Collection[Union[str, UserID]],
+    ):
         logger.info("Checking interested services for %s" % (stream_key))
         with Measure(self.clock, "notify_interested_services_ephemeral"):
             for service in services:
@@ -214,7 +250,9 @@ class ApplicationServicesHandler:
                         service, "presence", new_token
                     )
 
-    async def _handle_typing(self, service: ApplicationService, new_token: int):
+    async def _handle_typing(
+        self, service: ApplicationService, new_token: int
+    ) -> List[JsonDict]:
         typing_source = self.event_sources.sources["typing"]
         # Get the typing events from just before current
         typing, _ = await typing_source.get_new_events_as(
@@ -226,7 +264,7 @@ class ApplicationServicesHandler:
         )
         return typing
 
-    async def _handle_receipts(self, service: ApplicationService):
+    async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
         from_key = await self.store.get_type_stream_id_for_appservice(
             service, "read_receipt"
         )
@@ -237,7 +275,7 @@ class ApplicationServicesHandler:
         return receipts
 
     async def _handle_presence(
-        self, service: ApplicationService, users: Collection[UserID]
+        self, service: ApplicationService, users: Collection[Union[str, UserID]]
     ) -> List[JsonDict]:
         events = []  # type: List[JsonDict]
         presence_source = self.event_sources.sources["presence"]
@@ -245,6 +283,9 @@ class ApplicationServicesHandler:
             service, "presence"
         )
         for user in users:
+            if isinstance(user, str):
+                user = UserID.from_string(user)
+
             interested = await service.is_interested_in_presence(user, self.store)
             if not interested:
                 continue
@@ -265,11 +306,11 @@ class ApplicationServicesHandler:
 
         return events
 
-    async def query_user_exists(self, user_id):
+    async def query_user_exists(self, user_id: str) -> bool:
         """Check if any application service knows this user_id exists.
 
         Args:
-            user_id(str): The user to query if they exist on any AS.
+            user_id: The user to query if they exist on any AS.
         Returns:
             True if this user exists on at least one application service.
         """
@@ -280,11 +321,13 @@ class ApplicationServicesHandler:
                 return True
         return False
 
-    async def query_room_alias_exists(self, room_alias):
+    async def query_room_alias_exists(
+        self, room_alias: RoomAlias
+    ) -> Optional[RoomAliasMapping]:
         """Check if an application service knows this room alias exists.
 
         Args:
-            room_alias(RoomAlias): The room alias to query.
+            room_alias: The room alias to query.
         Returns:
             namedtuple: with keys "room_id" and "servers" or None if no
             association can be found.
@@ -300,10 +343,13 @@ class ApplicationServicesHandler:
             )
             if is_known_alias:
                 # the alias exists now so don't query more ASes.
-                result = await self.store.get_association_from_room_alias(room_alias)
-                return result
+                return await self.store.get_association_from_room_alias(room_alias)
 
-    async def query_3pe(self, kind, protocol, fields):
+        return None
+
+    async def query_3pe(
+        self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
+    ) -> List[JsonDict]:
         services = self._get_services_for_3pn(protocol)
 
         results = await make_deferred_yieldable(
@@ -325,7 +371,9 @@ class ApplicationServicesHandler:
 
         return ret
 
-    async def get_3pe_protocols(self, only_protocol=None):
+    async def get_3pe_protocols(
+        self, only_protocol: Optional[str] = None
+    ) -> Dict[str, JsonDict]:
         services = self.store.get_app_services()
         protocols = {}  # type: Dict[str, List[JsonDict]]
 
@@ -343,7 +391,7 @@ class ApplicationServicesHandler:
                 if info is not None:
                     protocols[p].append(info)
 
-        def _merge_instances(infos):
+        def _merge_instances(infos: List[JsonDict]) -> JsonDict:
             if not infos:
                 return {}
 
@@ -358,19 +406,17 @@ class ApplicationServicesHandler:
 
             return combined
 
-        for p in protocols.keys():
-            protocols[p] = _merge_instances(protocols[p])
-
-        return protocols
+        return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
 
-    async def _get_services_for_event(self, event):
+    async def _get_services_for_event(
+        self, event: EventBase
+    ) -> List[ApplicationService]:
         """Retrieve a list of application services interested in this event.
 
         Args:
-            event(Event): The event to check. Can be None if alias_list is not.
+            event: The event to check. Can be None if alias_list is not.
         Returns:
-            list<ApplicationService>: A list of services interested in this
-            event based on the service regex.
+            A list of services interested in this event based on the service regex.
         """
         services = self.store.get_app_services()
 
@@ -384,17 +430,15 @@ class ApplicationServicesHandler:
 
         return interested_list
 
-    def _get_services_for_user(self, user_id):
+    def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
         services = self.store.get_app_services()
-        interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
-        return interested_list
+        return [s for s in services if (s.is_interested_in_user(user_id))]
 
-    def _get_services_for_3pn(self, protocol):
+    def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
         services = self.store.get_app_services()
-        interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
-        return interested_list
+        return [s for s in services if s.is_interested_in_protocol(protocol)]
 
-    async def _is_unknown_user(self, user_id):
+    async def _is_unknown_user(self, user_id: str) -> bool:
         if not self.is_mine_id(user_id):
             # we don't know if they are unknown or not since it isn't one of our
             # users. We can't poke ASes.
@@ -409,9 +453,8 @@ class ApplicationServicesHandler:
         service_list = [s for s in services if s.sender == user_id]
         return len(service_list) == 0
 
-    async def _check_user_exists(self, user_id):
+    async def _check_user_exists(self, user_id: str) -> bool:
         unknown_user = await self._is_unknown_user(user_id)
         if unknown_user:
-            exists = await self.query_user_exists(user_id)
-            return exists
+            return await self.query_user_exists(user_id)
         return True