diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index e06b0bc56d..e6a42a53bb 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -19,14 +19,174 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache
-from .background_updates import BackgroundUpdateStore
-
logger = logging.getLogger(__name__)
-class DeviceInboxStore(BackgroundUpdateStore):
+class DeviceInboxWorkerStore(SQLBaseStore):
+ def get_to_device_stream_token(self):
+ return self._device_inbox_id_gen.get_current_token()
+
+ def get_new_messages_for_device(
+ self, user_id, device_id, last_stream_id, current_stream_id, limit=100
+ ):
+ """
+ Args:
+ user_id(str): The recipient user_id.
+ device_id(str): The recipient device_id.
+ current_stream_id(int): The current position of the to device
+ message stream.
+ Returns:
+ Deferred ([dict], int): List of messages for the device and where
+ in the stream the messages got to.
+ """
+ has_changed = self._device_inbox_stream_cache.has_entity_changed(
+ user_id, last_stream_id
+ )
+ if not has_changed:
+ return defer.succeed(([], current_stream_id))
+
+ def get_new_messages_for_device_txn(txn):
+ sql = (
+ "SELECT stream_id, message_json FROM device_inbox"
+ " WHERE user_id = ? AND device_id = ?"
+ " AND ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (
+ user_id, device_id, last_stream_id, current_stream_id, limit
+ ))
+ messages = []
+ for row in txn:
+ stream_pos = row[0]
+ messages.append(json.loads(row[1]))
+ if len(messages) < limit:
+ stream_pos = current_stream_id
+ return (messages, stream_pos)
+
+ return self.runInteraction(
+ "get_new_messages_for_device", get_new_messages_for_device_txn,
+ )
+
+ @defer.inlineCallbacks
+ def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
+ """
+ Args:
+ user_id(str): The recipient user_id.
+ device_id(str): The recipient device_id.
+ up_to_stream_id(int): Where to delete messages up to.
+ Returns:
+ A deferred that resolves to the number of messages deleted.
+ """
+ # If we have cached the last stream id we've deleted up to, we can
+ # check if there is likely to be anything that needs deleting
+ last_deleted_stream_id = self._last_device_delete_cache.get(
+ (user_id, device_id), None
+ )
+ if last_deleted_stream_id:
+ has_changed = self._device_inbox_stream_cache.has_entity_changed(
+ user_id, last_deleted_stream_id
+ )
+ if not has_changed:
+ defer.returnValue(0)
+
+ def delete_messages_for_device_txn(txn):
+ sql = (
+ "DELETE FROM device_inbox"
+ " WHERE user_id = ? AND device_id = ?"
+ " AND stream_id <= ?"
+ )
+ txn.execute(sql, (user_id, device_id, up_to_stream_id))
+ return txn.rowcount
+
+ count = yield self.runInteraction(
+ "delete_messages_for_device", delete_messages_for_device_txn
+ )
+
+ # Update the cache, ensuring that we only ever increase the value
+ last_deleted_stream_id = self._last_device_delete_cache.get(
+ (user_id, device_id), 0
+ )
+ self._last_device_delete_cache[(user_id, device_id)] = max(
+ last_deleted_stream_id, up_to_stream_id
+ )
+
+ defer.returnValue(count)
+
+ def get_new_device_msgs_for_remote(
+ self, destination, last_stream_id, current_stream_id, limit=100
+ ):
+ """
+ Args:
+ destination(str): The name of the remote server.
+ last_stream_id(int|long): The last position of the device message stream
+ that the server sent up to.
+ current_stream_id(int|long): The current position of the device
+ message stream.
+ Returns:
+ Deferred ([dict], int|long): List of messages for the device and where
+ in the stream the messages got to.
+ """
+
+ has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
+ destination, last_stream_id
+ )
+ if not has_changed or last_stream_id == current_stream_id:
+ return defer.succeed(([], current_stream_id))
+
+ def get_new_messages_for_remote_destination_txn(txn):
+ sql = (
+ "SELECT stream_id, messages_json FROM device_federation_outbox"
+ " WHERE destination = ?"
+ " AND ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (
+ destination, last_stream_id, current_stream_id, limit
+ ))
+ messages = []
+ for row in txn:
+ stream_pos = row[0]
+ messages.append(json.loads(row[1]))
+ if len(messages) < limit:
+ stream_pos = current_stream_id
+ return (messages, stream_pos)
+
+ return self.runInteraction(
+ "get_new_device_msgs_for_remote",
+ get_new_messages_for_remote_destination_txn,
+ )
+
+ def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
+ """Used to delete messages when the remote destination acknowledges
+ their receipt.
+
+ Args:
+ destination(str): The destination server_name
+ up_to_stream_id(int): Where to delete messages up to.
+ Returns:
+ A deferred that resolves when the messages have been deleted.
+ """
+ def delete_messages_for_remote_destination_txn(txn):
+ sql = (
+ "DELETE FROM device_federation_outbox"
+ " WHERE destination = ?"
+ " AND stream_id <= ?"
+ )
+ txn.execute(sql, (destination, up_to_stream_id))
+
+ return self.runInteraction(
+ "delete_device_msgs_for_remote",
+ delete_messages_for_remote_destination_txn
+ )
+
+
+class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
@@ -220,93 +380,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
txn.executemany(sql, rows)
- def get_new_messages_for_device(
- self, user_id, device_id, last_stream_id, current_stream_id, limit=100
- ):
- """
- Args:
- user_id(str): The recipient user_id.
- device_id(str): The recipient device_id.
- current_stream_id(int): The current position of the to device
- message stream.
- Returns:
- Deferred ([dict], int): List of messages for the device and where
- in the stream the messages got to.
- """
- has_changed = self._device_inbox_stream_cache.has_entity_changed(
- user_id, last_stream_id
- )
- if not has_changed:
- return defer.succeed(([], current_stream_id))
-
- def get_new_messages_for_device_txn(txn):
- sql = (
- "SELECT stream_id, message_json FROM device_inbox"
- " WHERE user_id = ? AND device_id = ?"
- " AND ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- " LIMIT ?"
- )
- txn.execute(sql, (
- user_id, device_id, last_stream_id, current_stream_id, limit
- ))
- messages = []
- for row in txn:
- stream_pos = row[0]
- messages.append(json.loads(row[1]))
- if len(messages) < limit:
- stream_pos = current_stream_id
- return (messages, stream_pos)
-
- return self.runInteraction(
- "get_new_messages_for_device", get_new_messages_for_device_txn,
- )
-
- @defer.inlineCallbacks
- def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
- """
- Args:
- user_id(str): The recipient user_id.
- device_id(str): The recipient device_id.
- up_to_stream_id(int): Where to delete messages up to.
- Returns:
- A deferred that resolves to the number of messages deleted.
- """
- # If we have cached the last stream id we've deleted up to, we can
- # check if there is likely to be anything that needs deleting
- last_deleted_stream_id = self._last_device_delete_cache.get(
- (user_id, device_id), None
- )
- if last_deleted_stream_id:
- has_changed = self._device_inbox_stream_cache.has_entity_changed(
- user_id, last_deleted_stream_id
- )
- if not has_changed:
- defer.returnValue(0)
-
- def delete_messages_for_device_txn(txn):
- sql = (
- "DELETE FROM device_inbox"
- " WHERE user_id = ? AND device_id = ?"
- " AND stream_id <= ?"
- )
- txn.execute(sql, (user_id, device_id, up_to_stream_id))
- return txn.rowcount
-
- count = yield self.runInteraction(
- "delete_messages_for_device", delete_messages_for_device_txn
- )
-
- # Update the cache, ensuring that we only ever increase the value
- last_deleted_stream_id = self._last_device_delete_cache.get(
- (user_id, device_id), 0
- )
- self._last_device_delete_cache[(user_id, device_id)] = max(
- last_deleted_stream_id, up_to_stream_id
- )
-
- defer.returnValue(count)
-
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
@@ -351,77 +424,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
"get_all_new_device_messages", get_all_new_device_messages_txn
)
- def get_to_device_stream_token(self):
- return self._device_inbox_id_gen.get_current_token()
-
- def get_new_device_msgs_for_remote(
- self, destination, last_stream_id, current_stream_id, limit=100
- ):
- """
- Args:
- destination(str): The name of the remote server.
- last_stream_id(int|long): The last position of the device message stream
- that the server sent up to.
- current_stream_id(int|long): The current position of the device
- message stream.
- Returns:
- Deferred ([dict], int|long): List of messages for the device and where
- in the stream the messages got to.
- """
-
- has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
- destination, last_stream_id
- )
- if not has_changed or last_stream_id == current_stream_id:
- return defer.succeed(([], current_stream_id))
-
- def get_new_messages_for_remote_destination_txn(txn):
- sql = (
- "SELECT stream_id, messages_json FROM device_federation_outbox"
- " WHERE destination = ?"
- " AND ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- " LIMIT ?"
- )
- txn.execute(sql, (
- destination, last_stream_id, current_stream_id, limit
- ))
- messages = []
- for row in txn:
- stream_pos = row[0]
- messages.append(json.loads(row[1]))
- if len(messages) < limit:
- stream_pos = current_stream_id
- return (messages, stream_pos)
-
- return self.runInteraction(
- "get_new_device_msgs_for_remote",
- get_new_messages_for_remote_destination_txn,
- )
-
- def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
- """Used to delete messages when the remote destination acknowledges
- their receipt.
-
- Args:
- destination(str): The destination server_name
- up_to_stream_id(int): Where to delete messages up to.
- Returns:
- A deferred that resolves when the messages have been deleted.
- """
- def delete_messages_for_remote_destination_txn(txn):
- sql = (
- "DELETE FROM device_federation_outbox"
- " WHERE destination = ?"
- " AND stream_id <= ?"
- )
- txn.execute(sql, (destination, up_to_stream_id))
-
- return self.runInteraction(
- "delete_device_msgs_for_remote",
- delete_messages_for_remote_destination_txn
- )
-
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index ecdab34e7d..e716dc1437 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -22,11 +22,10 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import Cache, SQLBaseStore, db_to_json
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-from ._base import Cache, db_to_json
-
logger = logging.getLogger(__name__)
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -34,7 +33,343 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
)
-class DeviceStore(BackgroundUpdateStore):
+class DeviceWorkerStore(SQLBaseStore):
+ def get_device(self, user_id, device_id):
+ """Retrieve a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to retrieve
+ Returns:
+ defer.Deferred for a dict containing the device information
+ Raises:
+ StoreError: if the device is not found
+ """
+ return self._simple_select_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_device",
+ )
+
+ @defer.inlineCallbacks
+ def get_devices_by_user(self, user_id):
+ """Retrieve all of a user's registered devices.
+
+ Args:
+ user_id (str):
+ Returns:
+ defer.Deferred: resolves to a dict from device_id to a dict
+ containing "device_id", "user_id" and "display_name" for each
+ device.
+ """
+ devices = yield self._simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user"
+ )
+
+ defer.returnValue({d["device_id"]: d for d in devices})
+
+ def get_devices_by_remote(self, destination, from_stream_id):
+ """Get stream of updates to send to remote servers
+
+ Returns:
+ (int, list[dict]): current stream id and list of updates
+ """
+ now_stream_id = self._device_list_id_gen.get_current_token()
+
+ has_changed = self._device_list_federation_stream_cache.has_entity_changed(
+ destination, int(from_stream_id)
+ )
+ if not has_changed:
+ return (now_stream_id, [])
+
+ return self.runInteraction(
+ "get_devices_by_remote", self._get_devices_by_remote_txn,
+ destination, from_stream_id, now_stream_id,
+ )
+
+ def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
+ now_stream_id):
+ sql = """
+ SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
+ WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+ GROUP BY user_id, device_id
+ LIMIT 20
+ """
+ txn.execute(
+ sql, (destination, from_stream_id, now_stream_id, False)
+ )
+
+ # maps (user_id, device_id) -> stream_id
+ query_map = {(r[0], r[1]): r[2] for r in txn}
+ if not query_map:
+ return (now_stream_id, [])
+
+ if len(query_map) >= 20:
+ now_stream_id = max(stream_id for stream_id in itervalues(query_map))
+
+ devices = self._get_e2e_device_keys_txn(
+ txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
+ )
+
+ prev_sent_id_sql = """
+ SELECT coalesce(max(stream_id), 0) as stream_id
+ FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ? AND stream_id <= ?
+ """
+
+ results = []
+ for user_id, user_devices in iteritems(devices):
+ # The prev_id for the first row is always the last row before
+ # `from_stream_id`
+ txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
+ rows = txn.fetchall()
+ prev_id = rows[0][0]
+ for device_id, device in iteritems(user_devices):
+ stream_id = query_map[(user_id, device_id)]
+ result = {
+ "user_id": user_id,
+ "device_id": device_id,
+ "prev_id": [prev_id] if prev_id else [],
+ "stream_id": stream_id,
+ }
+
+ prev_id = stream_id
+
+ if device is not None:
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+ else:
+ result["deleted"] = True
+
+ results.append(result)
+
+ return (now_stream_id, results)
+
+ def mark_as_sent_devices_by_remote(self, destination, stream_id):
+ """Mark that updates have successfully been sent to the destination.
+ """
+ return self.runInteraction(
+ "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
+ destination, stream_id,
+ )
+
+ def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+ # We update the device_lists_outbound_last_success with the successfully
+ # poked users. We do the join to see which users need to be inserted and
+ # which updated.
+ sql = """
+ SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ FROM device_lists_outbound_pokes as o
+ LEFT JOIN device_lists_outbound_last_success as s
+ USING (destination, user_id)
+ WHERE destination = ? AND o.stream_id <= ?
+ GROUP BY user_id
+ """
+ txn.execute(sql, (destination, stream_id,))
+ rows = txn.fetchall()
+
+ sql = """
+ UPDATE device_lists_outbound_last_success
+ SET stream_id = ?
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(
+ sql, ((row[1], destination, row[0],) for row in rows if row[2])
+ )
+
+ sql = """
+ INSERT INTO device_lists_outbound_last_success
+ (destination, user_id, stream_id) VALUES (?, ?, ?)
+ """
+ txn.executemany(
+ sql, ((destination, row[0], row[1],) for row in rows if not row[2])
+ )
+
+ # Delete all sent outbound pokes
+ sql = """
+ DELETE FROM device_lists_outbound_pokes
+ WHERE destination = ? AND stream_id <= ?
+ """
+ txn.execute(sql, (destination, stream_id,))
+
+ def get_device_stream_token(self):
+ return self._device_list_id_gen.get_current_token()
+
+ @defer.inlineCallbacks
+ def get_user_devices_from_cache(self, query_list):
+ """Get the devices (and keys if any) for remote users from the cache.
+
+ Args:
+ query_list(list): List of (user_id, device_ids), if device_ids is
+ falsey then return all device ids for that user.
+
+ Returns:
+ (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
+ a set of user_ids and results_map is a mapping of
+ user_id -> device_id -> device_info
+ """
+ user_ids = set(user_id for user_id, _ in query_list)
+ user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ user_ids_in_cache = set(
+ user_id for user_id, stream_id in user_map.items() if stream_id
+ )
+ user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+ results = {}
+ for user_id, device_id in query_list:
+ if user_id not in user_ids_in_cache:
+ continue
+
+ if device_id:
+ device = yield self._get_cached_user_device(user_id, device_id)
+ results.setdefault(user_id, {})[device_id] = device
+ else:
+ results[user_id] = yield self._get_cached_devices_for_user(user_id)
+
+ defer.returnValue((user_ids_not_in_cache, results))
+
+ @cachedInlineCallbacks(num_args=2, tree=True)
+ def _get_cached_user_device(self, user_id, device_id):
+ content = yield self._simple_select_one_onecol(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ retcol="content",
+ desc="_get_cached_user_device",
+ )
+ defer.returnValue(db_to_json(content))
+
+ @cachedInlineCallbacks()
+ def _get_cached_devices_for_user(self, user_id):
+ devices = yield self._simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("device_id", "content"),
+ desc="_get_cached_devices_for_user",
+ )
+ defer.returnValue({
+ device["device_id"]: db_to_json(device["content"])
+ for device in devices
+ })
+
+ def get_devices_with_keys_by_user(self, user_id):
+ """Get all devices (with any device keys) for a user
+
+ Returns:
+ (stream_id, devices)
+ """
+ return self.runInteraction(
+ "get_devices_with_keys_by_user",
+ self._get_devices_with_keys_by_user_txn, user_id,
+ )
+
+ def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+ now_stream_id = self._device_list_id_gen.get_current_token()
+
+ devices = self._get_e2e_device_keys_txn(
+ txn, [(user_id, None)], include_all_devices=True
+ )
+
+ if devices:
+ user_devices = devices[user_id]
+ results = []
+ for device_id, device in iteritems(user_devices):
+ result = {
+ "device_id": device_id,
+ }
+
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
+ @defer.inlineCallbacks
+ def get_user_whose_devices_changed(self, from_key):
+ """Get set of users whose devices have changed since `from_key`.
+ """
+ from_key = int(from_key)
+ changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
+ if changed is not None:
+ defer.returnValue(set(changed))
+
+ sql = """
+ SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
+ """
+ rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+ defer.returnValue(set(row[0] for row in rows))
+
+ def get_all_device_list_changes_for_remotes(self, from_key, to_key):
+ """Return a list of `(stream_id, user_id, destination)` which is the
+ combined list of changes to devices, and which destinations need to be
+ poked. `destination` may be None if no destinations need to be poked.
+ """
+ # We do a group by here as there can be a large number of duplicate
+ # entries, since we throw away device IDs.
+ sql = """
+ SELECT MAX(stream_id) AS stream_id, user_id, destination
+ FROM device_lists_stream
+ LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+ WHERE ? < stream_id AND stream_id <= ?
+ GROUP BY user_id, destination
+ """
+ return self._execute(
+ "get_all_device_list_changes_for_remotes", None,
+ sql, from_key, to_key
+ )
+
+ @cached(max_entries=10000)
+ def get_device_list_last_stream_id_for_remote(self, user_id):
+ """Get the last stream_id we got for a user. May be None if we haven't
+ got any information for them.
+ """
+ return self._simple_select_one_onecol(
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ retcol="stream_id",
+ desc="get_device_list_last_stream_id_for_remote",
+ allow_none=True,
+ )
+
+ @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
+ list_name="user_ids", inlineCallbacks=True)
+ def get_device_list_last_stream_id_for_remotes(self, user_ids):
+ rows = yield self._simple_select_many_batch(
+ table="device_lists_remote_extremeties",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id", "stream_id",),
+ desc="get_device_list_last_stream_id_for_remotes",
+ )
+
+ results = {user_id: None for user_id in user_ids}
+ results.update({
+ row["user_id"]: row["stream_id"] for row in rows
+ })
+
+ defer.returnValue(results)
+
+
+class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
@@ -121,24 +456,6 @@ class DeviceStore(BackgroundUpdateStore):
initial_device_display_name, e)
raise StoreError(500, "Problem storing device.")
- def get_device(self, user_id, device_id):
- """Retrieve a device.
-
- Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to retrieve
- Returns:
- defer.Deferred for a dict containing the device information
- Raises:
- StoreError: if the device is not found
- """
- return self._simple_select_one(
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_device",
- )
-
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
"""Delete a device.
@@ -203,57 +520,6 @@ class DeviceStore(BackgroundUpdateStore):
)
@defer.inlineCallbacks
- def get_devices_by_user(self, user_id):
- """Retrieve all of a user's registered devices.
-
- Args:
- user_id (str):
- Returns:
- defer.Deferred: resolves to a dict from device_id to a dict
- containing "device_id", "user_id" and "display_name" for each
- device.
- """
- devices = yield self._simple_select_list(
- table="devices",
- keyvalues={"user_id": user_id},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_devices_by_user"
- )
-
- defer.returnValue({d["device_id"]: d for d in devices})
-
- @cached(max_entries=10000)
- def get_device_list_last_stream_id_for_remote(self, user_id):
- """Get the last stream_id we got for a user. May be None if we haven't
- got any information for them.
- """
- return self._simple_select_one_onecol(
- table="device_lists_remote_extremeties",
- keyvalues={"user_id": user_id},
- retcol="stream_id",
- desc="get_device_list_remote_extremity",
- allow_none=True,
- )
-
- @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
- list_name="user_ids", inlineCallbacks=True)
- def get_device_list_last_stream_id_for_remotes(self, user_ids):
- rows = yield self._simple_select_many_batch(
- table="device_lists_remote_extremeties",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id", "stream_id",),
- desc="get_user_devices_from_cache",
- )
-
- results = {user_id: None for user_id in user_ids}
- results.update({
- row["user_id"]: row["stream_id"] for row in rows
- })
-
- defer.returnValue(results)
-
- @defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
@@ -405,268 +671,6 @@ class DeviceStore(BackgroundUpdateStore):
lock=False,
)
- def get_devices_by_remote(self, destination, from_stream_id):
- """Get stream of updates to send to remote servers
-
- Returns:
- (int, list[dict]): current stream id and list of updates
- """
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- has_changed = self._device_list_federation_stream_cache.has_entity_changed(
- destination, int(from_stream_id)
- )
- if not has_changed:
- return (now_stream_id, [])
-
- return self.runInteraction(
- "get_devices_by_remote", self._get_devices_by_remote_txn,
- destination, from_stream_id, now_stream_id,
- )
-
- def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
- now_stream_id):
- sql = """
- SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
- WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
- GROUP BY user_id, device_id
- LIMIT 20
- """
- txn.execute(
- sql, (destination, from_stream_id, now_stream_id, False)
- )
-
- # maps (user_id, device_id) -> stream_id
- query_map = {(r[0], r[1]): r[2] for r in txn}
- if not query_map:
- return (now_stream_id, [])
-
- if len(query_map) >= 20:
- now_stream_id = max(stream_id for stream_id in itervalues(query_map))
-
- devices = self._get_e2e_device_keys_txn(
- txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
- )
-
- prev_sent_id_sql = """
- SELECT coalesce(max(stream_id), 0) as stream_id
- FROM device_lists_outbound_last_success
- WHERE destination = ? AND user_id = ? AND stream_id <= ?
- """
-
- results = []
- for user_id, user_devices in iteritems(devices):
- # The prev_id for the first row is always the last row before
- # `from_stream_id`
- txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
- rows = txn.fetchall()
- prev_id = rows[0][0]
- for device_id, device in iteritems(user_devices):
- stream_id = query_map[(user_id, device_id)]
- result = {
- "user_id": user_id,
- "device_id": device_id,
- "prev_id": [prev_id] if prev_id else [],
- "stream_id": stream_id,
- }
-
- prev_id = stream_id
-
- if device is not None:
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
- else:
- result["deleted"] = True
-
- results.append(result)
-
- return (now_stream_id, results)
-
- @defer.inlineCallbacks
- def get_user_devices_from_cache(self, query_list):
- """Get the devices (and keys if any) for remote users from the cache.
-
- Args:
- query_list(list): List of (user_id, device_ids), if device_ids is
- falsey then return all device ids for that user.
-
- Returns:
- (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
- a set of user_ids and results_map is a mapping of
- user_id -> device_id -> device_info
- """
- user_ids = set(user_id for user_id, _ in query_list)
- user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
- user_ids_in_cache = set(
- user_id for user_id, stream_id in user_map.items() if stream_id
- )
- user_ids_not_in_cache = user_ids - user_ids_in_cache
-
- results = {}
- for user_id, device_id in query_list:
- if user_id not in user_ids_in_cache:
- continue
-
- if device_id:
- device = yield self._get_cached_user_device(user_id, device_id)
- results.setdefault(user_id, {})[device_id] = device
- else:
- results[user_id] = yield self._get_cached_devices_for_user(user_id)
-
- defer.returnValue((user_ids_not_in_cache, results))
-
- @cachedInlineCallbacks(num_args=2, tree=True)
- def _get_cached_user_device(self, user_id, device_id):
- content = yield self._simple_select_one_onecol(
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- retcol="content",
- desc="_get_cached_user_device",
- )
- defer.returnValue(db_to_json(content))
-
- @cachedInlineCallbacks()
- def _get_cached_devices_for_user(self, user_id):
- devices = yield self._simple_select_list(
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- },
- retcols=("device_id", "content"),
- desc="_get_cached_devices_for_user",
- )
- defer.returnValue({
- device["device_id"]: db_to_json(device["content"])
- for device in devices
- })
-
- def get_devices_with_keys_by_user(self, user_id):
- """Get all devices (with any device keys) for a user
-
- Returns:
- (stream_id, devices)
- """
- return self.runInteraction(
- "get_devices_with_keys_by_user",
- self._get_devices_with_keys_by_user_txn, user_id,
- )
-
- def _get_devices_with_keys_by_user_txn(self, txn, user_id):
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- devices = self._get_e2e_device_keys_txn(
- txn, [(user_id, None)], include_all_devices=True
- )
-
- if devices:
- user_devices = devices[user_id]
- results = []
- for device_id, device in iteritems(user_devices):
- result = {
- "device_id": device_id,
- }
-
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
-
- results.append(result)
-
- return now_stream_id, results
-
- return now_stream_id, []
-
- def mark_as_sent_devices_by_remote(self, destination, stream_id):
- """Mark that updates have successfully been sent to the destination.
- """
- return self.runInteraction(
- "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
- destination, stream_id,
- )
-
- def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
- # We update the device_lists_outbound_last_success with the successfully
- # poked users. We do the join to see which users need to be inserted and
- # which updated.
- sql = """
- SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
- FROM device_lists_outbound_pokes as o
- LEFT JOIN device_lists_outbound_last_success as s
- USING (destination, user_id)
- WHERE destination = ? AND o.stream_id <= ?
- GROUP BY user_id
- """
- txn.execute(sql, (destination, stream_id,))
- rows = txn.fetchall()
-
- sql = """
- UPDATE device_lists_outbound_last_success
- SET stream_id = ?
- WHERE destination = ? AND user_id = ?
- """
- txn.executemany(
- sql, ((row[1], destination, row[0],) for row in rows if row[2])
- )
-
- sql = """
- INSERT INTO device_lists_outbound_last_success
- (destination, user_id, stream_id) VALUES (?, ?, ?)
- """
- txn.executemany(
- sql, ((destination, row[0], row[1],) for row in rows if not row[2])
- )
-
- # Delete all sent outbound pokes
- sql = """
- DELETE FROM device_lists_outbound_pokes
- WHERE destination = ? AND stream_id <= ?
- """
- txn.execute(sql, (destination, stream_id,))
-
- @defer.inlineCallbacks
- def get_user_whose_devices_changed(self, from_key):
- """Get set of users whose devices have changed since `from_key`.
- """
- from_key = int(from_key)
- changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
- if changed is not None:
- defer.returnValue(set(changed))
-
- sql = """
- SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
- """
- rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
- defer.returnValue(set(row[0] for row in rows))
-
- def get_all_device_list_changes_for_remotes(self, from_key, to_key):
- """Return a list of `(stream_id, user_id, destination)` which is the
- combined list of changes to devices, and which destinations need to be
- poked. `destination` may be None if no destinations need to be poked.
- """
- # We do a group by here as there can be a large number of duplicate
- # entries, since we throw away device IDs.
- sql = """
- SELECT MAX(stream_id) AS stream_id, user_id, destination
- FROM device_lists_stream
- LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
- WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id, destination
- """
- return self._execute(
- "get_all_device_list_changes_for_remotes", None,
- sql, from_key, to_key
- )
-
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
@@ -732,9 +736,6 @@ class DeviceStore(BackgroundUpdateStore):
]
)
- def get_device_stream_token(self):
- return self._device_list_id_gen.get_current_token()
-
def _prune_old_outbound_device_pokes(self):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. We keep one entry per
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2a0f6cfca9..e381e472a2 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -23,49 +23,7 @@ from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore, db_to_json
-class EndToEndKeyStore(SQLBaseStore):
- def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
- """Stores device keys for a device. Returns whether there was a change
- or the keys were already in the database.
- """
- def _set_e2e_device_keys_txn(txn):
- old_key_json = self._simple_select_one_onecol_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- retcol="key_json",
- allow_none=True,
- )
-
- # In py3 we need old_key_json to match new_key_json type. The DB
- # returns unicode while encode_canonical_json returns bytes.
- new_key_json = encode_canonical_json(device_keys).decode("utf-8")
-
- if old_key_json == new_key_json:
- return False
-
- self._simple_upsert_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- values={
- "ts_added_ms": time_now,
- "key_json": new_key_json,
- }
- )
-
- return True
-
- return self.runInteraction(
- "set_e2e_device_keys", _set_e2e_device_keys_txn
- )
-
+class EndToEndKeyWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_e2e_device_keys(
self, query_list, include_all_devices=False,
@@ -238,6 +196,50 @@ class EndToEndKeyStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
+
+class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+ def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+ """Stores device keys for a device. Returns whether there was a change
+ or the keys were already in the database.
+ """
+ def _set_e2e_device_keys_txn(txn):
+ old_key_json = self._simple_select_one_onecol_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ retcol="key_json",
+ allow_none=True,
+ )
+
+ # In py3 we need old_key_json to match new_key_json type. The DB
+ # returns unicode while encode_canonical_json returns bytes.
+ new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+
+ if old_key_json == new_key_json:
+ return False
+
+ self._simple_upsert_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={
+ "ts_added_ms": time_now,
+ "key_json": new_key_json,
+ }
+ )
+
+ return True
+
+ return self.runInteraction(
+ "set_e2e_device_keys", _set_e2e_device_keys_txn
+ )
+
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
def _claim_e2e_one_time_keys(txn):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 38809ed0fc..a8d90456e3 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -442,6 +442,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
event_results.reverse()
return event_results
+ @defer.inlineCallbacks
+ def get_successor_events(self, event_ids):
+ """Fetch all events that have the given events as a prev event
+
+ Args:
+ event_ids (iterable[str])
+
+ Returns:
+ Deferred[list[str]]
+ """
+ rows = yield self._simple_select_many_batch(
+ table="event_edges",
+ column="prev_event_id",
+ iterable=event_ids,
+ retcols=("event_id",),
+ desc="get_successor_events"
+ )
+
+ defer.returnValue([
+ row["event_id"] for row in rows
+ ])
+
class EventFederationStore(EventFederationWorkerStore):
""" Responsible for storing and serving up the various graphs associated
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 06db9e56e6..428300ea0a 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -537,6 +537,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
new_events = [
event for event, ctx in event_contexts
if not event.internal_metadata.is_outlier() and not ctx.rejected
+ and not event.internal_metadata.is_soft_failed()
]
# start with the existing forward extremities
@@ -1406,21 +1407,6 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
values=state_values,
)
- self._simple_insert_many_txn(
- txn,
- table="event_edges",
- values=[
- {
- "event_id": event.event_id,
- "prev_event_id": prev_id,
- "room_id": event.room_id,
- "is_state": True,
- }
- for event, _ in state_events_and_contexts
- for prev_id, _ in event.prev_state
- ],
- )
-
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 6a5028961d..4b8438c3e9 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -186,6 +186,63 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
defer.returnValue(results)
@defer.inlineCallbacks
+ def move_push_rule_from_room_to_room(
+ self, new_room_id, user_id, rule,
+ ):
+ """Move a single push rule from one room to another for a specific user.
+
+ Args:
+ new_room_id (str): ID of the new room.
+ user_id (str): ID of user the push rule belongs to.
+ rule (Dict): A push rule.
+ """
+ # Create new rule id
+ rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1])
+ new_rule_id = rule_id_scope + "/" + new_room_id
+
+ # Change room id in each condition
+ for condition in rule.get("conditions", []):
+ if condition.get("key") == "room_id":
+ condition["pattern"] = new_room_id
+
+ # Add the rule for the new room
+ yield self.add_push_rule(
+ user_id=user_id,
+ rule_id=new_rule_id,
+ priority_class=rule["priority_class"],
+ conditions=rule["conditions"],
+ actions=rule["actions"],
+ )
+
+ # Delete push rule for the old room
+ yield self.delete_push_rule(user_id, rule["rule_id"])
+
+ @defer.inlineCallbacks
+ def move_push_rules_from_room_to_room_for_user(
+ self, old_room_id, new_room_id, user_id,
+ ):
+ """Move all of the push rules from one room to another for a specific
+ user.
+
+ Args:
+ old_room_id (str): ID of the old room.
+ new_room_id (str): ID of the new room.
+ user_id (str): ID of user to copy push rules for.
+ """
+ # Retrieve push rules for this user
+ user_push_rules = yield self.get_push_rules_for_user(user_id)
+
+ # Get rules relating to the old room, move them to the new room, then
+ # delete them from the old room
+ for rule in user_push_rules:
+ conditions = rule.get("conditions", [])
+ if any((c.get("key") == "room_id" and
+ c.get("pattern") == old_room_id) for c in conditions):
+ self.move_push_rule_from_room_to_room(
+ new_room_id, user_id, rule,
+ )
+
+ @defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
state_group = context.state_group
if not state_group:
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 0ac665e967..0fd1ccc40a 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -346,15 +346,23 @@ class ReceiptsStore(ReceiptsWorkerStore):
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
+ """Inserts a read-receipt into the database if it's newer than the current RR
+
+ Returns: int|None
+ None if the RR is older than the current RR
+ otherwise, the rx timestamp of the event that the RR corresponds to
+ (or 0 if the event is unknown)
+ """
res = self._simple_select_one_txn(
txn,
table="events",
- retcols=["topological_ordering", "stream_ordering"],
+ retcols=["stream_ordering", "received_ts"],
keyvalues={"event_id": event_id},
allow_none=True
)
stream_ordering = int(res["stream_ordering"]) if res else None
+ rx_ts = res["received_ts"] if res else 0
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
@@ -373,7 +381,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"one for later event %s",
event_id, eid,
)
- return False
+ return None
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
@@ -429,7 +437,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
stream_ordering=stream_ordering,
)
- return True
+ return rx_ts
@defer.inlineCallbacks
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
@@ -466,7 +474,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
- have_persisted = yield self.runInteraction(
+ event_ts = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id, receipt_type, user_id, linearized_event_id,
@@ -474,8 +482,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
stream_id=stream_id,
)
- if not have_persisted:
- defer.returnValue(None)
+ if event_ts is None:
+ defer.returnValue(None)
+
+ now = self._clock.time_msec()
+ logger.debug(
+ "RR for event %s in %s (%i ms old)",
+ linearized_event_id, room_id, now - event_ts,
+ )
yield self.insert_graph_receipt(
room_id, receipt_type, user_id, event_ids, data
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 9b9572890b..9b6c28892c 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -295,6 +295,39 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret['user_id']
return None
+ @defer.inlineCallbacks
+ def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+ yield self._simple_upsert("user_threepids", {
+ "medium": medium,
+ "address": address,
+ }, {
+ "user_id": user_id,
+ "validated_at": validated_at,
+ "added_at": added_at,
+ })
+
+ @defer.inlineCallbacks
+ def user_get_threepids(self, user_id):
+ ret = yield self._simple_select_list(
+ "user_threepids", {
+ "user_id": user_id
+ },
+ ['medium', 'address', 'validated_at', 'added_at'],
+ 'user_get_threepids'
+ )
+ defer.returnValue(ret)
+
+ def user_delete_threepid(self, user_id, medium, address):
+ return self._simple_delete(
+ "user_threepids",
+ keyvalues={
+ "user_id": user_id,
+ "medium": medium,
+ "address": address,
+ },
+ desc="user_delete_threepids",
+ )
+
class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):
@@ -633,39 +666,6 @@ class RegistrationStore(RegistrationWorkerStore,
defer.returnValue(res if res else False)
@defer.inlineCallbacks
- def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self._simple_upsert("user_threepids", {
- "medium": medium,
- "address": address,
- }, {
- "user_id": user_id,
- "validated_at": validated_at,
- "added_at": added_at,
- })
-
- @defer.inlineCallbacks
- def user_get_threepids(self, user_id):
- ret = yield self._simple_select_list(
- "user_threepids", {
- "user_id": user_id
- },
- ['medium', 'address', 'validated_at', 'added_at'],
- 'user_get_threepids'
- )
- defer.returnValue(ret)
-
- def user_delete_threepid(self, user_id, medium, address):
- return self._simple_delete(
- "user_threepids",
- keyvalues={
- "user_id": user_id,
- "medium": medium,
- "address": address,
- },
- desc="user_delete_threepids",
- )
-
- @defer.inlineCallbacks
def save_or_get_3pid_guest_access_token(
self, medium, address, access_token, inviter_user_id
):
diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/schema/delta/53/user_share.sql
new file mode 100644
index 0000000000..14424ded0c
--- /dev/null
+++ b/synapse/storage/schema/delta/53/user_share.sql
@@ -0,0 +1,47 @@
+/* Copyright 2017 Vector Creations Ltd, 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Old disused version of the tables below.
+DROP TABLE IF EXISTS users_who_share_rooms;
+
+-- This is no longer used because it's duplicated by the users_who_share_public_rooms
+DROP TABLE IF EXISTS users_in_public_rooms;
+
+-- Tables keeping track of what users share rooms. This is a map of local users
+-- to local or remote users, per room. Remote users cannot be in the user_id
+-- column, only the other_user_id column. There are two tables, one for public
+-- rooms and those for private rooms.
+CREATE TABLE IF NOT EXISTS users_who_share_public_rooms (
+ user_id TEXT NOT NULL,
+ other_user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS users_who_share_private_rooms (
+ user_id TEXT NOT NULL,
+ other_user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX users_who_share_public_rooms_u_idx ON users_who_share_public_rooms(user_id, other_user_id, room_id);
+CREATE INDEX users_who_share_public_rooms_r_idx ON users_who_share_public_rooms(room_id);
+CREATE INDEX users_who_share_public_rooms_o_idx ON users_who_share_public_rooms(other_user_id);
+
+CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id);
+CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id);
+CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id);
+
+-- Make sure that we populate the tables initially by resetting the stream ID
+UPDATE user_directory_stream_pos SET stream_id = NULL;
diff --git a/synapse/storage/schema/full_schemas/11/event_edges.sql b/synapse/storage/schema/full_schemas/11/event_edges.sql
index 52eec88357..bccd1c6f74 100644
--- a/synapse/storage/schema/full_schemas/11/event_edges.sql
+++ b/synapse/storage/schema/full_schemas/11/event_edges.sql
@@ -37,6 +37,8 @@ CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT NOT NULL,
prev_event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
+ -- We no longer insert prev_state into this table, so all new rows will have
+ -- is_state as false.
is_state BOOL NOT NULL,
UNIQUE (event_id, prev_event_id, room_id, is_state)
);
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d6cfdba519..580fafeb3a 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -191,6 +191,25 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
order='DESC'):
+ """Get new room events in stream ordering since `from_key`.
+
+ Args:
+ room_id (str)
+ from_key (str): Token from which no events are returned before
+ to_key (str): Token from which no events are returned after. (This
+ is typically the current stream token)
+ limit (int): Maximum number of events to return
+ order (str): Either "DESC" or "ASC". Determines which events are
+ returned when the result is limited. If "DESC" then the most
+ recent `limit` events are returned, otherwise returns the
+ oldest `limit` events.
+
+ Returns:
+ Deferred[dict[str,tuple[list[FrozenEvent], str]]]
+ A map from room id to a tuple containing:
+ - list of recent events in the room
+ - stream ordering key for the start of the chunk of events returned.
+ """
from_id = RoomStreamToken.parse_stream_token(from_key).stream
room_ids = yield self._events_stream_cache.get_entities_changed(
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index fea866c043..2317d22ed6 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -63,31 +63,14 @@ class UserDirectoryStore(SQLBaseStore):
defer.returnValue(False)
- @defer.inlineCallbacks
- def add_users_to_public_room(self, room_id, user_ids):
- """Add user to the list of users in public rooms
-
- Args:
- room_id (str): A room_id that all users are in that is world_readable
- or publically joinable
- user_ids (list(str)): Users to add
- """
- yield self._simple_insert_many(
- table="users_in_public_rooms",
- values=[{"user_id": user_id, "room_id": room_id} for user_id in user_ids],
- desc="add_users_to_public_room",
- )
- for user_id in user_ids:
- self.get_user_in_public_room.invalidate((user_id,))
-
- def add_profiles_to_user_dir(self, room_id, users_with_profile):
+ def add_profiles_to_user_dir(self, users_with_profile):
"""Add profiles to the user directory
Args:
- room_id (str): A room_id that all users are joined to
users_with_profile (dict): Users to add to directory in the form of
mapping of user_id -> ProfileInfo
"""
+
if isinstance(self.database_engine, PostgresEngine):
# We weight the loclpart most highly, then display name and finally
# server name
@@ -113,7 +96,7 @@ class UserDirectoryStore(SQLBaseStore):
INSERT INTO user_directory_search(user_id, value)
VALUES (?,?)
"""
- args = (
+ args = tuple(
(
user_id,
"%s %s" % (user_id, p.display_name) if p.display_name else user_id,
@@ -132,7 +115,7 @@ class UserDirectoryStore(SQLBaseStore):
values=[
{
"user_id": user_id,
- "room_id": room_id,
+ "room_id": None,
"display_name": profile.display_name,
"avatar_url": profile.avatar_url,
}
@@ -250,16 +233,6 @@ class UserDirectoryStore(SQLBaseStore):
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
- @defer.inlineCallbacks
- def update_user_in_public_user_list(self, user_id, room_id):
- yield self._simple_update_one(
- table="users_in_public_rooms",
- keyvalues={"user_id": user_id},
- updatevalues={"room_id": room_id},
- desc="update_user_in_public_user_list",
- )
- self.get_user_in_public_room.invalidate((user_id,))
-
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self._simple_delete_txn(
@@ -269,62 +242,50 @@ class UserDirectoryStore(SQLBaseStore):
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
- txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
+ txn,
+ table="users_who_share_public_rooms",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_public_rooms",
+ keyvalues={"other_user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"other_user_id": user_id},
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- txn.call_after(self.get_user_in_public_room.invalidate, (user_id,))
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
@defer.inlineCallbacks
- def remove_from_user_in_public_room(self, user_id):
- yield self._simple_delete(
- table="users_in_public_rooms",
- keyvalues={"user_id": user_id},
- desc="remove_from_user_in_public_room",
- )
- self.get_user_in_public_room.invalidate((user_id,))
-
- def get_users_in_public_due_to_room(self, room_id):
- """Get all user_ids that are in the room directory because they're
- in the given room_id
- """
- return self._simple_select_onecol(
- table="users_in_public_rooms",
- keyvalues={"room_id": room_id},
- retcol="user_id",
- desc="get_users_in_public_due_to_room",
- )
-
- @defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
- user_ids_dir = yield self._simple_select_onecol(
- table="user_directory",
- keyvalues={"room_id": room_id},
- retcol="user_id",
- desc="get_users_in_dir_due_to_room",
- )
-
- user_ids_pub = yield self._simple_select_onecol(
- table="users_in_public_rooms",
+ user_ids_share_pub = yield self._simple_select_onecol(
+ table="users_who_share_public_rooms",
keyvalues={"room_id": room_id},
- retcol="user_id",
+ retcol="other_user_id",
desc="get_users_in_dir_due_to_room",
)
- user_ids_share = yield self._simple_select_onecol(
- table="users_who_share_rooms",
+ user_ids_share_priv = yield self._simple_select_onecol(
+ table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
- retcol="user_id",
+ retcol="other_user_id",
desc="get_users_in_dir_due_to_room",
)
- user_ids = set(user_ids_dir)
- user_ids.update(user_ids_pub)
- user_ids.update(user_ids_share)
+ user_ids = set(user_ids_share_pub)
+ user_ids.update(user_ids_share_priv)
defer.returnValue(user_ids)
@@ -351,7 +312,7 @@ class UserDirectoryStore(SQLBaseStore):
defer.returnValue([name for name, in rows])
def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
- """Insert entries into the users_who_share_rooms table. The first
+ """Insert entries into the users_who_share_*_rooms table. The first
user should be a local user.
Args:
@@ -361,109 +322,71 @@ class UserDirectoryStore(SQLBaseStore):
"""
def _add_users_who_share_room_txn(txn):
- self._simple_insert_many_txn(
+
+ if share_private:
+ tbl = "users_who_share_private_rooms"
+ else:
+ tbl = "users_who_share_public_rooms"
+
+ self._simple_upsert_many_txn(
txn,
- table="users_who_share_rooms",
- values=[
- {
- "user_id": user_id,
- "other_user_id": other_user_id,
- "room_id": room_id,
- "share_private": share_private,
- }
+ table=tbl,
+ key_names=["user_id", "other_user_id", "room_id"],
+ key_values=[
+ (user_id, other_user_id, room_id)
for user_id, other_user_id in user_id_tuples
],
+ value_names=(),
+ value_values=None,
)
for user_id, other_user_id in user_id_tuples:
txn.call_after(
self.get_users_who_share_room_from_dir.invalidate, (user_id,)
)
- txn.call_after(
- self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
- )
return self.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
- def update_users_who_share_room(self, room_id, share_private, user_id_sets):
- """Updates entries in the users_who_share_rooms table. The first
- user should be a local user.
-
- Args:
- room_id (str)
- share_private (bool): Is the room private
- user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ def remove_user_who_share_room(self, user_id, room_id):
"""
-
- def _update_users_who_share_room_txn(txn):
- sql = """
- UPDATE users_who_share_rooms
- SET room_id = ?, share_private = ?
- WHERE user_id = ? AND other_user_id = ?
- """
- txn.executemany(
- sql, ((room_id, share_private, uid, oid) for uid, oid in user_id_sets)
- )
- for user_id, other_user_id in user_id_sets:
- txn.call_after(
- self.get_users_who_share_room_from_dir.invalidate, (user_id,)
- )
- txn.call_after(
- self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
- )
-
- return self.runInteraction(
- "update_users_who_share_room", _update_users_who_share_room_txn
- )
-
- def remove_user_who_share_room(self, user_id, other_user_id):
- """Deletes entries in the users_who_share_rooms table. The first
+ Deletes entries in the users_who_share_*_rooms table. The first
user should be a local user.
Args:
+ user_id (str)
room_id (str)
- share_private (bool): Is the room private
- user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
"""
def _remove_user_who_share_room_txn(txn):
self._simple_delete_txn(
txn,
- table="users_who_share_rooms",
- keyvalues={"user_id": user_id, "other_user_id": other_user_id},
+ table="users_who_share_private_rooms",
+ keyvalues={"user_id": user_id, "room_id": room_id},
)
- txn.call_after(
- self.get_users_who_share_room_from_dir.invalidate, (user_id,)
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"other_user_id": user_id, "room_id": room_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_public_rooms",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_public_rooms",
+ keyvalues={"other_user_id": user_id, "room_id": room_id},
)
txn.call_after(
- self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
+ self.get_users_who_share_room_from_dir.invalidate, (user_id,)
)
return self.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
- @cached(max_entries=500000)
- def get_if_users_share_a_room(self, user_id, other_user_id):
- """Gets if users share a room.
-
- Args:
- user_id (str): Must be a local user_id
- other_user_id (str)
-
- Returns:
- bool|None: None if they don't share a room, otherwise whether they
- share a private room or not.
- """
- return self._simple_select_one_onecol(
- table="users_who_share_rooms",
- keyvalues={"user_id": user_id, "other_user_id": other_user_id},
- retcol="share_private",
- allow_none=True,
- desc="get_if_users_share_a_room",
- )
-
@cachedInlineCallbacks(max_entries=500000, iterable=True)
def get_users_who_share_room_from_dir(self, user_id):
"""Returns the set of users who share a room with `user_id`
@@ -472,32 +395,29 @@ class UserDirectoryStore(SQLBaseStore):
user_id(str): Must be a local user
Returns:
- dict: user_id -> share_private mapping
+ list: user_id
"""
- rows = yield self._simple_select_list(
- table="users_who_share_rooms",
+ rows = yield self._simple_select_onecol(
+ table="users_who_share_private_rooms",
+ keyvalues={"user_id": user_id},
+ retcol="other_user_id",
+ desc="get_users_who_share_room_with_user",
+ )
+
+ pub_rows = yield self._simple_select_onecol(
+ table="users_who_share_public_rooms",
keyvalues={"user_id": user_id},
- retcols=("other_user_id", "share_private"),
+ retcol="other_user_id",
desc="get_users_who_share_room_with_user",
)
- defer.returnValue({row["other_user_id"]: row["share_private"] for row in rows})
+ users = set(pub_rows)
+ users.update(rows)
- def get_users_in_share_dir_with_room_id(self, user_id, room_id):
- """Get all user tuples that are in the users_who_share_rooms due to the
- given room_id.
+ # Remove the user themselves from this list.
+ users.discard(user_id)
- Returns:
- [(user_id, other_user_id)]: where one of the two will match the given
- user_id.
- """
- sql = """
- SELECT user_id, other_user_id FROM users_who_share_rooms
- WHERE room_id = ? AND (user_id = ? OR other_user_id = ?)
- """
- return self._execute(
- "get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id
- )
+ defer.returnValue(list(users))
@defer.inlineCallbacks
def get_rooms_in_common_for_users(self, user_id, other_user_id):
@@ -532,12 +452,10 @@ class UserDirectoryStore(SQLBaseStore):
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
- txn.execute("DELETE FROM users_in_public_rooms")
- txn.execute("DELETE FROM users_who_share_rooms")
+ txn.execute("DELETE FROM users_who_share_public_rooms")
+ txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
- txn.call_after(self.get_user_in_public_room.invalidate_all)
txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
- txn.call_after(self.get_if_users_share_a_room.invalidate_all)
return self.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
@@ -548,21 +466,11 @@ class UserDirectoryStore(SQLBaseStore):
return self._simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
- retcols=("room_id", "display_name", "avatar_url"),
+ retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
)
- @cached()
- def get_user_in_public_room(self, user_id):
- return self._simple_select_one(
- table="users_in_public_rooms",
- keyvalues={"user_id": user_id},
- retcols=("room_id",),
- allow_none=True,
- desc="get_user_in_public_room",
- )
-
def get_user_directory_stream_pos(self):
return self._simple_select_one_onecol(
table="user_directory_stream_pos",
@@ -660,14 +568,15 @@ class UserDirectoryStore(SQLBaseStore):
where_clause = "1=1"
else:
join_clause = """
- LEFT JOIN users_in_public_rooms AS p USING (user_id)
LEFT JOIN (
- SELECT other_user_id AS user_id FROM users_who_share_rooms
- WHERE user_id = ? AND share_private
- ) AS s USING (user_id)
+ SELECT other_user_id AS user_id FROM users_who_share_public_rooms
+ UNION
+ SELECT other_user_id AS user_id FROM users_who_share_private_rooms
+ WHERE user_id = ?
+ ) AS p USING (user_id)
"""
join_args = (user_id,)
- where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
+ where_clause = "p.user_id IS NOT NULL"
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@@ -686,7 +595,7 @@ class UserDirectoryStore(SQLBaseStore):
%s
AND vector @@ to_tsquery('english', ?)
ORDER BY
- (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
+ (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
* (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
* (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
* (
|