summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/appservice/__init__.py2
-rw-r--r--synapse/appservice/api.py6
-rw-r--r--synapse/handlers/appservice.py141
-rw-r--r--synapse/handlers/typing.py9
-rw-r--r--synapse/notifier.py17
-rw-r--r--synapse/storage/databases/main/appservice.py51
-rw-r--r--synapse/storage/databases/main/deviceinbox.py13
-rw-r--r--synapse/storage/databases/main/devices.py26
-rw-r--r--synapse/storage/databases/main/receipts.py5
9 files changed, 167 insertions, 103 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 5ed9f54bc7..a93c08ca81 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
 from synapse.api.constants import EventTypes
 from synapse.appservice.api import ApplicationServiceApi
 from synapse.types import GroupID, get_domain_from_id
-from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
     from synapse.storage.databases.main import DataStore
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 2c1dae5984..364c1a88f3 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -213,7 +213,11 @@ class ApplicationServiceApi(SimpleHttpClient):
         try:
             await self.put_json(
                 uri=uri,
-                json_body={"events": events, "device_messages": to_device, "device_lists": device_lists},
+                json_body={
+                    "events": events,
+                    "device_messages": to_device,
+                    "device_lists": device_lists,
+                },
                 args={"access_token": service.hs_token},
             )
             return True
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index dbbde3db18..6abc2891cf 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -14,36 +14,24 @@
 # limitations under the License.
 
 import logging
+from typing import Collection, List, Union
 
 from prometheus_client import Counter
 
 from twisted.internet import defer
 
 import synapse
-from typing import (
-    Awaitable,
-    Callable,
-    Dict,
-    Iterable,
-    List,
-    Optional,
-    Set,
-    Tuple,
-    TypeVar,
-    Union,
-    Collection,
-)
-
-from synapse.types import RoomStreamToken, UserID
 from synapse.api.constants import EventTypes
+from synapse.appservice import ApplicationService
+from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import (
     event_processing_loop_counter,
     event_processing_loop_room_count,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import RoomStreamToken, UserID
 from synapse.util.metrics import Measure
-from synapse.handlers.presence import format_user_presence_state
 
 logger = logging.getLogger(__name__)
 
@@ -175,8 +163,17 @@ class ApplicationServicesHandler:
             finally:
                 self.is_processing = False
 
-    async def notify_interested_services_ephemeral(self, stream_key: str, new_token: Union[int, RoomStreamToken], users: Collection[UserID] = []):
-        services = [service for service in self.store.get_app_services() if service.supports_ephemeral]
+    async def notify_interested_services_ephemeral(
+        self,
+        stream_key: str,
+        new_token: Union[int, RoomStreamToken],
+        users: Collection[UserID] = [],
+    ):
+        services = [
+            service
+            for service in self.store.get_app_services()
+            if service.supports_ephemeral
+        ]
         if not services or not self.notify_appservices:
             return
         logger.info("Checking interested services for %s" % (stream_key))
@@ -184,65 +181,99 @@ class ApplicationServicesHandler:
             for service in services:
                 events = []
                 if stream_key == "typing_key":
-                    from_key = new_token - 1
-                    typing_source = self.event_sources.sources["typing"]
-                    # Get the typing events from just before current
-                    typing, _key = await typing_source.get_new_events_as(
-                        service=service,
-                        from_key=from_key
-                    )
-                    events = typing
+                    events = await self._handle_typing(service, new_token)
                 elif stream_key == "receipt_key":
-                    from_key = new_token - 1
-                    receipts_source = self.event_sources.sources["receipt"]
-                    receipts, _key = await receipts_source.get_new_events_as(
-                        service=service,
-                        from_key=from_key
-                    )
-                    events = receipts
+                    events = await self._handle_receipts(service)
                 elif stream_key == "presence_key":
                     events = await self._handle_as_presence(service, users)
                 elif stream_key == "device_list_key":
                     # Check if the device lists have changed for any of the users we are interested in
-                    print("device_list_key", users)
+                    events = await self._handle_device_list(service, users, new_token)
                 elif stream_key == "to_device_key":
-                    # Check the inbox for any users the bridge owns 
-                    events, to_device_token = await self._handle_to_device(service, users, new_token)
-                    if events:
-                        # TODO: Do in background?
-                        await self.scheduler.submit_ephemeral_events_for_as(service, events, new_token)
-                        if stream_key == "to_device_key":
-                            # Update database with new token
-                            await self.store.set_device_messages_token_for_appservice(service, to_device_token)
-                        return
+                    # Check the inbox for any users the bridge owns
+                    events = await self._handle_to_device(service, users, new_token)
                 if events:
                     # TODO: Do in background?
-                    await self.scheduler.submit_ephemeral_events_for_as(service, events, new_token)
+                    await self.scheduler.submit_ephemeral_events_for_as(
+                        service, events, new_token
+                    )
+                    # We don't persist the token for typing_key
+                    if stream_key == "presence_key":
+                        await self.store.set_type_stream_id_for_appservice(
+                            service, "presence", new_token
+                        )
+                    elif stream_key == "receipt_key":
+                        await self.store.set_type_stream_id_for_appservice(
+                            service, "read_receipt", new_token
+                        )
+                    elif stream_key == "to_device_key":
+                        await self.store.set_type_stream_id_for_appservice(
+                            service, "to_device", new_token
+                        )
 
-    async def _handle_device_list(self, service, users, token):
-        if not any([True for u in users if service.is_interested_in_user(u)]):
-            return False
+    async def _handle_typing(self, service, new_token):
+        typing_source = self.event_sources.sources["typing"]
+        # Get the typing events from just before current
+        typing, _key = await typing_source.get_new_events_as(
+            service=service,
+            # For performance reasons, we don't persist the previous
+            # token in the DB and instead fetch the latest typing information
+            # for appservices.
+            from_key=new_token - 1,
+        )
+        return typing
+
+    async def _handle_receipts(self, service, token: int):
+        from_key = await self.store.get_type_stream_id_for_appservice(
+            service, "read_receipt"
+        )
+        receipts_source = self.event_sources.sources["receipt"]
+        receipts, _ = await receipts_source.get_new_events_as(
+            service=service, from_key=from_key
+        )
+        return receipts
+
+    async def _handle_device_list(
+        self, service: ApplicationService, users: List[str], new_token: int
+    ):
+        # TODO: Determine if any user have left and report those
+        from_token = await self.store.get_type_stream_id_for_appservice(
+            service, "device_list"
+        )
+        changed_user_ids = await self.store.get_device_changes_for_as(
+            service, from_token, new_token
+        )
+        # Return the
+        return {
+            "type": "m.device_list_update",
+            "content": {"changed": changed_user_ids,}, 
+        }
 
     async def _handle_to_device(self, service, users, token):
         if not any([True for u in users if service.is_interested_in_user(u)]):
             return False
-        
-        since_token = await self.store.get_device_messages_token_for_appservice(service)
-        
-        messages, new_token = await self.store.get_new_messages_for_as(service, since_token, token)
-        return messages, new_token
+
+        since_token = await self.store.get_type_stream_id_for_appservice(
+            service, "to_device"
+        )
+        messages, _ = await self.store.get_new_messages_for_as(
+            service, since_token, token
+        )
+        # This returns user_id -> device_id -> message
+        return messages
 
     async def _handle_as_presence(self, service, users):
         events = []
         presence_source = self.event_sources.sources["presence"]
+        from_key = await self.store.get_type_stream_id_for_appservice(
+            service, "presence"
+        )
         for user in users:
             interested = await service.is_interested_in_presence(user, self.store)
             if not interested:
                 continue
             presence_events, _key = await presence_source.get_new_events(
-                user=user,
-                service=service,
-                from_key=None, # TODO: I don't think this is required?
+                user=user, service=service, from_key=from_key,
             )
             time_now = self.clock.time_msec()
             presence_events = [
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 1747e4c872..8a8f480777 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -431,7 +431,9 @@ class TypingNotificationEventSource:
             "content": {"user_ids": list(typing)},
         }
 
-    async def get_new_events_as(self, from_key, service, **kwargs):
+    async def get_new_events_as(
+        self, from_key: int, service: ApplicationService, **kwargs
+    ):
         with Measure(self.clock, "typing.get_new_events_as"):
             from_key = int(from_key)
             handler = self.get_typing_handler()
@@ -441,8 +443,9 @@ class TypingNotificationEventSource:
                 if handler._room_serials[room_id] <= from_key:
                     print("Key too old")
                     continue
-                # XXX: Store gut wrenching
-                if not await service.matches_user_in_member_list(room_id, handler.store):
+                if not await service.matches_user_in_member_list(
+                    room_id, handler.store
+                ):
                     continue
 
                 events.append(self._make_event_for(room_id))
diff --git a/synapse/notifier.py b/synapse/notifier.py
index b6b231c15d..7396fe96c5 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -329,9 +329,16 @@ class Notifier:
         except Exception:
             logger.exception("Error notifying application services of event")
 
-    async def _notify_app_services_ephemeral(self, stream_key: str, new_token: Union[int, RoomStreamToken], users: Collection[UserID] = []):
+    async def _notify_app_services_ephemeral(
+        self,
+        stream_key: str,
+        new_token: Union[int, RoomStreamToken],
+        users: Collection[UserID] = [],
+    ):
         try:
-            await self.appservice_handler.notify_interested_services_ephemeral(stream_key, new_token, users)
+            await self.appservice_handler.notify_interested_services_ephemeral(
+                stream_key, new_token, users
+            )
         except Exception:
             logger.exception("Error notifying application services of event")
 
@@ -375,7 +382,11 @@ class Notifier:
 
                 # Notify appservices
                 run_as_background_process(
-                    "_notify_app_services_ephemeral", self._notify_app_services_ephemeral, stream_key, new_token, users,
+                    "_notify_app_services_ephemeral",
+                    self._notify_app_services_ephemeral,
+                    stream_key,
+                    new_token,
+                    users,
                 )
 
     def on_new_replication_data(self) -> None:
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 1ebe4504fd..91c0b52b34 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -350,47 +350,36 @@ class ApplicationServiceTransactionWorkerStore(
         events = await self.get_events_as_list(event_ids)
 
         return upper_bound, events
-    
-    async def get_device_messages_token_for_appservice(self, service):
-        txn.execute(
-            "SELECT device_message_stream_id FROM application_services_state WHERE as_id=?",
-            (service.id,),
-        )
-        last_txn_id = txn.fetchone()
-        if last_txn_id is None or last_txn_id[0] is None:  # no row exists
-            return 0
-        else:
-            return int(last_txn_id[0])  # select 'last_txn' col
 
-    async def set_device_messages_token_for_appservice(self, service, pos) -> None:
-        def set_appservice_last_pos_txn(txn):
+    async def get_type_stream_id_for_appservice(self, service, type: str) -> int:
+        def get_type_stream_id_for_appservice_txn(txn):
+            stream_id_type = "%s_stream_id" % type
             txn.execute(
-                "UPDATE application_services_state SET device_message_stream_id = ? WHERE as_id=?", (pos, service.id)
+                "SELECT ? FROM application_services_state WHERE as_id=?",
+                (stream_id_type, service.id,),
             )
+            last_txn_id = txn.fetchone()
+            if last_txn_id is None or last_txn_id[0] is None:  # no row exists
+                return 0
+            else:
+                return int(last_txn_id[0])
 
-        await self.db_pool.runInteraction(
-            "set_device_messages_token_for_appservice", set_appservice_last_pos_txn
-        )
-    
-    async def get_device_list_token_for_appservice(self, service):
-        txn.execute(
-            "SELECT device_list_stream_id FROM application_services_state WHERE as_id=?",
-            (service.id,),
+        return await self.db_pool.runInteraction(
+            "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
         )
-        last_txn_id = txn.fetchone()
-        if last_txn_id is None or last_txn_id[0] is None:  # no row exists
-            return 0
-        else:
-            return int(last_txn_id[0])  # select 'last_txn' col
 
-    async def set_device_list_token_for_appservice(self, service, pos) -> None:
-        def set_appservice_last_pos_txn(txn):
+    async def set_type_stream_id_for_appservice(
+        self, service, type: str, pos: int
+    ) -> None:
+        def set_type_stream_id_for_appservice_txn(txn):
+            stream_id_type = "%s_stream_id" % type
             txn.execute(
-                "UPDATE application_services_state SET device_list_stream_id = ?", (pos, service.id)
+                "UPDATE ? SET device_list_stream_id = ? WHERE as_id=?",
+                (stream_id_type, pos, service.id),
             )
 
         await self.db_pool.runInteraction(
-            "set_device_list_token_for_appservice", set_appservice_last_pos_txn
+            "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
         )
 
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 2d151b9134..8897e27b1f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -16,12 +16,12 @@
 import logging
 from typing import List, Tuple
 
+from synapse.appservice import ApplicationService
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.appservice import ApplicationService
 
 logger = logging.getLogger(__name__)
 
@@ -44,15 +44,18 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 " ORDER BY stream_id ASC"
                 " LIMIT ?"
             )
-            txn.execute(
-                sql, (last_stream_id, current_stream_id, limit)
-            )
+            txn.execute(sql, (last_stream_id, current_stream_id, limit))
             messages = []
 
             for row in txn:
                 stream_pos = row[0]
                 if service.is_interested_in_user(row.user_id):
-                    messages.append(db_to_json(row[1]))
+                    msg = db_to_json(row[1])
+                    msg.recipient = {
+                        "device_id": row.device_id,
+                        "user_id": row.user_id,
+                    }
+                    messages.append(msg)
             if len(messages) < limit:
                 stream_pos = current_stream_id
             return messages, stream_pos
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fdf394c612..bf32cc6c06 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -19,6 +19,7 @@ import logging
 from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import Codes, StoreError
+from synapse.appservice import ApplicationService
 from synapse.logging.opentracing import (
     get_active_span_text_map,
     set_tag,
@@ -525,6 +526,31 @@ class DeviceWorkerStore(SQLBaseStore):
             "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
         )
 
+    async def get_device_changes_for_as(
+        self,
+        service: ApplicationService,
+        last_stream_id: int,
+        current_stream_id: int,
+        limit: int = 100,
+    ) -> Tuple[List[dict], int]:
+        def get_device_changes_for_as_txn(txn):
+            sql = (
+                "SELECT DISTINCT user_ids FROM device_lists_stream"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (last_stream_id, current_stream_id, limit))
+            rows = txn.fetchall()
+            users = []
+            for user in db_to_json(rows[0]):
+                if await service.is_interested_in_presence(user):
+                    users.append(user)
+
+        return await self.db_pool.runInteraction(
+            "get_device_changes_for_as", get_device_changes_for_as_txn
+        )
+
     async def get_users_whose_signatures_changed(
         self, user_id: str, from_key: int
     ) -> Set[str]:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c10a16ffa3..d26c315ed4 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -283,9 +283,7 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
         }
         return results
 
-    @cached(
-        num_args=2,
-    )
+    @cached(num_args=2,)
     async def _get_linearized_receipts_for_all_rooms(self, to_key, from_key=None):
         def f(txn):
             if from_key:
@@ -326,7 +324,6 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
 
         return results
 
-
     async def get_users_sent_receipts_between(
         self, last_id: int, current_id: int
     ) -> List[str]: