summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/appservice.py51
-rw-r--r--synapse/storage/databases/main/deviceinbox.py13
-rw-r--r--synapse/storage/databases/main/devices.py26
-rw-r--r--synapse/storage/databases/main/receipts.py5
4 files changed, 55 insertions, 40 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py

index 1ebe4504fd..91c0b52b34 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -350,47 +350,36 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) return upper_bound, events - - async def get_device_messages_token_for_appservice(self, service): - txn.execute( - "SELECT device_message_stream_id FROM application_services_state WHERE as_id=?", - (service.id,), - ) - last_txn_id = txn.fetchone() - if last_txn_id is None or last_txn_id[0] is None: # no row exists - return 0 - else: - return int(last_txn_id[0]) # select 'last_txn' col - async def set_device_messages_token_for_appservice(self, service, pos) -> None: - def set_appservice_last_pos_txn(txn): + async def get_type_stream_id_for_appservice(self, service, type: str) -> int: + def get_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type txn.execute( - "UPDATE application_services_state SET device_message_stream_id = ? WHERE as_id=?", (pos, service.id) + "SELECT ? FROM application_services_state WHERE as_id=?", + (stream_id_type, service.id,), ) + last_txn_id = txn.fetchone() + if last_txn_id is None or last_txn_id[0] is None: # no row exists + return 0 + else: + return int(last_txn_id[0]) - await self.db_pool.runInteraction( - "set_device_messages_token_for_appservice", set_appservice_last_pos_txn - ) - - async def get_device_list_token_for_appservice(self, service): - txn.execute( - "SELECT device_list_stream_id FROM application_services_state WHERE as_id=?", - (service.id,), + return await self.db_pool.runInteraction( + "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn ) - last_txn_id = txn.fetchone() - if last_txn_id is None or last_txn_id[0] is None: # no row exists - return 0 - else: - return int(last_txn_id[0]) # select 'last_txn' col - async def set_device_list_token_for_appservice(self, service, pos) -> None: - def set_appservice_last_pos_txn(txn): + async def set_type_stream_id_for_appservice( + self, service, type: str, pos: int + ) -> None: + def set_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type txn.execute( - "UPDATE application_services_state SET device_list_stream_id = ?", (pos, service.id) + "UPDATE ? SET device_list_stream_id = ? WHERE as_id=?", + (stream_id_type, pos, service.id), ) await self.db_pool.runInteraction( - "set_device_list_token_for_appservice", set_appservice_last_pos_txn + "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 2d151b9134..8897e27b1f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -16,12 +16,12 @@ import logging from typing import List, Tuple +from synapse.appservice import ApplicationService from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache -from synapse.appservice import ApplicationService logger = logging.getLogger(__name__) @@ -44,15 +44,18 @@ class DeviceInboxWorkerStore(SQLBaseStore): " ORDER BY stream_id ASC" " LIMIT ?" ) - txn.execute( - sql, (last_stream_id, current_stream_id, limit) - ) + txn.execute(sql, (last_stream_id, current_stream_id, limit)) messages = [] for row in txn: stream_pos = row[0] if service.is_interested_in_user(row.user_id): - messages.append(db_to_json(row[1])) + msg = db_to_json(row[1]) + msg.recipient = { + "device_id": row.device_id, + "user_id": row.user_id, + } + messages.append(msg) if len(messages) < limit: stream_pos = current_stream_id return messages, stream_pos diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fdf394c612..bf32cc6c06 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -19,6 +19,7 @@ import logging from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError +from synapse.appservice import ApplicationService from synapse.logging.opentracing import ( get_active_span_text_map, set_tag, @@ -525,6 +526,31 @@ class DeviceWorkerStore(SQLBaseStore): "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) + async def get_device_changes_for_as( + self, + service: ApplicationService, + last_stream_id: int, + current_stream_id: int, + limit: int = 100, + ) -> Tuple[List[dict], int]: + def get_device_changes_for_as_txn(txn): + sql = ( + "SELECT DISTINCT user_ids FROM device_lists_stream" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_stream_id, current_stream_id, limit)) + rows = txn.fetchall() + users = [] + for user in db_to_json(rows[0]): + if await service.is_interested_in_presence(user): + users.append(user) + + return await self.db_pool.runInteraction( + "get_device_changes_for_as", get_device_changes_for_as_txn + ) + async def get_users_whose_signatures_changed( self, user_id: str, from_key: int ) -> Set[str]: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c10a16ffa3..d26c315ed4 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -283,9 +283,7 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): } return results - @cached( - num_args=2, - ) + @cached(num_args=2,) async def _get_linearized_receipts_for_all_rooms(self, to_key, from_key=None): def f(txn): if from_key: @@ -326,7 +324,6 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): return results - async def get_users_sent_receipts_between( self, last_id: int, current_id: int ) -> List[str]: