summary refs log tree commit diff
path: root/synapse/handlers/appservice.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/appservice.py')
-rw-r--r--synapse/handlers/appservice.py222
1 files changed, 186 insertions, 36 deletions
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9d4e87dad6..5c6458eb52 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,8 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
 
 from prometheus_client import Counter
 
@@ -21,21 +21,32 @@ from twisted.internet import defer
 
 import synapse
 from synapse.api.constants import EventTypes
+from synapse.appservice import ApplicationService
+from synapse.events import EventBase
+from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import (
     event_processing_loop_counter,
     event_processing_loop_room_count,
 )
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
+from synapse.storage.databases.main.directory import RoomAliasMapping
+from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
 
 
 class ApplicationServicesHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.is_mine_id = hs.is_mine_id
         self.appservice_api = hs.get_application_service_api()
@@ -43,19 +54,22 @@ class ApplicationServicesHandler:
         self.started_scheduler = False
         self.clock = hs.get_clock()
         self.notify_appservices = hs.config.notify_appservices
+        self.event_sources = hs.get_event_sources()
 
         self.current_max = 0
         self.is_processing = False
 
-    async def notify_interested_services(self, current_id):
+    def notify_interested_services(self, max_token: RoomStreamToken):
         """Notifies (pushes) all application services interested in this event.
 
         Pushing is done asynchronously, so this method won't block for any
         prolonged length of time.
-
-        Args:
-            current_id(int): The current maximum ID.
         """
+        # We just use the minimum stream ordering and ignore the vector clock
+        # component. This is safe to do as long as we *always* ignore the vector
+        # clock components.
+        current_id = max_token.stream
+
         services = self.store.get_app_services()
         if not services or not self.notify_appservices:
             return
@@ -64,6 +78,12 @@ class ApplicationServicesHandler:
         if self.is_processing:
             return
 
+        # We only start a new background process if necessary rather than
+        # optimistically (to cut down on overhead).
+        self._notify_interested_services(max_token)
+
+    @wrap_as_background_process("notify_interested_services")
+    async def _notify_interested_services(self, max_token: RoomStreamToken):
         with Measure(self.clock, "notify_interested_services"):
             self.is_processing = True
             try:
@@ -79,7 +99,7 @@ class ApplicationServicesHandler:
                     if not events:
                         break
 
-                    events_by_room = {}
+                    events_by_room = {}  # type: Dict[str, List[EventBase]]
                     for event in events:
                         events_by_room.setdefault(event.room_id, []).append(event)
 
@@ -158,11 +178,139 @@ class ApplicationServicesHandler:
             finally:
                 self.is_processing = False
 
-    async def query_user_exists(self, user_id):
+    def notify_interested_services_ephemeral(
+        self,
+        stream_key: str,
+        new_token: Optional[int],
+        users: Collection[Union[str, UserID]] = [],
+    ):
+        """This is called by the notifier in the background
+        when a ephemeral event handled by the homeserver.
+
+        This will determine which appservices
+        are interested in the event, and submit them.
+
+        Events will only be pushed to appservices
+        that have opted into ephemeral events
+
+        Args:
+            stream_key: The stream the event came from.
+            new_token: The latest stream token
+            users: The user(s) involved with the event.
+        """
+        if not self.notify_appservices:
+            return
+
+        if stream_key not in ("typing_key", "receipt_key", "presence_key"):
+            return
+
+        services = [
+            service
+            for service in self.store.get_app_services()
+            if service.supports_ephemeral
+        ]
+        if not services:
+            return
+
+        # We only start a new background process if necessary rather than
+        # optimistically (to cut down on overhead).
+        self._notify_interested_services_ephemeral(
+            services, stream_key, new_token, users
+        )
+
+    @wrap_as_background_process("notify_interested_services_ephemeral")
+    async def _notify_interested_services_ephemeral(
+        self,
+        services: List[ApplicationService],
+        stream_key: str,
+        new_token: Optional[int],
+        users: Collection[Union[str, UserID]],
+    ):
+        logger.debug("Checking interested services for %s" % (stream_key))
+        with Measure(self.clock, "notify_interested_services_ephemeral"):
+            for service in services:
+                # Only handle typing if we have the latest token
+                if stream_key == "typing_key" and new_token is not None:
+                    events = await self._handle_typing(service, new_token)
+                    if events:
+                        self.scheduler.submit_ephemeral_events_for_as(service, events)
+                    # We don't persist the token for typing_key for performance reasons
+                elif stream_key == "receipt_key":
+                    events = await self._handle_receipts(service)
+                    if events:
+                        self.scheduler.submit_ephemeral_events_for_as(service, events)
+                    await self.store.set_type_stream_id_for_appservice(
+                        service, "read_receipt", new_token
+                    )
+                elif stream_key == "presence_key":
+                    events = await self._handle_presence(service, users)
+                    if events:
+                        self.scheduler.submit_ephemeral_events_for_as(service, events)
+                    await self.store.set_type_stream_id_for_appservice(
+                        service, "presence", new_token
+                    )
+
+    async def _handle_typing(
+        self, service: ApplicationService, new_token: int
+    ) -> List[JsonDict]:
+        typing_source = self.event_sources.sources["typing"]
+        # Get the typing events from just before current
+        typing, _ = await typing_source.get_new_events_as(
+            service=service,
+            # For performance reasons, we don't persist the previous
+            # token in the DB and instead fetch the latest typing information
+            # for appservices.
+            from_key=new_token - 1,
+        )
+        return typing
+
+    async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
+        from_key = await self.store.get_type_stream_id_for_appservice(
+            service, "read_receipt"
+        )
+        receipts_source = self.event_sources.sources["receipt"]
+        receipts, _ = await receipts_source.get_new_events_as(
+            service=service, from_key=from_key
+        )
+        return receipts
+
+    async def _handle_presence(
+        self, service: ApplicationService, users: Collection[Union[str, UserID]]
+    ) -> List[JsonDict]:
+        events = []  # type: List[JsonDict]
+        presence_source = self.event_sources.sources["presence"]
+        from_key = await self.store.get_type_stream_id_for_appservice(
+            service, "presence"
+        )
+        for user in users:
+            if isinstance(user, str):
+                user = UserID.from_string(user)
+
+            interested = await service.is_interested_in_presence(user, self.store)
+            if not interested:
+                continue
+            presence_events, _ = await presence_source.get_new_events(
+                user=user, service=service, from_key=from_key,
+            )
+            time_now = self.clock.time_msec()
+            events.extend(
+                {
+                    "type": "m.presence",
+                    "sender": event.user_id,
+                    "content": format_user_presence_state(
+                        event, time_now, include_user_id=False
+                    ),
+                }
+                for event in presence_events
+            )
+
+        return events
+
+    async def query_user_exists(self, user_id: str) -> bool:
         """Check if any application service knows this user_id exists.
 
         Args:
-            user_id(str): The user to query if they exist on any AS.
+            user_id: The user to query if they exist on any AS.
         Returns:
             True if this user exists on at least one application service.
         """
@@ -173,11 +321,13 @@ class ApplicationServicesHandler:
                 return True
         return False
 
-    async def query_room_alias_exists(self, room_alias):
+    async def query_room_alias_exists(
+        self, room_alias: RoomAlias
+    ) -> Optional[RoomAliasMapping]:
         """Check if an application service knows this room alias exists.
 
         Args:
-            room_alias(RoomAlias): The room alias to query.
+            room_alias: The room alias to query.
         Returns:
             namedtuple: with keys "room_id" and "servers" or None if no
             association can be found.
@@ -193,10 +343,13 @@ class ApplicationServicesHandler:
             )
             if is_known_alias:
                 # the alias exists now so don't query more ASes.
-                result = await self.store.get_association_from_room_alias(room_alias)
-                return result
+                return await self.store.get_association_from_room_alias(room_alias)
 
-    async def query_3pe(self, kind, protocol, fields):
+        return None
+
+    async def query_3pe(
+        self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
+    ) -> List[JsonDict]:
         services = self._get_services_for_3pn(protocol)
 
         results = await make_deferred_yieldable(
@@ -218,9 +371,11 @@ class ApplicationServicesHandler:
 
         return ret
 
-    async def get_3pe_protocols(self, only_protocol=None):
+    async def get_3pe_protocols(
+        self, only_protocol: Optional[str] = None
+    ) -> Dict[str, JsonDict]:
         services = self.store.get_app_services()
-        protocols = {}
+        protocols = {}  # type: Dict[str, List[JsonDict]]
 
         # Collect up all the individual protocol responses out of the ASes
         for s in services:
@@ -236,7 +391,7 @@ class ApplicationServicesHandler:
                 if info is not None:
                     protocols[p].append(info)
 
-        def _merge_instances(infos):
+        def _merge_instances(infos: List[JsonDict]) -> JsonDict:
             if not infos:
                 return {}
 
@@ -251,19 +406,17 @@ class ApplicationServicesHandler:
 
             return combined
 
-        for p in protocols.keys():
-            protocols[p] = _merge_instances(protocols[p])
-
-        return protocols
+        return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
 
-    async def _get_services_for_event(self, event):
+    async def _get_services_for_event(
+        self, event: EventBase
+    ) -> List[ApplicationService]:
         """Retrieve a list of application services interested in this event.
 
         Args:
-            event(Event): The event to check. Can be None if alias_list is not.
+            event: The event to check. Can be None if alias_list is not.
         Returns:
-            list<ApplicationService>: A list of services interested in this
-            event based on the service regex.
+            A list of services interested in this event based on the service regex.
         """
         services = self.store.get_app_services()
 
@@ -277,17 +430,15 @@ class ApplicationServicesHandler:
 
         return interested_list
 
-    def _get_services_for_user(self, user_id):
+    def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
         services = self.store.get_app_services()
-        interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
-        return interested_list
+        return [s for s in services if (s.is_interested_in_user(user_id))]
 
-    def _get_services_for_3pn(self, protocol):
+    def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
         services = self.store.get_app_services()
-        interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
-        return interested_list
+        return [s for s in services if s.is_interested_in_protocol(protocol)]
 
-    async def _is_unknown_user(self, user_id):
+    async def _is_unknown_user(self, user_id: str) -> bool:
         if not self.is_mine_id(user_id):
             # we don't know if they are unknown or not since it isn't one of our
             # users. We can't poke ASes.
@@ -302,9 +453,8 @@ class ApplicationServicesHandler:
         service_list = [s for s in services if s.sender == user_id]
         return len(service_list) == 0
 
-    async def _check_user_exists(self, user_id):
+    async def _check_user_exists(self, user_id: str) -> bool:
         unknown_user = await self._is_unknown_user(user_id)
         if unknown_user:
-            exists = await self.query_user_exists(user_id)
-            return exists
+            return await self.query_user_exists(user_id)
         return True