summary refs log tree commit diff
diff options
context:
space:
mode:
authorWill Hunt <will@half-shot.uk>2020-09-21 15:10:06 +0100
committerWill Hunt <will@half-shot.uk>2020-09-21 15:10:06 +0100
commitae724db89986938db60d187db1ef1ab92f7e7753 (patch)
treeed2051aa52e448f905eb21dd88e090a472b7cf79
parentAppservice API changes (diff)
downloadsynapse-ae724db89986938db60d187db1ef1ab92f7e7753.tar.xz
Changes to handlers to support fetching events for appservices
-rw-r--r--synapse/handlers/appservice.py49
-rw-r--r--synapse/handlers/receipts.py22
-rw-r--r--synapse/handlers/typing.py19
-rw-r--r--synapse/storage/databases/main/receipts.py53
4 files changed, 143 insertions, 0 deletions
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9d4e87dad6..e8cc166fde 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -20,6 +20,20 @@ from prometheus_client import Counter
 from twisted.internet import defer
 
 import synapse
+from typing import (
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+    Union,
+)
+
+from synapse.types import RoomStreamToken
 from synapse.api.constants import EventTypes
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import (
@@ -43,6 +57,7 @@ class ApplicationServicesHandler:
         self.started_scheduler = False
         self.clock = hs.get_clock()
         self.notify_appservices = hs.config.notify_appservices
+        self.event_sources = hs.get_event_sources()
 
         self.current_max = 0
         self.is_processing = False
@@ -158,6 +173,40 @@ class ApplicationServicesHandler:
             finally:
                 self.is_processing = False
 
+    async def notify_interested_services_ephemeral(self, stream_key: str, new_token: Union[int, RoomStreamToken]):
+        services = [service for service in self.store.get_app_services() if service.supports_ephemeral]
+        if not services or not self.notify_appservices:
+            return
+        logger.info("Checking interested services for %s" % (stream_key))
+        with Measure(self.clock, "notify_interested_services_ephemeral"):
+            for service in services:
+                events = []
+                if stream_key == "typing_key":
+                    from_key = new_token - 1
+                    typing_source = self.event_sources.sources["typing"]
+                    # Get the typing events from just before current
+                    typing, _typing_key = await typing_source.get_new_events_as(
+                        service=service,
+                        from_key=from_key
+                    )
+                    events = typing
+                elif stream_key == "receipt_key":
+                    from_key = new_token - 1
+                    receipts_source = self.event_sources.sources["receipt"]
+                    receipts, _receipts_key = await receipts_source.get_new_events_as(
+                        service=service,
+                        from_key=from_key
+                    )
+                    events = receipts
+                elif stream_key == "presence":
+                    # TODO: This. Presence means trying to determine all the
+                    # users the appservice cares about, which means checking
+                    # all the rooms the appservice is in.
+                if events:
+                    # TODO: Do in background?
+                    await self.scheduler.submit_ephemeral_events_for_as(service, events)
+        
+
     async def query_user_exists(self, user_id):
         """Check if any application service knows this user_id exists.
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 7225923757..d9e4b1c271 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -140,5 +140,27 @@ class ReceiptEventSource:
 
         return (events, to_key)
 
+    async def get_new_events_as(self, from_key, service, **kwargs):
+        from_key = int(from_key)
+        to_key = self.get_current_key()
+
+        if from_key == to_key:
+            return [], to_key
+
+        # We first need to fetch all new receipts
+        rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
+            from_key=from_key, to_key=to_key
+        )
+
+        # Then filter down to rooms that the AS can read
+        events = []
+        for room_id, event in rooms_to_events.items():
+            if not await service.matches_user_in_member_list(room_id, self.store):
+                continue
+
+            events.append(event)
+
+        return (events, to_key)
+
     def get_current_key(self, direction="f"):
         return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3cbfc2d780..1747e4c872 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -19,6 +19,7 @@ from collections import namedtuple
 from typing import TYPE_CHECKING, List, Set, Tuple
 
 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
+from synapse.appservice import ApplicationService
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.streams import TypingStream
 from synapse.types import UserID, get_domain_from_id
@@ -430,6 +431,24 @@ class TypingNotificationEventSource:
             "content": {"user_ids": list(typing)},
         }
 
+    async def get_new_events_as(self, from_key, service, **kwargs):
+        with Measure(self.clock, "typing.get_new_events_as"):
+            from_key = int(from_key)
+            handler = self.get_typing_handler()
+
+            events = []
+            for room_id in handler._room_serials.keys():
+                if handler._room_serials[room_id] <= from_key:
+                    print("Key too old")
+                    continue
+                # XXX: Store gut wrenching
+                if not await service.matches_user_in_member_list(room_id, handler.store):
+                    continue
+
+                events.append(self._make_event_for(room_id))
+
+            return (events, handler._latest_room_serial)
+
     async def get_new_events(self, from_key, room_ids, **kwargs):
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index f880b5e562..5867d52b62 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -123,6 +123,15 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
             for row in rows
         }
 
+    async def get_linearized_receipts_for_all_rooms(
+        self, to_key: int, from_key: Optional[int] = None
+    ) -> List[dict]:
+        results = await self._get_linearized_receipts_for_all_rooms(
+            to_key, from_key=from_key
+        )
+
+        return results
+
     async def get_linearized_receipts_for_rooms(
         self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
     ) -> List[dict]:
@@ -274,6 +283,50 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
         }
         return results
 
+    @cached(
+        num_args=2,
+    )
+    async def _get_linearized_receipts_for_all_rooms(self, to_key, from_key=None):
+        def f(txn):
+            if from_key:
+                sql = """
+                    SELECT * FROM receipts_linearized WHERE
+                    stream_id > ? AND stream_id <= ?
+                """
+                txn.execute(sql, [from_key, to_key])
+            else:
+                sql = """
+                    SELECT * FROM receipts_linearized WHERE
+                    stream_id <= ?
+                """
+
+                txn.execute(sql, [to_key])
+
+            return self.db_pool.cursor_to_dict(txn)
+
+        txn_results = await self.db_pool.runInteraction(
+            "_get_linearized_receipts_for_all_rooms", f
+        )
+
+        results = {}
+        for row in txn_results:
+            # We want a single event per room, since we want to batch the
+            # receipts by room, event and type.
+            room_event = results.setdefault(
+                row["room_id"],
+                {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
+            )
+
+            # The content is of the form:
+            # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
+            event_entry = room_event["content"].setdefault(row["event_id"], {})
+            receipt_type = event_entry.setdefault(row["receipt_type"], {})
+
+            receipt_type[row["user_id"]] = db_to_json(row["data"])
+
+        return results
+
+
     async def get_users_sent_receipts_between(
         self, last_id: int, current_id: int
     ) -> List[str]: