diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a0333d5309..7e3903859b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -767,18 +767,25 @@ class SQLBaseStore(object):
"""
allvalues = {}
allvalues.update(keyvalues)
- allvalues.update(values)
allvalues.update(insertion_values)
+ if not values:
+ latter = "NOTHING"
+ else:
+ allvalues.update(values)
+ latter = (
+ "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
+ )
+
sql = (
"INSERT INTO %s (%s) VALUES (%s) "
- "ON CONFLICT (%s) DO UPDATE SET %s"
+ "ON CONFLICT (%s) DO %s"
) % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
- ", ".join(k + "=EXCLUDED." + k for k in values),
+ latter
)
txn.execute(sql, list(allvalues.values()))
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 60cdc884e6..a2f8c23a65 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -52,7 +52,9 @@ class BackgroundUpdatePerformance(object):
Returns:
A duration in ms as a float
"""
- if self.total_item_count == 0:
+ if self.avg_duration_ms == 0:
+ return 0
+ elif self.total_item_count == 0:
return None
else:
# Use the exponential moving average so that we can adapt to
@@ -64,7 +66,9 @@ class BackgroundUpdatePerformance(object):
Returns:
A duration in ms as a float
"""
- if self.total_item_count == 0:
+ if self.total_duration_ms == 0:
+ return 0
+ elif self.total_item_count == 0:
return None
else:
return float(self.total_item_count) / float(self.total_duration_ms)
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index e06b0bc56d..e6a42a53bb 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -19,14 +19,174 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache
-from .background_updates import BackgroundUpdateStore
-
logger = logging.getLogger(__name__)
-class DeviceInboxStore(BackgroundUpdateStore):
+class DeviceInboxWorkerStore(SQLBaseStore):
+ def get_to_device_stream_token(self):
+ return self._device_inbox_id_gen.get_current_token()
+
+ def get_new_messages_for_device(
+ self, user_id, device_id, last_stream_id, current_stream_id, limit=100
+ ):
+ """
+ Args:
+ user_id(str): The recipient user_id.
+ device_id(str): The recipient device_id.
+ current_stream_id(int): The current position of the to device
+ message stream.
+ Returns:
+ Deferred ([dict], int): List of messages for the device and where
+ in the stream the messages got to.
+ """
+ has_changed = self._device_inbox_stream_cache.has_entity_changed(
+ user_id, last_stream_id
+ )
+ if not has_changed:
+ return defer.succeed(([], current_stream_id))
+
+ def get_new_messages_for_device_txn(txn):
+ sql = (
+ "SELECT stream_id, message_json FROM device_inbox"
+ " WHERE user_id = ? AND device_id = ?"
+ " AND ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (
+ user_id, device_id, last_stream_id, current_stream_id, limit
+ ))
+ messages = []
+ for row in txn:
+ stream_pos = row[0]
+ messages.append(json.loads(row[1]))
+ if len(messages) < limit:
+ stream_pos = current_stream_id
+ return (messages, stream_pos)
+
+ return self.runInteraction(
+ "get_new_messages_for_device", get_new_messages_for_device_txn,
+ )
+
+ @defer.inlineCallbacks
+ def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
+ """
+ Args:
+ user_id(str): The recipient user_id.
+ device_id(str): The recipient device_id.
+ up_to_stream_id(int): Where to delete messages up to.
+ Returns:
+ A deferred that resolves to the number of messages deleted.
+ """
+ # If we have cached the last stream id we've deleted up to, we can
+ # check if there is likely to be anything that needs deleting
+ last_deleted_stream_id = self._last_device_delete_cache.get(
+ (user_id, device_id), None
+ )
+ if last_deleted_stream_id:
+ has_changed = self._device_inbox_stream_cache.has_entity_changed(
+ user_id, last_deleted_stream_id
+ )
+ if not has_changed:
+ defer.returnValue(0)
+
+ def delete_messages_for_device_txn(txn):
+ sql = (
+ "DELETE FROM device_inbox"
+ " WHERE user_id = ? AND device_id = ?"
+ " AND stream_id <= ?"
+ )
+ txn.execute(sql, (user_id, device_id, up_to_stream_id))
+ return txn.rowcount
+
+ count = yield self.runInteraction(
+ "delete_messages_for_device", delete_messages_for_device_txn
+ )
+
+ # Update the cache, ensuring that we only ever increase the value
+ last_deleted_stream_id = self._last_device_delete_cache.get(
+ (user_id, device_id), 0
+ )
+ self._last_device_delete_cache[(user_id, device_id)] = max(
+ last_deleted_stream_id, up_to_stream_id
+ )
+
+ defer.returnValue(count)
+
+ def get_new_device_msgs_for_remote(
+ self, destination, last_stream_id, current_stream_id, limit=100
+ ):
+ """
+ Args:
+ destination(str): The name of the remote server.
+ last_stream_id(int|long): The last position of the device message stream
+ that the server sent up to.
+ current_stream_id(int|long): The current position of the device
+ message stream.
+ Returns:
+ Deferred ([dict], int|long): List of messages for the device and where
+ in the stream the messages got to.
+ """
+
+ has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
+ destination, last_stream_id
+ )
+ if not has_changed or last_stream_id == current_stream_id:
+ return defer.succeed(([], current_stream_id))
+
+ def get_new_messages_for_remote_destination_txn(txn):
+ sql = (
+ "SELECT stream_id, messages_json FROM device_federation_outbox"
+ " WHERE destination = ?"
+ " AND ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (
+ destination, last_stream_id, current_stream_id, limit
+ ))
+ messages = []
+ for row in txn:
+ stream_pos = row[0]
+ messages.append(json.loads(row[1]))
+ if len(messages) < limit:
+ stream_pos = current_stream_id
+ return (messages, stream_pos)
+
+ return self.runInteraction(
+ "get_new_device_msgs_for_remote",
+ get_new_messages_for_remote_destination_txn,
+ )
+
+ def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
+ """Used to delete messages when the remote destination acknowledges
+ their receipt.
+
+ Args:
+ destination(str): The destination server_name
+ up_to_stream_id(int): Where to delete messages up to.
+ Returns:
+ A deferred that resolves when the messages have been deleted.
+ """
+ def delete_messages_for_remote_destination_txn(txn):
+ sql = (
+ "DELETE FROM device_federation_outbox"
+ " WHERE destination = ?"
+ " AND stream_id <= ?"
+ )
+ txn.execute(sql, (destination, up_to_stream_id))
+
+ return self.runInteraction(
+ "delete_device_msgs_for_remote",
+ delete_messages_for_remote_destination_txn
+ )
+
+
+class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
@@ -220,93 +380,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
txn.executemany(sql, rows)
- def get_new_messages_for_device(
- self, user_id, device_id, last_stream_id, current_stream_id, limit=100
- ):
- """
- Args:
- user_id(str): The recipient user_id.
- device_id(str): The recipient device_id.
- current_stream_id(int): The current position of the to device
- message stream.
- Returns:
- Deferred ([dict], int): List of messages for the device and where
- in the stream the messages got to.
- """
- has_changed = self._device_inbox_stream_cache.has_entity_changed(
- user_id, last_stream_id
- )
- if not has_changed:
- return defer.succeed(([], current_stream_id))
-
- def get_new_messages_for_device_txn(txn):
- sql = (
- "SELECT stream_id, message_json FROM device_inbox"
- " WHERE user_id = ? AND device_id = ?"
- " AND ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- " LIMIT ?"
- )
- txn.execute(sql, (
- user_id, device_id, last_stream_id, current_stream_id, limit
- ))
- messages = []
- for row in txn:
- stream_pos = row[0]
- messages.append(json.loads(row[1]))
- if len(messages) < limit:
- stream_pos = current_stream_id
- return (messages, stream_pos)
-
- return self.runInteraction(
- "get_new_messages_for_device", get_new_messages_for_device_txn,
- )
-
- @defer.inlineCallbacks
- def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
- """
- Args:
- user_id(str): The recipient user_id.
- device_id(str): The recipient device_id.
- up_to_stream_id(int): Where to delete messages up to.
- Returns:
- A deferred that resolves to the number of messages deleted.
- """
- # If we have cached the last stream id we've deleted up to, we can
- # check if there is likely to be anything that needs deleting
- last_deleted_stream_id = self._last_device_delete_cache.get(
- (user_id, device_id), None
- )
- if last_deleted_stream_id:
- has_changed = self._device_inbox_stream_cache.has_entity_changed(
- user_id, last_deleted_stream_id
- )
- if not has_changed:
- defer.returnValue(0)
-
- def delete_messages_for_device_txn(txn):
- sql = (
- "DELETE FROM device_inbox"
- " WHERE user_id = ? AND device_id = ?"
- " AND stream_id <= ?"
- )
- txn.execute(sql, (user_id, device_id, up_to_stream_id))
- return txn.rowcount
-
- count = yield self.runInteraction(
- "delete_messages_for_device", delete_messages_for_device_txn
- )
-
- # Update the cache, ensuring that we only ever increase the value
- last_deleted_stream_id = self._last_device_delete_cache.get(
- (user_id, device_id), 0
- )
- self._last_device_delete_cache[(user_id, device_id)] = max(
- last_deleted_stream_id, up_to_stream_id
- )
-
- defer.returnValue(count)
-
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
@@ -351,77 +424,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
"get_all_new_device_messages", get_all_new_device_messages_txn
)
- def get_to_device_stream_token(self):
- return self._device_inbox_id_gen.get_current_token()
-
- def get_new_device_msgs_for_remote(
- self, destination, last_stream_id, current_stream_id, limit=100
- ):
- """
- Args:
- destination(str): The name of the remote server.
- last_stream_id(int|long): The last position of the device message stream
- that the server sent up to.
- current_stream_id(int|long): The current position of the device
- message stream.
- Returns:
- Deferred ([dict], int|long): List of messages for the device and where
- in the stream the messages got to.
- """
-
- has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
- destination, last_stream_id
- )
- if not has_changed or last_stream_id == current_stream_id:
- return defer.succeed(([], current_stream_id))
-
- def get_new_messages_for_remote_destination_txn(txn):
- sql = (
- "SELECT stream_id, messages_json FROM device_federation_outbox"
- " WHERE destination = ?"
- " AND ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- " LIMIT ?"
- )
- txn.execute(sql, (
- destination, last_stream_id, current_stream_id, limit
- ))
- messages = []
- for row in txn:
- stream_pos = row[0]
- messages.append(json.loads(row[1]))
- if len(messages) < limit:
- stream_pos = current_stream_id
- return (messages, stream_pos)
-
- return self.runInteraction(
- "get_new_device_msgs_for_remote",
- get_new_messages_for_remote_destination_txn,
- )
-
- def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
- """Used to delete messages when the remote destination acknowledges
- their receipt.
-
- Args:
- destination(str): The destination server_name
- up_to_stream_id(int): Where to delete messages up to.
- Returns:
- A deferred that resolves when the messages have been deleted.
- """
- def delete_messages_for_remote_destination_txn(txn):
- sql = (
- "DELETE FROM device_federation_outbox"
- " WHERE destination = ?"
- " AND stream_id <= ?"
- )
- txn.execute(sql, (destination, up_to_stream_id))
-
- return self.runInteraction(
- "delete_device_msgs_for_remote",
- delete_messages_for_remote_destination_txn
- )
-
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index ecdab34e7d..e716dc1437 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -22,11 +22,10 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import Cache, SQLBaseStore, db_to_json
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-from ._base import Cache, db_to_json
-
logger = logging.getLogger(__name__)
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -34,7 +33,343 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
)
-class DeviceStore(BackgroundUpdateStore):
+class DeviceWorkerStore(SQLBaseStore):
+ def get_device(self, user_id, device_id):
+ """Retrieve a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to retrieve
+ Returns:
+ defer.Deferred for a dict containing the device information
+ Raises:
+ StoreError: if the device is not found
+ """
+ return self._simple_select_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_device",
+ )
+
+ @defer.inlineCallbacks
+ def get_devices_by_user(self, user_id):
+ """Retrieve all of a user's registered devices.
+
+ Args:
+ user_id (str):
+ Returns:
+ defer.Deferred: resolves to a dict from device_id to a dict
+ containing "device_id", "user_id" and "display_name" for each
+ device.
+ """
+ devices = yield self._simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user"
+ )
+
+ defer.returnValue({d["device_id"]: d for d in devices})
+
+ def get_devices_by_remote(self, destination, from_stream_id):
+ """Get stream of updates to send to remote servers
+
+ Returns:
+ (int, list[dict]): current stream id and list of updates
+ """
+ now_stream_id = self._device_list_id_gen.get_current_token()
+
+ has_changed = self._device_list_federation_stream_cache.has_entity_changed(
+ destination, int(from_stream_id)
+ )
+ if not has_changed:
+ return (now_stream_id, [])
+
+ return self.runInteraction(
+ "get_devices_by_remote", self._get_devices_by_remote_txn,
+ destination, from_stream_id, now_stream_id,
+ )
+
+ def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
+ now_stream_id):
+ sql = """
+ SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
+ WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+ GROUP BY user_id, device_id
+ LIMIT 20
+ """
+ txn.execute(
+ sql, (destination, from_stream_id, now_stream_id, False)
+ )
+
+ # maps (user_id, device_id) -> stream_id
+ query_map = {(r[0], r[1]): r[2] for r in txn}
+ if not query_map:
+ return (now_stream_id, [])
+
+ if len(query_map) >= 20:
+ now_stream_id = max(stream_id for stream_id in itervalues(query_map))
+
+ devices = self._get_e2e_device_keys_txn(
+ txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
+ )
+
+ prev_sent_id_sql = """
+ SELECT coalesce(max(stream_id), 0) as stream_id
+ FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ? AND stream_id <= ?
+ """
+
+ results = []
+ for user_id, user_devices in iteritems(devices):
+ # The prev_id for the first row is always the last row before
+ # `from_stream_id`
+ txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
+ rows = txn.fetchall()
+ prev_id = rows[0][0]
+ for device_id, device in iteritems(user_devices):
+ stream_id = query_map[(user_id, device_id)]
+ result = {
+ "user_id": user_id,
+ "device_id": device_id,
+ "prev_id": [prev_id] if prev_id else [],
+ "stream_id": stream_id,
+ }
+
+ prev_id = stream_id
+
+ if device is not None:
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+ else:
+ result["deleted"] = True
+
+ results.append(result)
+
+ return (now_stream_id, results)
+
+ def mark_as_sent_devices_by_remote(self, destination, stream_id):
+ """Mark that updates have successfully been sent to the destination.
+ """
+ return self.runInteraction(
+ "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
+ destination, stream_id,
+ )
+
+ def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+ # We update the device_lists_outbound_last_success with the successfully
+ # poked users. We do the join to see which users need to be inserted and
+ # which updated.
+ sql = """
+ SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ FROM device_lists_outbound_pokes as o
+ LEFT JOIN device_lists_outbound_last_success as s
+ USING (destination, user_id)
+ WHERE destination = ? AND o.stream_id <= ?
+ GROUP BY user_id
+ """
+ txn.execute(sql, (destination, stream_id,))
+ rows = txn.fetchall()
+
+ sql = """
+ UPDATE device_lists_outbound_last_success
+ SET stream_id = ?
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(
+ sql, ((row[1], destination, row[0],) for row in rows if row[2])
+ )
+
+ sql = """
+ INSERT INTO device_lists_outbound_last_success
+ (destination, user_id, stream_id) VALUES (?, ?, ?)
+ """
+ txn.executemany(
+ sql, ((destination, row[0], row[1],) for row in rows if not row[2])
+ )
+
+ # Delete all sent outbound pokes
+ sql = """
+ DELETE FROM device_lists_outbound_pokes
+ WHERE destination = ? AND stream_id <= ?
+ """
+ txn.execute(sql, (destination, stream_id,))
+
+ def get_device_stream_token(self):
+ return self._device_list_id_gen.get_current_token()
+
+ @defer.inlineCallbacks
+ def get_user_devices_from_cache(self, query_list):
+ """Get the devices (and keys if any) for remote users from the cache.
+
+ Args:
+ query_list(list): List of (user_id, device_ids), if device_ids is
+ falsey then return all device ids for that user.
+
+ Returns:
+ (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
+ a set of user_ids and results_map is a mapping of
+ user_id -> device_id -> device_info
+ """
+ user_ids = set(user_id for user_id, _ in query_list)
+ user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ user_ids_in_cache = set(
+ user_id for user_id, stream_id in user_map.items() if stream_id
+ )
+ user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+ results = {}
+ for user_id, device_id in query_list:
+ if user_id not in user_ids_in_cache:
+ continue
+
+ if device_id:
+ device = yield self._get_cached_user_device(user_id, device_id)
+ results.setdefault(user_id, {})[device_id] = device
+ else:
+ results[user_id] = yield self._get_cached_devices_for_user(user_id)
+
+ defer.returnValue((user_ids_not_in_cache, results))
+
+ @cachedInlineCallbacks(num_args=2, tree=True)
+ def _get_cached_user_device(self, user_id, device_id):
+ content = yield self._simple_select_one_onecol(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ retcol="content",
+ desc="_get_cached_user_device",
+ )
+ defer.returnValue(db_to_json(content))
+
+ @cachedInlineCallbacks()
+ def _get_cached_devices_for_user(self, user_id):
+ devices = yield self._simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("device_id", "content"),
+ desc="_get_cached_devices_for_user",
+ )
+ defer.returnValue({
+ device["device_id"]: db_to_json(device["content"])
+ for device in devices
+ })
+
+ def get_devices_with_keys_by_user(self, user_id):
+ """Get all devices (with any device keys) for a user
+
+ Returns:
+ (stream_id, devices)
+ """
+ return self.runInteraction(
+ "get_devices_with_keys_by_user",
+ self._get_devices_with_keys_by_user_txn, user_id,
+ )
+
+ def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+ now_stream_id = self._device_list_id_gen.get_current_token()
+
+ devices = self._get_e2e_device_keys_txn(
+ txn, [(user_id, None)], include_all_devices=True
+ )
+
+ if devices:
+ user_devices = devices[user_id]
+ results = []
+ for device_id, device in iteritems(user_devices):
+ result = {
+ "device_id": device_id,
+ }
+
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
+ @defer.inlineCallbacks
+ def get_user_whose_devices_changed(self, from_key):
+ """Get set of users whose devices have changed since `from_key`.
+ """
+ from_key = int(from_key)
+ changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
+ if changed is not None:
+ defer.returnValue(set(changed))
+
+ sql = """
+ SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
+ """
+ rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+ defer.returnValue(set(row[0] for row in rows))
+
+ def get_all_device_list_changes_for_remotes(self, from_key, to_key):
+ """Return a list of `(stream_id, user_id, destination)` which is the
+ combined list of changes to devices, and which destinations need to be
+ poked. `destination` may be None if no destinations need to be poked.
+ """
+ # We do a group by here as there can be a large number of duplicate
+ # entries, since we throw away device IDs.
+ sql = """
+ SELECT MAX(stream_id) AS stream_id, user_id, destination
+ FROM device_lists_stream
+ LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+ WHERE ? < stream_id AND stream_id <= ?
+ GROUP BY user_id, destination
+ """
+ return self._execute(
+ "get_all_device_list_changes_for_remotes", None,
+ sql, from_key, to_key
+ )
+
+ @cached(max_entries=10000)
+ def get_device_list_last_stream_id_for_remote(self, user_id):
+ """Get the last stream_id we got for a user. May be None if we haven't
+ got any information for them.
+ """
+ return self._simple_select_one_onecol(
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ retcol="stream_id",
+ desc="get_device_list_last_stream_id_for_remote",
+ allow_none=True,
+ )
+
+ @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
+ list_name="user_ids", inlineCallbacks=True)
+ def get_device_list_last_stream_id_for_remotes(self, user_ids):
+ rows = yield self._simple_select_many_batch(
+ table="device_lists_remote_extremeties",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id", "stream_id",),
+ desc="get_device_list_last_stream_id_for_remotes",
+ )
+
+ results = {user_id: None for user_id in user_ids}
+ results.update({
+ row["user_id"]: row["stream_id"] for row in rows
+ })
+
+ defer.returnValue(results)
+
+
+class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
@@ -121,24 +456,6 @@ class DeviceStore(BackgroundUpdateStore):
initial_device_display_name, e)
raise StoreError(500, "Problem storing device.")
- def get_device(self, user_id, device_id):
- """Retrieve a device.
-
- Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to retrieve
- Returns:
- defer.Deferred for a dict containing the device information
- Raises:
- StoreError: if the device is not found
- """
- return self._simple_select_one(
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_device",
- )
-
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
"""Delete a device.
@@ -203,57 +520,6 @@ class DeviceStore(BackgroundUpdateStore):
)
@defer.inlineCallbacks
- def get_devices_by_user(self, user_id):
- """Retrieve all of a user's registered devices.
-
- Args:
- user_id (str):
- Returns:
- defer.Deferred: resolves to a dict from device_id to a dict
- containing "device_id", "user_id" and "display_name" for each
- device.
- """
- devices = yield self._simple_select_list(
- table="devices",
- keyvalues={"user_id": user_id},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_devices_by_user"
- )
-
- defer.returnValue({d["device_id"]: d for d in devices})
-
- @cached(max_entries=10000)
- def get_device_list_last_stream_id_for_remote(self, user_id):
- """Get the last stream_id we got for a user. May be None if we haven't
- got any information for them.
- """
- return self._simple_select_one_onecol(
- table="device_lists_remote_extremeties",
- keyvalues={"user_id": user_id},
- retcol="stream_id",
- desc="get_device_list_remote_extremity",
- allow_none=True,
- )
-
- @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
- list_name="user_ids", inlineCallbacks=True)
- def get_device_list_last_stream_id_for_remotes(self, user_ids):
- rows = yield self._simple_select_many_batch(
- table="device_lists_remote_extremeties",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id", "stream_id",),
- desc="get_user_devices_from_cache",
- )
-
- results = {user_id: None for user_id in user_ids}
- results.update({
- row["user_id"]: row["stream_id"] for row in rows
- })
-
- defer.returnValue(results)
-
- @defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
@@ -405,268 +671,6 @@ class DeviceStore(BackgroundUpdateStore):
lock=False,
)
- def get_devices_by_remote(self, destination, from_stream_id):
- """Get stream of updates to send to remote servers
-
- Returns:
- (int, list[dict]): current stream id and list of updates
- """
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- has_changed = self._device_list_federation_stream_cache.has_entity_changed(
- destination, int(from_stream_id)
- )
- if not has_changed:
- return (now_stream_id, [])
-
- return self.runInteraction(
- "get_devices_by_remote", self._get_devices_by_remote_txn,
- destination, from_stream_id, now_stream_id,
- )
-
- def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
- now_stream_id):
- sql = """
- SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
- WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
- GROUP BY user_id, device_id
- LIMIT 20
- """
- txn.execute(
- sql, (destination, from_stream_id, now_stream_id, False)
- )
-
- # maps (user_id, device_id) -> stream_id
- query_map = {(r[0], r[1]): r[2] for r in txn}
- if not query_map:
- return (now_stream_id, [])
-
- if len(query_map) >= 20:
- now_stream_id = max(stream_id for stream_id in itervalues(query_map))
-
- devices = self._get_e2e_device_keys_txn(
- txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
- )
-
- prev_sent_id_sql = """
- SELECT coalesce(max(stream_id), 0) as stream_id
- FROM device_lists_outbound_last_success
- WHERE destination = ? AND user_id = ? AND stream_id <= ?
- """
-
- results = []
- for user_id, user_devices in iteritems(devices):
- # The prev_id for the first row is always the last row before
- # `from_stream_id`
- txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
- rows = txn.fetchall()
- prev_id = rows[0][0]
- for device_id, device in iteritems(user_devices):
- stream_id = query_map[(user_id, device_id)]
- result = {
- "user_id": user_id,
- "device_id": device_id,
- "prev_id": [prev_id] if prev_id else [],
- "stream_id": stream_id,
- }
-
- prev_id = stream_id
-
- if device is not None:
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
- else:
- result["deleted"] = True
-
- results.append(result)
-
- return (now_stream_id, results)
-
- @defer.inlineCallbacks
- def get_user_devices_from_cache(self, query_list):
- """Get the devices (and keys if any) for remote users from the cache.
-
- Args:
- query_list(list): List of (user_id, device_ids), if device_ids is
- falsey then return all device ids for that user.
-
- Returns:
- (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
- a set of user_ids and results_map is a mapping of
- user_id -> device_id -> device_info
- """
- user_ids = set(user_id for user_id, _ in query_list)
- user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
- user_ids_in_cache = set(
- user_id for user_id, stream_id in user_map.items() if stream_id
- )
- user_ids_not_in_cache = user_ids - user_ids_in_cache
-
- results = {}
- for user_id, device_id in query_list:
- if user_id not in user_ids_in_cache:
- continue
-
- if device_id:
- device = yield self._get_cached_user_device(user_id, device_id)
- results.setdefault(user_id, {})[device_id] = device
- else:
- results[user_id] = yield self._get_cached_devices_for_user(user_id)
-
- defer.returnValue((user_ids_not_in_cache, results))
-
- @cachedInlineCallbacks(num_args=2, tree=True)
- def _get_cached_user_device(self, user_id, device_id):
- content = yield self._simple_select_one_onecol(
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- retcol="content",
- desc="_get_cached_user_device",
- )
- defer.returnValue(db_to_json(content))
-
- @cachedInlineCallbacks()
- def _get_cached_devices_for_user(self, user_id):
- devices = yield self._simple_select_list(
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- },
- retcols=("device_id", "content"),
- desc="_get_cached_devices_for_user",
- )
- defer.returnValue({
- device["device_id"]: db_to_json(device["content"])
- for device in devices
- })
-
- def get_devices_with_keys_by_user(self, user_id):
- """Get all devices (with any device keys) for a user
-
- Returns:
- (stream_id, devices)
- """
- return self.runInteraction(
- "get_devices_with_keys_by_user",
- self._get_devices_with_keys_by_user_txn, user_id,
- )
-
- def _get_devices_with_keys_by_user_txn(self, txn, user_id):
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- devices = self._get_e2e_device_keys_txn(
- txn, [(user_id, None)], include_all_devices=True
- )
-
- if devices:
- user_devices = devices[user_id]
- results = []
- for device_id, device in iteritems(user_devices):
- result = {
- "device_id": device_id,
- }
-
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
-
- results.append(result)
-
- return now_stream_id, results
-
- return now_stream_id, []
-
- def mark_as_sent_devices_by_remote(self, destination, stream_id):
- """Mark that updates have successfully been sent to the destination.
- """
- return self.runInteraction(
- "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
- destination, stream_id,
- )
-
- def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
- # We update the device_lists_outbound_last_success with the successfully
- # poked users. We do the join to see which users need to be inserted and
- # which updated.
- sql = """
- SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
- FROM device_lists_outbound_pokes as o
- LEFT JOIN device_lists_outbound_last_success as s
- USING (destination, user_id)
- WHERE destination = ? AND o.stream_id <= ?
- GROUP BY user_id
- """
- txn.execute(sql, (destination, stream_id,))
- rows = txn.fetchall()
-
- sql = """
- UPDATE device_lists_outbound_last_success
- SET stream_id = ?
- WHERE destination = ? AND user_id = ?
- """
- txn.executemany(
- sql, ((row[1], destination, row[0],) for row in rows if row[2])
- )
-
- sql = """
- INSERT INTO device_lists_outbound_last_success
- (destination, user_id, stream_id) VALUES (?, ?, ?)
- """
- txn.executemany(
- sql, ((destination, row[0], row[1],) for row in rows if not row[2])
- )
-
- # Delete all sent outbound pokes
- sql = """
- DELETE FROM device_lists_outbound_pokes
- WHERE destination = ? AND stream_id <= ?
- """
- txn.execute(sql, (destination, stream_id,))
-
- @defer.inlineCallbacks
- def get_user_whose_devices_changed(self, from_key):
- """Get set of users whose devices have changed since `from_key`.
- """
- from_key = int(from_key)
- changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
- if changed is not None:
- defer.returnValue(set(changed))
-
- sql = """
- SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
- """
- rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
- defer.returnValue(set(row[0] for row in rows))
-
- def get_all_device_list_changes_for_remotes(self, from_key, to_key):
- """Return a list of `(stream_id, user_id, destination)` which is the
- combined list of changes to devices, and which destinations need to be
- poked. `destination` may be None if no destinations need to be poked.
- """
- # We do a group by here as there can be a large number of duplicate
- # entries, since we throw away device IDs.
- sql = """
- SELECT MAX(stream_id) AS stream_id, user_id, destination
- FROM device_lists_stream
- LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
- WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id, destination
- """
- return self._execute(
- "get_all_device_list_changes_for_remotes", None,
- sql, from_key, to_key
- )
-
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
@@ -732,9 +736,6 @@ class DeviceStore(BackgroundUpdateStore):
]
)
- def get_device_stream_token(self):
- return self._device_list_id_gen.get_current_token()
-
def _prune_old_outbound_device_pokes(self):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. We keep one entry per
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2a0f6cfca9..e381e472a2 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -23,49 +23,7 @@ from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore, db_to_json
-class EndToEndKeyStore(SQLBaseStore):
- def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
- """Stores device keys for a device. Returns whether there was a change
- or the keys were already in the database.
- """
- def _set_e2e_device_keys_txn(txn):
- old_key_json = self._simple_select_one_onecol_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- retcol="key_json",
- allow_none=True,
- )
-
- # In py3 we need old_key_json to match new_key_json type. The DB
- # returns unicode while encode_canonical_json returns bytes.
- new_key_json = encode_canonical_json(device_keys).decode("utf-8")
-
- if old_key_json == new_key_json:
- return False
-
- self._simple_upsert_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- values={
- "ts_added_ms": time_now,
- "key_json": new_key_json,
- }
- )
-
- return True
-
- return self.runInteraction(
- "set_e2e_device_keys", _set_e2e_device_keys_txn
- )
-
+class EndToEndKeyWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_e2e_device_keys(
self, query_list, include_all_devices=False,
@@ -238,6 +196,50 @@ class EndToEndKeyStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
+
+class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+ def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+ """Stores device keys for a device. Returns whether there was a change
+ or the keys were already in the database.
+ """
+ def _set_e2e_device_keys_txn(txn):
+ old_key_json = self._simple_select_one_onecol_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ retcol="key_json",
+ allow_none=True,
+ )
+
+ # In py3 we need old_key_json to match new_key_json type. The DB
+ # returns unicode while encode_canonical_json returns bytes.
+ new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+
+ if old_key_json == new_key_json:
+ return False
+
+ self._simple_upsert_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={
+ "ts_added_ms": time_now,
+ "key_json": new_key_json,
+ }
+ )
+
+ return True
+
+ return self.runInteraction(
+ "set_e2e_device_keys", _set_e2e_device_keys_txn
+ )
+
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
def _claim_e2e_one_time_keys(txn):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 38809ed0fc..a8d90456e3 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -442,6 +442,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
event_results.reverse()
return event_results
+ @defer.inlineCallbacks
+ def get_successor_events(self, event_ids):
+ """Fetch all events that have the given events as a prev event
+
+ Args:
+ event_ids (iterable[str])
+
+ Returns:
+ Deferred[list[str]]
+ """
+ rows = yield self._simple_select_many_batch(
+ table="event_edges",
+ column="prev_event_id",
+ iterable=event_ids,
+ retcols=("event_id",),
+ desc="get_successor_events"
+ )
+
+ defer.returnValue([
+ row["event_id"] for row in rows
+ ])
+
class EventFederationStore(EventFederationWorkerStore):
""" Responsible for storing and serving up the various graphs associated
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 06db9e56e6..428300ea0a 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -537,6 +537,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
new_events = [
event for event, ctx in event_contexts
if not event.internal_metadata.is_outlier() and not ctx.rejected
+ and not event.internal_metadata.is_soft_failed()
]
# start with the existing forward extremities
@@ -1406,21 +1407,6 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
values=state_values,
)
- self._simple_insert_many_txn(
- txn,
- table="event_edges",
- values=[
- {
- "event_id": event.event_id,
- "prev_event_id": prev_id,
- "room_id": event.room_id,
- "is_state": True,
- }
- for event, _ in state_events_and_contexts
- for prev_id, _ in event.prev_state
- ],
- )
-
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 6a5028961d..4b8438c3e9 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -186,6 +186,63 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
defer.returnValue(results)
@defer.inlineCallbacks
+ def move_push_rule_from_room_to_room(
+ self, new_room_id, user_id, rule,
+ ):
+ """Move a single push rule from one room to another for a specific user.
+
+ Args:
+ new_room_id (str): ID of the new room.
+ user_id (str): ID of user the push rule belongs to.
+ rule (Dict): A push rule.
+ """
+ # Create new rule id
+ rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1])
+ new_rule_id = rule_id_scope + "/" + new_room_id
+
+ # Change room id in each condition
+ for condition in rule.get("conditions", []):
+ if condition.get("key") == "room_id":
+ condition["pattern"] = new_room_id
+
+ # Add the rule for the new room
+ yield self.add_push_rule(
+ user_id=user_id,
+ rule_id=new_rule_id,
+ priority_class=rule["priority_class"],
+ conditions=rule["conditions"],
+ actions=rule["actions"],
+ )
+
+ # Delete push rule for the old room
+ yield self.delete_push_rule(user_id, rule["rule_id"])
+
+ @defer.inlineCallbacks
+ def move_push_rules_from_room_to_room_for_user(
+ self, old_room_id, new_room_id, user_id,
+ ):
+ """Move all of the push rules from one room to another for a specific
+ user.
+
+ Args:
+ old_room_id (str): ID of the old room.
+ new_room_id (str): ID of the new room.
+ user_id (str): ID of user to copy push rules for.
+ """
+ # Retrieve push rules for this user
+ user_push_rules = yield self.get_push_rules_for_user(user_id)
+
+ # Get rules relating to the old room, move them to the new room, then
+ # delete them from the old room
+ for rule in user_push_rules:
+ conditions = rule.get("conditions", [])
+ if any((c.get("key") == "room_id" and
+ c.get("pattern") == old_room_id) for c in conditions):
+ self.move_push_rule_from_room_to_room(
+ new_room_id, user_id, rule,
+ )
+
+ @defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
state_group = context.state_group
if not state_group:
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 0ac665e967..89a1f7e3d7 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -301,7 +301,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return txn.fetchall()
+ return (
+ r[0:5] + (json.loads(r[5]), ) for r in txn
+ )
return self.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
@@ -346,15 +348,23 @@ class ReceiptsStore(ReceiptsWorkerStore):
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
+ """Inserts a read-receipt into the database if it's newer than the current RR
+
+ Returns: int|None
+ None if the RR is older than the current RR
+ otherwise, the rx timestamp of the event that the RR corresponds to
+ (or 0 if the event is unknown)
+ """
res = self._simple_select_one_txn(
txn,
table="events",
- retcols=["topological_ordering", "stream_ordering"],
+ retcols=["stream_ordering", "received_ts"],
keyvalues={"event_id": event_id},
allow_none=True
)
stream_ordering = int(res["stream_ordering"]) if res else None
+ rx_ts = res["received_ts"] if res else 0
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
@@ -373,7 +383,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"one for later event %s",
event_id, eid,
)
- return False
+ return None
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
@@ -429,7 +439,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
stream_ordering=stream_ordering,
)
- return True
+ return rx_ts
@defer.inlineCallbacks
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
@@ -466,7 +476,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
- have_persisted = yield self.runInteraction(
+ event_ts = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id, receipt_type, user_id, linearized_event_id,
@@ -474,8 +484,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
stream_id=stream_id,
)
- if not have_persisted:
- defer.returnValue(None)
+ if event_ts is None:
+ defer.returnValue(None)
+
+ now = self._clock.time_msec()
+ logger.debug(
+ "RR for event %s in %s (%i ms old)",
+ linearized_event_id, room_id, now - event_ts,
+ )
yield self.insert_graph_receipt(
room_id, receipt_type, user_id, event_ids, data
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 9b9572890b..9b6c28892c 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -295,6 +295,39 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret['user_id']
return None
+ @defer.inlineCallbacks
+ def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+ yield self._simple_upsert("user_threepids", {
+ "medium": medium,
+ "address": address,
+ }, {
+ "user_id": user_id,
+ "validated_at": validated_at,
+ "added_at": added_at,
+ })
+
+ @defer.inlineCallbacks
+ def user_get_threepids(self, user_id):
+ ret = yield self._simple_select_list(
+ "user_threepids", {
+ "user_id": user_id
+ },
+ ['medium', 'address', 'validated_at', 'added_at'],
+ 'user_get_threepids'
+ )
+ defer.returnValue(ret)
+
+ def user_delete_threepid(self, user_id, medium, address):
+ return self._simple_delete(
+ "user_threepids",
+ keyvalues={
+ "user_id": user_id,
+ "medium": medium,
+ "address": address,
+ },
+ desc="user_delete_threepids",
+ )
+
class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):
@@ -633,39 +666,6 @@ class RegistrationStore(RegistrationWorkerStore,
defer.returnValue(res if res else False)
@defer.inlineCallbacks
- def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self._simple_upsert("user_threepids", {
- "medium": medium,
- "address": address,
- }, {
- "user_id": user_id,
- "validated_at": validated_at,
- "added_at": added_at,
- })
-
- @defer.inlineCallbacks
- def user_get_threepids(self, user_id):
- ret = yield self._simple_select_list(
- "user_threepids", {
- "user_id": user_id
- },
- ['medium', 'address', 'validated_at', 'added_at'],
- 'user_get_threepids'
- )
- defer.returnValue(ret)
-
- def user_delete_threepid(self, user_id, medium, address):
- return self._simple_delete(
- "user_threepids",
- keyvalues={
- "user_id": user_id,
- "medium": medium,
- "address": address,
- },
- desc="user_delete_threepids",
- )
-
- @defer.inlineCallbacks
def save_or_get_3pid_guest_access_token(
self, medium, address, access_token, inviter_user_id
):
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 41c65e112a..a979d4860a 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -500,10 +500,22 @@ class RoomStore(RoomWorkerStore, SearchStore):
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
- yield self._simple_insert(
+ """Marks the room as blocked. Can be called multiple times.
+
+ Args:
+ room_id (str): Room to block
+ user_id (str): Who blocked it
+
+ Returns:
+ Deferred
+ """
+ yield self._simple_upsert(
table="blocked_rooms",
- values={
+ keyvalues={
"room_id": room_id,
+ },
+ values={},
+ insertion_values={
"user_id": user_id,
},
desc="block_room",
diff --git a/synapse/storage/schema/delta/53/user_dir_populate.sql b/synapse/storage/schema/delta/53/user_dir_populate.sql
new file mode 100644
index 0000000000..ffcc896b58
--- /dev/null
+++ b/synapse/storage/schema/delta/53/user_dir_populate.sql
@@ -0,0 +1,30 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Set up staging tables
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('populate_user_directory_createtables', '{}');
+
+-- Run through each room and update the user directory according to who is in it
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_user_directory_process_rooms', '{}', 'populate_user_directory_createtables');
+
+-- Insert all users, if search_all_users is on
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_user_directory_process_users', '{}', 'populate_user_directory_process_rooms');
+
+-- Clean up staging tables
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_user_directory_cleanup', '{}', 'populate_user_directory_process_users');
diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/schema/delta/53/user_share.sql
new file mode 100644
index 0000000000..5831b1a6f8
--- /dev/null
+++ b/synapse/storage/schema/delta/53/user_share.sql
@@ -0,0 +1,44 @@
+/* Copyright 2017 Vector Creations Ltd, 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Old disused version of the tables below.
+DROP TABLE IF EXISTS users_who_share_rooms;
+
+-- Tables keeping track of what users share rooms. This is a map of local users
+-- to local or remote users, per room. Remote users cannot be in the user_id
+-- column, only the other_user_id column. There are two tables, one for public
+-- rooms and those for private rooms.
+CREATE TABLE IF NOT EXISTS users_who_share_public_rooms (
+ user_id TEXT NOT NULL,
+ other_user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS users_who_share_private_rooms (
+ user_id TEXT NOT NULL,
+ other_user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX users_who_share_public_rooms_u_idx ON users_who_share_public_rooms(user_id, other_user_id, room_id);
+CREATE INDEX users_who_share_public_rooms_r_idx ON users_who_share_public_rooms(room_id);
+CREATE INDEX users_who_share_public_rooms_o_idx ON users_who_share_public_rooms(other_user_id);
+
+CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id);
+CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id);
+CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id);
+
+-- Make sure that we populate the tables initially by resetting the stream ID
+UPDATE user_directory_stream_pos SET stream_id = NULL;
diff --git a/synapse/storage/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/schema/delta/53/users_in_public_rooms.sql
new file mode 100644
index 0000000000..f7827ca6d2
--- /dev/null
+++ b/synapse/storage/schema/delta/53/users_in_public_rooms.sql
@@ -0,0 +1,28 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We don't need the old version of this table.
+DROP TABLE IF EXISTS users_in_public_rooms;
+
+-- Old version of users_in_public_rooms
+DROP TABLE IF EXISTS users_who_share_public_rooms;
+
+-- Track what users are in public rooms.
+CREATE TABLE IF NOT EXISTS users_in_public_rooms (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id);
diff --git a/synapse/storage/schema/full_schemas/11/event_edges.sql b/synapse/storage/schema/full_schemas/11/event_edges.sql
index 52eec88357..bccd1c6f74 100644
--- a/synapse/storage/schema/full_schemas/11/event_edges.sql
+++ b/synapse/storage/schema/full_schemas/11/event_edges.sql
@@ -37,6 +37,8 @@ CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT NOT NULL,
prev_event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
+ -- We no longer insert prev_state into this table, so all new rows will have
+ -- is_state as false.
is_state BOOL NOT NULL,
UNIQUE (event_id, prev_event_id, room_id, is_state)
);
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py
new file mode 100644
index 0000000000..57bc45cdb9
--- /dev/null
+++ b/synapse/storage/state_deltas.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class StateDeltasStore(SQLBaseStore):
+
+ def get_current_state_deltas(self, prev_stream_id):
+ prev_stream_id = int(prev_stream_id)
+ if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+ return []
+
+ def get_current_state_deltas_txn(txn):
+ # First we calculate the max stream id that will give us less than
+ # N results.
+ # We arbitarily limit to 100 stream_id entries to ensure we don't
+ # select toooo many.
+ sql = """
+ SELECT stream_id, count(*)
+ FROM current_state_delta_stream
+ WHERE stream_id > ?
+ GROUP BY stream_id
+ ORDER BY stream_id ASC
+ LIMIT 100
+ """
+ txn.execute(sql, (prev_stream_id,))
+
+ total = 0
+ max_stream_id = prev_stream_id
+ for max_stream_id, count in txn:
+ total += count
+ if total > 100:
+ # We arbitarily limit to 100 entries to ensure we don't
+ # select toooo many.
+ break
+
+ # Now actually get the deltas
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ """
+ txn.execute(sql, (prev_stream_id, max_stream_id,))
+ return self.cursor_to_dict(txn)
+
+ return self.runInteraction(
+ "get_current_state_deltas", get_current_state_deltas_txn
+ )
+
+ def get_max_stream_id_in_current_state_deltas(self):
+ return self._simple_select_one_onecol(
+ table="current_state_delta_stream",
+ keyvalues={},
+ retcol="COALESCE(MAX(stream_id), -1)",
+ desc="get_max_stream_id_in_current_state_deltas",
+ )
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d6cfdba519..580fafeb3a 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -191,6 +191,25 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
order='DESC'):
+ """Get new room events in stream ordering since `from_key`.
+
+ Args:
+ room_id (str)
+ from_key (str): Token from which no events are returned before
+ to_key (str): Token from which no events are returned after. (This
+ is typically the current stream token)
+ limit (int): Maximum number of events to return
+ order (str): Either "DESC" or "ASC". Determines which events are
+ returned when the result is limited. If "DESC" then the most
+ recent `limit` events are returned, otherwise returns the
+ oldest `limit` events.
+
+ Returns:
+ Deferred[dict[str,tuple[list[FrozenEvent], str]]]
+ A map from room id to a tuple containing:
+ - list of recent events in the room
+ - stream ordering key for the start of the chunk of events returned.
+ """
from_id = RoomStreamToken.parse_stream_token(from_key).stream
room_ids = yield self._events_stream_cache.get_entities_changed(
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index fea866c043..4d60a5726f 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -16,22 +16,314 @@
import logging
import re
-from six import iteritems
-
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
+from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.state import StateFilter
+from synapse.storage.state_deltas import StateDeltasStore
from synapse.types import get_domain_from_id, get_localpart_from_id
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-
-from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
-class UserDirectoryStore(SQLBaseStore):
+TEMP_TABLE = "_temp_populate_user_directory"
+
+
+class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
+
+ # How many records do we calculate before sending it to
+ # add_users_who_share_private_rooms?
+ SHARE_PRIVATE_WORKING_SET = 500
+
+ def __init__(self, db_conn, hs):
+ super(UserDirectoryStore, self).__init__(db_conn, hs)
+
+ self.server_name = hs.hostname
+
+ self.register_background_update_handler(
+ "populate_user_directory_createtables",
+ self._populate_user_directory_createtables,
+ )
+ self.register_background_update_handler(
+ "populate_user_directory_process_rooms",
+ self._populate_user_directory_process_rooms,
+ )
+ self.register_background_update_handler(
+ "populate_user_directory_process_users",
+ self._populate_user_directory_process_users,
+ )
+ self.register_background_update_handler(
+ "populate_user_directory_cleanup", self._populate_user_directory_cleanup
+ )
+
+ @defer.inlineCallbacks
+ def _populate_user_directory_createtables(self, progress, batch_size):
+
+ # Get all the rooms that we want to process.
+ def _make_staging_area(txn):
+ sql = (
+ "CREATE TABLE IF NOT EXISTS "
+ + TEMP_TABLE
+ + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)"
+ )
+ txn.execute(sql)
+
+ sql = (
+ "CREATE TABLE IF NOT EXISTS "
+ + TEMP_TABLE
+ + "_position(position TEXT NOT NULL)"
+ )
+ txn.execute(sql)
+
+ # Get rooms we want to process from the database
+ sql = """
+ SELECT room_id, count(*) FROM current_state_events
+ GROUP BY room_id
+ """
+ txn.execute(sql)
+ rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
+ self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+ del rooms
+
+ # If search all users is on, get all the users we want to add.
+ if self.hs.config.user_directory_search_all_users:
+ sql = (
+ "CREATE TABLE IF NOT EXISTS "
+ + TEMP_TABLE
+ + "_users(user_id TEXT NOT NULL)"
+ )
+ txn.execute(sql)
+
+ txn.execute("SELECT name FROM users")
+ users = [{"user_id": x[0]} for x in txn.fetchall()]
+
+ self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+
+ new_pos = yield self.get_max_stream_id_in_current_state_deltas()
+ yield self.runInteraction(
+ "populate_user_directory_temp_build", _make_staging_area
+ )
+ yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+
+ yield self._end_background_update("populate_user_directory_createtables")
+ defer.returnValue(1)
+
+ @defer.inlineCallbacks
+ def _populate_user_directory_cleanup(self, progress, batch_size):
+ """
+ Update the user directory stream position, then clean up the old tables.
+ """
+ position = yield self._simple_select_one_onecol(
+ TEMP_TABLE + "_position", None, "position"
+ )
+ yield self.update_user_directory_stream_pos(position)
+
+ def _delete_staging_area(txn):
+ txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
+ txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
+ txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
+
+ yield self.runInteraction(
+ "populate_user_directory_cleanup", _delete_staging_area
+ )
+
+ yield self._end_background_update("populate_user_directory_cleanup")
+ defer.returnValue(1)
+
+ @defer.inlineCallbacks
+ def _populate_user_directory_process_rooms(self, progress, batch_size):
+ """
+ Args:
+ progress (dict)
+ batch_size (int): Maximum number of state events to process
+ per cycle.
+ """
+ state = self.hs.get_state_handler()
+
+ # If we don't have progress filed, delete everything.
+ if not progress:
+ yield self.delete_all_from_user_dir()
+
+ def _get_next_batch(txn):
+ # Only fetch 250 rooms, so we don't fetch too many at once, even
+ # if those 250 rooms have less than batch_size state events.
+ sql = """
+ SELECT room_id, events FROM %s
+ ORDER BY events DESC
+ LIMIT 250
+ """ % (
+ TEMP_TABLE + "_rooms",
+ )
+ txn.execute(sql)
+ rooms_to_work_on = txn.fetchall()
+
+ if not rooms_to_work_on:
+ return None
+
+ # Get how many are left to process, so we can give status on how
+ # far we are in processing
+ txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
+ progress["remaining"] = txn.fetchone()[0]
+
+ return rooms_to_work_on
+
+ rooms_to_work_on = yield self.runInteraction(
+ "populate_user_directory_temp_read", _get_next_batch
+ )
+
+ # No more rooms -- complete the transaction.
+ if not rooms_to_work_on:
+ yield self._end_background_update("populate_user_directory_process_rooms")
+ defer.returnValue(1)
+
+ logger.info(
+ "Processing the next %d rooms of %d remaining"
+ % (len(rooms_to_work_on), progress["remaining"])
+ )
+
+ processed_event_count = 0
+
+ for room_id, event_count in rooms_to_work_on:
+ is_in_room = yield self.is_host_joined(room_id, self.server_name)
+
+ if is_in_room:
+ is_public = yield self.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
+
+ users_with_profile = yield state.get_current_user_in_room(room_id)
+ user_ids = set(users_with_profile)
+
+ # Update each user in the user directory.
+ for user_id, profile in users_with_profile.items():
+ yield self.update_profile_in_user_dir(
+ user_id, profile.display_name, profile.avatar_url
+ )
+
+ to_insert = set()
+
+ if is_public:
+ for user_id in user_ids:
+ if self.get_if_app_services_interested_in_user(user_id):
+ continue
+
+ to_insert.add(user_id)
+
+ if to_insert:
+ yield self.add_users_in_public_rooms(room_id, to_insert)
+ to_insert.clear()
+ else:
+ for user_id in user_ids:
+ if not self.hs.is_mine_id(user_id):
+ continue
+
+ if self.get_if_app_services_interested_in_user(user_id):
+ continue
+
+ for other_user_id in user_ids:
+ if user_id == other_user_id:
+ continue
+
+ user_set = (user_id, other_user_id)
+ to_insert.add(user_set)
+
+ # If it gets too big, stop and write to the database
+ # to prevent storing too much in RAM.
+ if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET:
+ yield self.add_users_who_share_private_room(
+ room_id, to_insert
+ )
+ to_insert.clear()
+
+ if to_insert:
+ yield self.add_users_who_share_private_room(room_id, to_insert)
+ to_insert.clear()
+
+ # We've finished a room. Delete it from the table.
+ yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
+ # Update the remaining counter.
+ progress["remaining"] -= 1
+ yield self.runInteraction(
+ "populate_user_directory",
+ self._background_update_progress_txn,
+ "populate_user_directory_process_rooms",
+ progress,
+ )
+
+ processed_event_count += event_count
+
+ if processed_event_count > batch_size:
+ # Don't process any more rooms, we've hit our batch size.
+ defer.returnValue(processed_event_count)
+
+ defer.returnValue(processed_event_count)
+
+ @defer.inlineCallbacks
+ def _populate_user_directory_process_users(self, progress, batch_size):
+ """
+ If search_all_users is enabled, add all of the users to the user directory.
+ """
+ if not self.hs.config.user_directory_search_all_users:
+ yield self._end_background_update("populate_user_directory_process_users")
+ defer.returnValue(1)
+
+ def _get_next_batch(txn):
+ sql = "SELECT user_id FROM %s LIMIT %s" % (
+ TEMP_TABLE + "_users",
+ str(batch_size),
+ )
+ txn.execute(sql)
+ users_to_work_on = txn.fetchall()
+
+ if not users_to_work_on:
+ return None
+
+ users_to_work_on = [x[0] for x in users_to_work_on]
+
+ # Get how many are left to process, so we can give status on how
+ # far we are in processing
+ sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
+ txn.execute(sql)
+ progress["remaining"] = txn.fetchone()[0]
+
+ return users_to_work_on
+
+ users_to_work_on = yield self.runInteraction(
+ "populate_user_directory_temp_read", _get_next_batch
+ )
+
+ # No more users -- complete the transaction.
+ if not users_to_work_on:
+ yield self._end_background_update("populate_user_directory_process_users")
+ defer.returnValue(1)
+
+ logger.info(
+ "Processing the next %d users of %d remaining"
+ % (len(users_to_work_on), progress["remaining"])
+ )
+
+ for user_id in users_to_work_on:
+ profile = yield self.get_profileinfo(get_localpart_from_id(user_id))
+ yield self.update_profile_in_user_dir(
+ user_id, profile.display_name, profile.avatar_url
+ )
+
+ # We've finished processing a user. Delete it from the table.
+ yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
+ # Update the remaining counter.
+ progress["remaining"] -= 1
+ yield self.runInteraction(
+ "populate_user_directory",
+ self._background_update_progress_txn,
+ "populate_user_directory_process_users",
+ progress,
+ )
+
+ defer.returnValue(len(users_to_work_on))
+
@defer.inlineCallbacks
def is_room_world_readable_or_publicly_joinable(self, room_id):
"""Check if the room is either world_readable or publically joinable
@@ -63,106 +355,16 @@ class UserDirectoryStore(SQLBaseStore):
defer.returnValue(False)
- @defer.inlineCallbacks
- def add_users_to_public_room(self, room_id, user_ids):
- """Add user to the list of users in public rooms
-
- Args:
- room_id (str): A room_id that all users are in that is world_readable
- or publically joinable
- user_ids (list(str)): Users to add
+ def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
"""
- yield self._simple_insert_many(
- table="users_in_public_rooms",
- values=[{"user_id": user_id, "room_id": room_id} for user_id in user_ids],
- desc="add_users_to_public_room",
- )
- for user_id in user_ids:
- self.get_user_in_public_room.invalidate((user_id,))
-
- def add_profiles_to_user_dir(self, room_id, users_with_profile):
- """Add profiles to the user directory
-
- Args:
- room_id (str): A room_id that all users are joined to
- users_with_profile (dict): Users to add to directory in the form of
- mapping of user_id -> ProfileInfo
+ Update or add a user's profile in the user directory.
"""
- if isinstance(self.database_engine, PostgresEngine):
- # We weight the loclpart most highly, then display name and finally
- # server name
- sql = """
- INSERT INTO user_directory_search(user_id, vector)
- VALUES (?,
- setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
- )
- """
- args = (
- (
- user_id,
- get_localpart_from_id(user_id),
- get_domain_from_id(user_id),
- profile.display_name,
- )
- for user_id, profile in iteritems(users_with_profile)
- )
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = """
- INSERT INTO user_directory_search(user_id, value)
- VALUES (?,?)
- """
- args = (
- (
- user_id,
- "%s %s" % (user_id, p.display_name) if p.display_name else user_id,
- )
- for user_id, p in iteritems(users_with_profile)
- )
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
-
- def _add_profiles_to_user_dir_txn(txn):
- txn.executemany(sql, args)
- self._simple_insert_many_txn(
- txn,
- table="user_directory",
- values=[
- {
- "user_id": user_id,
- "room_id": room_id,
- "display_name": profile.display_name,
- "avatar_url": profile.avatar_url,
- }
- for user_id, profile in iteritems(users_with_profile)
- ],
- )
- for user_id in users_with_profile:
- txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.runInteraction(
- "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
- )
-
- @defer.inlineCallbacks
- def update_user_in_user_dir(self, user_id, room_id):
- yield self._simple_update_one(
- table="user_directory",
- keyvalues={"user_id": user_id},
- updatevalues={"room_id": room_id},
- desc="update_user_in_user_dir",
- )
- self.get_user_in_directory.invalidate((user_id,))
-
- def update_profile_in_user_dir(self, user_id, display_name, avatar_url, room_id):
def _update_profile_in_user_dir_txn(txn):
new_entry = self._simple_upsert_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
- insertion_values={"room_id": room_id},
values={"display_name": display_name, "avatar_url": avatar_url},
lock=False, # We're only inserter
)
@@ -250,16 +452,6 @@ class UserDirectoryStore(SQLBaseStore):
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
- @defer.inlineCallbacks
- def update_user_in_public_user_list(self, user_id, room_id):
- yield self._simple_update_one(
- table="users_in_public_rooms",
- keyvalues={"user_id": user_id},
- updatevalues={"room_id": room_id},
- desc="update_user_in_public_user_list",
- )
- self.get_user_in_public_room.invalidate((user_id,))
-
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self._simple_delete_txn(
@@ -271,234 +463,154 @@ class UserDirectoryStore(SQLBaseStore):
self._simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"other_user_id": user_id},
+ )
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- txn.call_after(self.get_user_in_public_room.invalidate, (user_id,))
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
@defer.inlineCallbacks
- def remove_from_user_in_public_room(self, user_id):
- yield self._simple_delete(
- table="users_in_public_rooms",
- keyvalues={"user_id": user_id},
- desc="remove_from_user_in_public_room",
- )
- self.get_user_in_public_room.invalidate((user_id,))
-
- def get_users_in_public_due_to_room(self, room_id):
- """Get all user_ids that are in the room directory because they're
- in the given room_id
- """
- return self._simple_select_onecol(
- table="users_in_public_rooms",
- keyvalues={"room_id": room_id},
- retcol="user_id",
- desc="get_users_in_public_due_to_room",
- )
-
- @defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
- user_ids_dir = yield self._simple_select_onecol(
- table="user_directory",
- keyvalues={"room_id": room_id},
- retcol="user_id",
- desc="get_users_in_dir_due_to_room",
- )
-
- user_ids_pub = yield self._simple_select_onecol(
+ user_ids_share_pub = yield self._simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
- user_ids_share = yield self._simple_select_onecol(
- table="users_who_share_rooms",
+ user_ids_share_priv = yield self._simple_select_onecol(
+ table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
- retcol="user_id",
+ retcol="other_user_id",
desc="get_users_in_dir_due_to_room",
)
- user_ids = set(user_ids_dir)
- user_ids.update(user_ids_pub)
- user_ids.update(user_ids_share)
+ user_ids = set(user_ids_share_pub)
+ user_ids.update(user_ids_share_priv)
defer.returnValue(user_ids)
- @defer.inlineCallbacks
- def get_all_rooms(self):
- """Get all room_ids we've ever known about, in ascending order of "size"
- """
- sql = """
- SELECT room_id FROM current_state_events
- GROUP BY room_id
- ORDER BY count(*) ASC
- """
- rows = yield self._execute("get_all_rooms", None, sql)
- defer.returnValue([room_id for room_id, in rows])
-
- @defer.inlineCallbacks
- def get_all_local_users(self):
- """Get all local users
- """
- sql = """
- SELECT name FROM users
- """
- rows = yield self._execute("get_all_local_users", None, sql)
- defer.returnValue([name for name, in rows])
-
- def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
- """Insert entries into the users_who_share_rooms table. The first
+ def add_users_who_share_private_room(self, room_id, user_id_tuples):
+ """Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
room_id (str)
- share_private (bool): Is the room private
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
"""
def _add_users_who_share_room_txn(txn):
- self._simple_insert_many_txn(
+ self._simple_upsert_many_txn(
txn,
- table="users_who_share_rooms",
- values=[
- {
- "user_id": user_id,
- "other_user_id": other_user_id,
- "room_id": room_id,
- "share_private": share_private,
- }
+ table="users_who_share_private_rooms",
+ key_names=["user_id", "other_user_id", "room_id"],
+ key_values=[
+ (user_id, other_user_id, room_id)
for user_id, other_user_id in user_id_tuples
],
+ value_names=(),
+ value_values=None,
)
- for user_id, other_user_id in user_id_tuples:
- txn.call_after(
- self.get_users_who_share_room_from_dir.invalidate, (user_id,)
- )
- txn.call_after(
- self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
- )
return self.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
- def update_users_who_share_room(self, room_id, share_private, user_id_sets):
- """Updates entries in the users_who_share_rooms table. The first
+ def add_users_in_public_rooms(self, room_id, user_ids):
+ """Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
room_id (str)
- share_private (bool): Is the room private
- user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ user_ids (list[str])
"""
- def _update_users_who_share_room_txn(txn):
- sql = """
- UPDATE users_who_share_rooms
- SET room_id = ?, share_private = ?
- WHERE user_id = ? AND other_user_id = ?
- """
- txn.executemany(
- sql, ((room_id, share_private, uid, oid) for uid, oid in user_id_sets)
+ def _add_users_in_public_rooms_txn(txn):
+
+ self._simple_upsert_many_txn(
+ txn,
+ table="users_in_public_rooms",
+ key_names=["user_id", "room_id"],
+ key_values=[(user_id, room_id) for user_id in user_ids],
+ value_names=(),
+ value_values=None,
)
- for user_id, other_user_id in user_id_sets:
- txn.call_after(
- self.get_users_who_share_room_from_dir.invalidate, (user_id,)
- )
- txn.call_after(
- self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
- )
return self.runInteraction(
- "update_users_who_share_room", _update_users_who_share_room_txn
+ "add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
- def remove_user_who_share_room(self, user_id, other_user_id):
- """Deletes entries in the users_who_share_rooms table. The first
+ def remove_user_who_share_room(self, user_id, room_id):
+ """
+ Deletes entries in the users_who_share_*_rooms table. The first
user should be a local user.
Args:
+ user_id (str)
room_id (str)
- share_private (bool): Is the room private
- user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
"""
def _remove_user_who_share_room_txn(txn):
self._simple_delete_txn(
txn,
- table="users_who_share_rooms",
- keyvalues={"user_id": user_id, "other_user_id": other_user_id},
+ table="users_who_share_private_rooms",
+ keyvalues={"user_id": user_id, "room_id": room_id},
)
- txn.call_after(
- self.get_users_who_share_room_from_dir.invalidate, (user_id,)
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"other_user_id": user_id, "room_id": room_id},
)
- txn.call_after(
- self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
+ self._simple_delete_txn(
+ txn,
+ table="users_in_public_rooms",
+ keyvalues={"user_id": user_id, "room_id": room_id},
)
return self.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
- @cached(max_entries=500000)
- def get_if_users_share_a_room(self, user_id, other_user_id):
- """Gets if users share a room.
-
- Args:
- user_id (str): Must be a local user_id
- other_user_id (str)
-
- Returns:
- bool|None: None if they don't share a room, otherwise whether they
- share a private room or not.
+ @defer.inlineCallbacks
+ def get_user_dir_rooms_user_is_in(self, user_id):
"""
- return self._simple_select_one_onecol(
- table="users_who_share_rooms",
- keyvalues={"user_id": user_id, "other_user_id": other_user_id},
- retcol="share_private",
- allow_none=True,
- desc="get_if_users_share_a_room",
- )
-
- @cachedInlineCallbacks(max_entries=500000, iterable=True)
- def get_users_who_share_room_from_dir(self, user_id):
- """Returns the set of users who share a room with `user_id`
+ Returns the rooms that a user is in.
Args:
user_id(str): Must be a local user
Returns:
- dict: user_id -> share_private mapping
+ list: user_id
"""
- rows = yield self._simple_select_list(
- table="users_who_share_rooms",
+ rows = yield self._simple_select_onecol(
+ table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
- retcols=("other_user_id", "share_private"),
- desc="get_users_who_share_room_with_user",
+ retcol="room_id",
+ desc="get_rooms_user_is_in",
)
- defer.returnValue({row["other_user_id"]: row["share_private"] for row in rows})
-
- def get_users_in_share_dir_with_room_id(self, user_id, room_id):
- """Get all user tuples that are in the users_who_share_rooms due to the
- given room_id.
-
- Returns:
- [(user_id, other_user_id)]: where one of the two will match the given
- user_id.
- """
- sql = """
- SELECT user_id, other_user_id FROM users_who_share_rooms
- WHERE room_id = ? AND (user_id = ? OR other_user_id = ?)
- """
- return self._execute(
- "get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id
+ pub_rows = yield self._simple_select_onecol(
+ table="users_in_public_rooms",
+ keyvalues={"user_id": user_id},
+ retcol="room_id",
+ desc="get_rooms_user_is_in",
)
+ users = set(pub_rows)
+ users.update(rows)
+ defer.returnValue(list(users))
+
@defer.inlineCallbacks
def get_rooms_in_common_for_users(self, user_id, other_user_id):
"""Given two user_ids find out the list of rooms they share.
@@ -533,11 +645,8 @@ class UserDirectoryStore(SQLBaseStore):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
- txn.execute("DELETE FROM users_who_share_rooms")
+ txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
- txn.call_after(self.get_user_in_public_room.invalidate_all)
- txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
- txn.call_after(self.get_if_users_share_a_room.invalidate_all)
return self.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
@@ -548,21 +657,11 @@ class UserDirectoryStore(SQLBaseStore):
return self._simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
- retcols=("room_id", "display_name", "avatar_url"),
+ retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
)
- @cached()
- def get_user_in_public_room(self, user_id):
- return self._simple_select_one(
- table="users_in_public_rooms",
- keyvalues={"user_id": user_id},
- retcols=("room_id",),
- allow_none=True,
- desc="get_user_in_public_room",
- )
-
def get_user_directory_stream_pos(self):
return self._simple_select_one_onecol(
table="user_directory_stream_pos",
@@ -579,59 +678,6 @@ class UserDirectoryStore(SQLBaseStore):
desc="update_user_directory_stream_pos",
)
- def get_current_state_deltas(self, prev_stream_id):
- prev_stream_id = int(prev_stream_id)
- if not self._curr_state_delta_stream_cache.has_any_entity_changed(
- prev_stream_id
- ):
- return []
-
- def get_current_state_deltas_txn(txn):
- # First we calculate the max stream id that will give us less than
- # N results.
- # We arbitarily limit to 100 stream_id entries to ensure we don't
- # select toooo many.
- sql = """
- SELECT stream_id, count(*)
- FROM current_state_delta_stream
- WHERE stream_id > ?
- GROUP BY stream_id
- ORDER BY stream_id ASC
- LIMIT 100
- """
- txn.execute(sql, (prev_stream_id,))
-
- total = 0
- max_stream_id = prev_stream_id
- for max_stream_id, count in txn:
- total += count
- if total > 100:
- # We arbitarily limit to 100 entries to ensure we don't
- # select toooo many.
- break
-
- # Now actually get the deltas
- sql = """
- SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
- FROM current_state_delta_stream
- WHERE ? < stream_id AND stream_id <= ?
- ORDER BY stream_id ASC
- """
- txn.execute(sql, (prev_stream_id, max_stream_id))
- return self.cursor_to_dict(txn)
-
- return self.runInteraction(
- "get_current_state_deltas", get_current_state_deltas_txn
- )
-
- def get_max_stream_id_in_current_state_deltas(self):
- return self._simple_select_one_onecol(
- table="current_state_delta_stream",
- keyvalues={},
- retcol="COALESCE(MAX(stream_id), -1)",
- desc="get_max_stream_id_in_current_state_deltas",
- )
-
@defer.inlineCallbacks
def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory
@@ -652,22 +698,19 @@ class UserDirectoryStore(SQLBaseStore):
"""
if self.hs.config.user_directory_search_all_users:
- # make s.user_id null to keep the ordering algorithm happy
- join_clause = """
- CROSS JOIN (SELECT NULL as user_id) AS s
- """
- join_args = ()
- where_clause = "1=1"
+ join_args = (user_id,)
+ where_clause = "user_id != ?"
else:
- join_clause = """
- LEFT JOIN users_in_public_rooms AS p USING (user_id)
- LEFT JOIN (
- SELECT other_user_id AS user_id FROM users_who_share_rooms
- WHERE user_id = ? AND share_private
- ) AS s USING (user_id)
- """
join_args = (user_id,)
- where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
+ where_clause = """
+ (
+ EXISTS (select 1 from users_in_public_rooms WHERE user_id = t.user_id)
+ OR EXISTS (
+ SELECT 1 FROM users_who_share_private_rooms
+ WHERE user_id = ? AND other_user_id = t.user_id
+ )
+ )
+ """
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@@ -679,14 +722,13 @@ class UserDirectoryStore(SQLBaseStore):
# search: (domain, _, display name, localpart)
sql = """
SELECT d.user_id AS user_id, display_name, avatar_url
- FROM user_directory_search
+ FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
- %s
WHERE
%s
AND vector @@ to_tsquery('english', ?)
ORDER BY
- (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
+ (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
* (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
* (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
* (
@@ -708,7 +750,6 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL
LIMIT ?
""" % (
- join_clause,
where_clause,
)
args = join_args + (full_query, exact_query, prefix_query, limit + 1)
@@ -717,9 +758,8 @@ class UserDirectoryStore(SQLBaseStore):
sql = """
SELECT d.user_id AS user_id, display_name, avatar_url
- FROM user_directory_search
+ FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
- %s
WHERE
%s
AND value MATCH ?
@@ -729,7 +769,6 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL
LIMIT ?
""" % (
- join_clause,
where_clause,
)
args = join_args + (search_query, limit + 1)
|