diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e93c3de66c..73fb334dd6 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -14,6 +14,8 @@
# limitations under the License.
from twisted.internet import defer
+
+from synapse.storage.devices import DeviceStore
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
@@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore,
OpenIdStore,
ClientIpStore,
+ DeviceStore,
):
def __init__(self, db_conn, hs):
@@ -92,7 +95,8 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("local_invites", "stream_id")]
)
self._backfill_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering", step=-1
+ db_conn, "events", "stream_ordering", step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 32c6677d47..0117fdc639 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -597,10 +597,13 @@ class SQLBaseStore(object):
more rows, returning the result as a list of dicts.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the rows with,
- or None to not apply a WHERE clause.
- retcols : list of strings giving the names of the columns to return
+ table (str): the table name
+ keyvalues (dict[str, Any] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
desc,
@@ -615,9 +618,11 @@ class SQLBaseStore(object):
Args:
txn : Transaction object
- table : string giving the table name
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ table (str): the table name
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -807,6 +812,11 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
+ def _simple_delete(self, table, keyvalues, desc):
+ return self.runInteraction(
+ desc, self._simple_delete_txn, table, keyvalues
+ )
+
@staticmethod
def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 66a995157d..30d0e4c5dc 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -14,6 +14,7 @@
# limitations under the License.
from ._base import SQLBaseStore
+from . import engines
from twisted.internet import defer
@@ -87,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks
def start_doing_background_updates(self):
- while True:
- if self._background_update_timer is not None:
- return
+ assert self._background_update_timer is None, \
+ "background updates already running"
+
+ logger.info("Starting background schema updates")
+ while True:
sleep = defer.Deferred()
self._background_update_timer = self._clock.call_later(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
@@ -101,22 +104,23 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_timer = None
try:
- result = yield self.do_background_update(
+ result = yield self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
except:
logger.exception("Error doing update")
-
- if result is None:
- logger.info(
- "No more background updates to do."
- " Unscheduling background update task."
- )
- return
+ else:
+ if result is None:
+ logger.info(
+ "No more background updates to do."
+ " Unscheduling background update task."
+ )
+ defer.returnValue(None)
@defer.inlineCallbacks
- def do_background_update(self, desired_duration_ms):
- """Does some amount of work on a background update
+ def do_next_background_update(self, desired_duration_ms):
+ """Does some amount of work on the next queued background update
+
Args:
desired_duration_ms(float): How long we want to spend
updating.
@@ -135,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue.append(update['update_name'])
if not self._background_update_queue:
+ # no work left to do
defer.returnValue(None)
+ # pop from the front, and add back to the back
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
+ res = yield self._do_background_update(update_name, desired_duration_ms)
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def _do_background_update(self, update_name, desired_duration_ms):
+ logger.info("Starting update batch on background update '%s'",
+ update_name)
+
update_handler = self._background_update_handlers[update_name]
performance = self._background_update_performance.get(update_name)
@@ -202,6 +216,64 @@ class BackgroundUpdateStore(SQLBaseStore):
"""
self._background_update_handlers[update_name] = update_handler
+ def register_background_index_update(self, update_name, index_name,
+ table, columns):
+ """Helper for store classes to do a background index addition
+
+ To use:
+
+ 1. use a schema delta file to add a background update. Example:
+ INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('my_new_index', '{}');
+
+ 2. In the Store constructor, call this method
+
+ Args:
+ update_name (str): update_name to register for
+ index_name (str): name of index to add
+ table (str): table to add index to
+ columns (list[str]): columns/expressions to include in index
+ """
+
+ # if this is postgres, we add the indexes concurrently. Otherwise
+ # we fall back to doing it inline
+ if isinstance(self.database_engine, engines.PostgresEngine):
+ conc = True
+ else:
+ conc = False
+
+ sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \
+ % {
+ "conc": "CONCURRENTLY" if conc else "",
+ "name": index_name,
+ "table": table,
+ "columns": ", ".join(columns),
+ }
+
+ def create_index_concurrently(conn):
+ conn.rollback()
+ # postgres insists on autocommit for the index
+ conn.set_session(autocommit=True)
+ c = conn.cursor()
+ c.execute(sql)
+ conn.set_session(autocommit=False)
+
+ def create_index(conn):
+ c = conn.cursor()
+ c.execute(sql)
+
+ @defer.inlineCallbacks
+ def updater(progress, batch_size):
+ logger.info("Adding index %s to %s", index_name, table)
+ if conc:
+ yield self.runWithConnection(create_index_concurrently)
+ else:
+ yield self.runWithConnection(create_index)
+ yield self._end_background_update(update_name)
+ defer.returnValue(1)
+
+ self.register_background_update_handler(update_name, updater)
+
def start_background_update(self, update_name, progress):
"""Starts a background update running.
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index a90990e006..71e5ea112f 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, Cache
+import logging
from twisted.internet import defer
+from ._base import Cache
+from . import background_updates
+
+logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits
@@ -24,8 +28,7 @@ from twisted.internet import defer
LAST_SEEN_GRANULARITY = 120 * 1000
-class ClientIpStore(SQLBaseStore):
-
+class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
@@ -34,8 +37,15 @@ class ClientIpStore(SQLBaseStore):
super(ClientIpStore, self).__init__(hs)
+ self.register_background_index_update(
+ "user_ips_device_index",
+ index_name="user_ips_device_id",
+ table="user_ips",
+ columns=["user_id", "device_id", "last_seen"],
+ )
+
@defer.inlineCallbacks
- def insert_client_ip(self, user, access_token, ip, user_agent):
+ def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
key = (user.to_string(), access_token, ip)
@@ -59,6 +69,7 @@ class ClientIpStore(SQLBaseStore):
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
+ "device_id": device_id,
},
values={
"last_seen": now,
@@ -66,3 +77,69 @@ class ClientIpStore(SQLBaseStore):
desc="insert_client_ip",
lock=False,
)
+
+ @defer.inlineCallbacks
+ def get_last_client_ip_by_device(self, devices):
+ """For each device_id listed, give the user_ip it was last seen on
+
+ Args:
+ devices (iterable[(str, str)]): list of (user_id, device_id) pairs
+
+ Returns:
+ defer.Deferred: resolves to a dict, where the keys
+ are (user_id, device_id) tuples. The values are also dicts, with
+ keys giving the column names
+ """
+
+ res = yield self.runInteraction(
+ "get_last_client_ip_by_device",
+ self._get_last_client_ip_by_device_txn,
+ retcols=(
+ "user_id",
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ),
+ devices=devices
+ )
+
+ ret = {(d["user_id"], d["device_id"]): d for d in res}
+ defer.returnValue(ret)
+
+ @classmethod
+ def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+ where_clauses = []
+ bindings = []
+ for (user_id, device_id) in devices:
+ if device_id is None:
+ where_clauses.append("(user_id = ? AND device_id IS NULL)")
+ bindings.extend((user_id, ))
+ else:
+ where_clauses.append("(user_id = ? AND device_id = ?)")
+ bindings.extend((user_id, device_id))
+
+ inner_select = (
+ "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
+ "WHERE %(where)s "
+ "GROUP BY user_id, device_id"
+ ) % {
+ "where": " OR ".join(where_clauses),
+ }
+
+ sql = (
+ "SELECT %(retcols)s FROM user_ips "
+ "JOIN (%(inner_select)s) ips ON"
+ " user_ips.last_seen = ips.mls AND"
+ " user_ips.user_id = ips.user_id AND"
+ " (user_ips.device_id = ips.device_id OR"
+ " (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
+ " )"
+ ) % {
+ "retcols": ",".join("user_ips." + c for c in retcols),
+ "inner_select": inner_select,
+ }
+
+ txn.execute(sql, bindings)
+ return cls.cursor_to_dict(txn)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
new file mode 100644
index 0000000000..afd6530cab
--- /dev/null
+++ b/synapse/storage/devices.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from ._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def store_device(self, user_id, device_id,
+ initial_device_display_name,
+ ignore_if_known=True):
+ """Ensure the given device is known; add it to the store if not
+
+ Args:
+ user_id (str): id of user associated with the device
+ device_id (str): id of device
+ initial_device_display_name (str): initial displayname of the
+ device
+ ignore_if_known (bool): ignore integrity errors which mean the
+ device is already known
+ Returns:
+ defer.Deferred
+ Raises:
+ StoreError: if ignore_if_known is False and the device was already
+ known
+ """
+ try:
+ yield self._simple_insert(
+ "devices",
+ values={
+ "user_id": user_id,
+ "device_id": device_id,
+ "display_name": initial_device_display_name
+ },
+ desc="store_device",
+ or_ignore=ignore_if_known,
+ )
+ except Exception as e:
+ logger.error("store_device with device_id=%s failed: %s",
+ device_id, e)
+ raise StoreError(500, "Problem storing device.")
+
+ def get_device(self, user_id, device_id):
+ """Retrieve a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to retrieve
+ Returns:
+ defer.Deferred for a dict containing the device information
+ Raises:
+ StoreError: if the device is not found
+ """
+ return self._simple_select_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_device",
+ )
+
+ def delete_device(self, user_id, device_id):
+ """Delete a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to delete
+ Returns:
+ defer.Deferred
+ """
+ return self._simple_delete_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="delete_device",
+ )
+
+ def update_device(self, user_id, device_id, new_display_name=None):
+ """Update a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to update
+ new_display_name (str|None): new displayname for device; None
+ to leave unchanged
+ Raises:
+ StoreError: if the device is not found
+ Returns:
+ defer.Deferred
+ """
+ updates = {}
+ if new_display_name is not None:
+ updates["display_name"] = new_display_name
+ if not updates:
+ return defer.succeed(None)
+ return self._simple_update_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ updatevalues=updates,
+ desc="update_device",
+ )
+
+ @defer.inlineCallbacks
+ def get_devices_by_user(self, user_id):
+ """Retrieve all of a user's registered devices.
+
+ Args:
+ user_id (str):
+ Returns:
+ defer.Deferred: resolves to a dict from device_id to a dict
+ containing "device_id", "user_id" and "display_name" for each
+ device.
+ """
+ devices = yield self._simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user"
+ )
+
+ defer.returnValue({d["device_id"]: d for d in devices})
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2e89066515..385d607056 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections
+
+import twisted.internet.defer
from ._base import SQLBaseStore
@@ -36,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore):
query_list(list): List of pairs of user_ids and device_ids.
Returns:
Dict mapping from user-id to dict mapping from device_id to
- key json byte strings.
+ dict containing "key_json", "device_display_name".
"""
- def _get_e2e_device_keys(txn):
- result = {}
- for user_id, device_id in query_list:
- user_result = result.setdefault(user_id, {})
- keyvalues = {"user_id": user_id}
- if device_id:
- keyvalues["device_id"] = device_id
- rows = self._simple_select_list_txn(
- txn, table="e2e_device_keys_json",
- keyvalues=keyvalues,
- retcols=["device_id", "key_json"]
- )
- for row in rows:
- user_result[row["device_id"]] = row["key_json"]
- return result
- return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
+ if not query_list:
+ return {}
+
+ return self.runInteraction(
+ "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
+ )
+
+ def _get_e2e_device_keys_txn(self, txn, query_list):
+ query_clauses = []
+ query_params = []
+
+ for (user_id, device_id) in query_list:
+ query_clause = "k.user_id = ?"
+ query_params.append(user_id)
+
+ if device_id:
+ query_clause += " AND k.device_id = ?"
+ query_params.append(device_id)
+
+ query_clauses.append(query_clause)
+
+ sql = (
+ "SELECT k.user_id, k.device_id, "
+ " d.display_name AS device_display_name, "
+ " k.key_json"
+ " FROM e2e_device_keys_json k"
+ " LEFT JOIN devices d ON d.user_id = k.user_id"
+ " AND d.device_id = k.device_id"
+ " WHERE %s"
+ ) % (
+ " OR ".join("(" + q + ")" for q in query_clauses)
+ )
+
+ txn.execute(sql, query_params)
+ rows = self.cursor_to_dict(txn)
+
+ result = collections.defaultdict(dict)
+ for row in rows:
+ result[row["user_id"]][row["device_id"]] = row
+
+ return result
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
def _add_e2e_one_time_keys(txn):
@@ -123,3 +151,16 @@ class EndToEndKeyStore(SQLBaseStore):
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
+
+ @twisted.internet.defer.inlineCallbacks
+ def delete_e2e_keys_by_device(self, user_id, device_id):
+ yield self._simple_delete(
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="delete_e2e_device_keys_by_device"
+ )
+ yield self._simple_delete(
+ table="e2e_one_time_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="delete_e2e_one_time_keys_by_device"
+ )
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 940e11d7a2..df4000d0da 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -16,6 +16,8 @@
from ._base import SQLBaseStore
from twisted.internet import defer
from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.types import RoomStreamToken
+from .stream import lower_bound
import logging
import ujson as json
@@ -73,6 +75,9 @@ class EventPushActionsStore(SQLBaseStore):
stream_ordering = results[0][0]
topological_ordering = results[0][1]
+ token = RoomStreamToken(
+ topological_ordering, stream_ordering
+ )
sql = (
"SELECT sum(notif), sum(highlight)"
@@ -80,15 +85,10 @@ class EventPushActionsStore(SQLBaseStore):
" WHERE"
" user_id = ?"
" AND room_id = ?"
- " AND ("
- " topological_ordering > ?"
- " OR (topological_ordering = ? AND stream_ordering > ?)"
- ")"
- )
- txn.execute(sql, (
- user_id, room_id,
- topological_ordering, topological_ordering, stream_ordering
- ))
+ " AND %s"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ txn.execute(sql, (user_id, room_id))
row = txn.fetchone()
if row:
return {
@@ -117,24 +117,42 @@ class EventPushActionsStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
- def get_unread_push_actions_for_user_in_range(self, user_id,
- min_stream_ordering,
- max_stream_ordering=None,
- limit=20):
+ def get_unread_push_actions_for_user_in_range_for_http(
+ self, user_id, min_stream_ordering, max_stream_ordering, limit=20
+ ):
+ """Get a list of the most recent unread push actions for a given user,
+ within the given stream ordering range. Called by the httppusher.
+
+ Args:
+ user_id (str): The user to fetch push actions for.
+ min_stream_ordering(int): The exclusive lower bound on the
+ stream ordering of event push actions to fetch.
+ max_stream_ordering(int): The inclusive upper bound on the
+ stream ordering of event push actions to fetch.
+ limit (int): The maximum number of rows to return.
+ Returns:
+ A promise which resolves to a list of dicts with the keys "event_id",
+ "room_id", "stream_ordering", "actions".
+ The list will be ordered by ascending stream_ordering.
+ The list will have between 0~limit entries.
+ """
+ # find rooms that have a read receipt in them and return the next
+ # push actions
def get_after_receipt(txn):
+ # find rooms that have a read receipt in them and return the next
+ # push actions
sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, "
- "e.received_ts "
- "FROM ("
- " SELECT room_id, user_id, "
- " max(topological_ordering) as topological_ordering, "
- " max(stream_ordering) as stream_ordering "
- " FROM events"
- " NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
- " GROUP BY room_id, user_id"
+ "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
+ " FROM ("
+ " SELECT room_id,"
+ " MAX(topological_ordering) as topological_ordering,"
+ " MAX(stream_ordering) as stream_ordering"
+ " FROM events"
+ " INNER JOIN receipts_linearized USING (room_id, event_id)"
+ " WHERE receipt_type = 'm.read' AND user_id = ?"
+ " GROUP BY room_id"
") AS rl,"
" event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
" ep.room_id = rl.room_id"
" AND ("
@@ -144,46 +162,163 @@ class EventPushActionsStore(SQLBaseStore):
" AND ep.stream_ordering > rl.stream_ordering"
" )"
" )"
- " AND ep.stream_ordering > ?"
" AND ep.user_id = ?"
- " AND ep.user_id = rl.user_id"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering ASC LIMIT ?"
)
- args = [min_stream_ordering, user_id]
- if max_stream_ordering is not None:
- sql += " AND ep.stream_ordering <= ?"
- args.append(max_stream_ordering)
- sql += " ORDER BY ep.stream_ordering ASC LIMIT ?"
- args.append(limit)
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
- "get_unread_push_actions_for_user_in_range", get_after_receipt
+ "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
+ # There are rooms with push actions in them but you don't have a read receipt in
+ # them e.g. rooms you've been invited to, so get push actions for rooms which do
+ # not have read receipts in them too.
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts"
" FROM event_push_actions AS ep"
- " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
- " WHERE ep.room_id not in ("
- " SELECT room_id FROM events NATURAL JOIN receipts_linearized"
+ " INNER JOIN events AS e USING (room_id, event_id)"
+ " WHERE"
+ " ep.room_id NOT IN ("
+ " SELECT room_id FROM receipts_linearized"
+ " WHERE receipt_type = 'm.read' AND user_id = ?"
+ " GROUP BY room_id"
+ " )"
+ " AND ep.user_id = ?"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering ASC LIMIT ?"
+ )
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
+ txn.execute(sql, args)
+ return txn.fetchall()
+ no_read_receipt = yield self.runInteraction(
+ "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
+ )
+
+ notifs = [
+ {
+ "event_id": row[0],
+ "room_id": row[1],
+ "stream_ordering": row[2],
+ "actions": json.loads(row[3]),
+ } for row in after_read_receipt + no_read_receipt
+ ]
+
+ # Now sort it so it's ordered correctly, since currently it will
+ # contain results from the first query, correctly ordered, followed
+ # by results from the second query, but we want them all ordered
+ # by stream_ordering, oldest first.
+ notifs.sort(key=lambda r: r['stream_ordering'])
+
+ # Take only up to the limit. We have to stop at the limit because
+ # one of the subqueries may have hit the limit.
+ defer.returnValue(notifs[:limit])
+
+ @defer.inlineCallbacks
+ def get_unread_push_actions_for_user_in_range_for_email(
+ self, user_id, min_stream_ordering, max_stream_ordering, limit=20
+ ):
+ """Get a list of the most recent unread push actions for a given user,
+ within the given stream ordering range. Called by the emailpusher
+
+ Args:
+ user_id (str): The user to fetch push actions for.
+ min_stream_ordering(int): The exclusive lower bound on the
+ stream ordering of event push actions to fetch.
+ max_stream_ordering(int): The inclusive upper bound on the
+ stream ordering of event push actions to fetch.
+ limit (int): The maximum number of rows to return.
+ Returns:
+ A promise which resolves to a list of dicts with the keys "event_id",
+ "room_id", "stream_ordering", "actions", "received_ts".
+ The list will be ordered by descending received_ts.
+ The list will have between 0~limit entries.
+ """
+ # find rooms that have a read receipt in them and return the most recent
+ # push actions
+ def get_after_receipt(txn):
+ sql = (
+ "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
+ " e.received_ts"
+ " FROM ("
+ " SELECT room_id,"
+ " MAX(topological_ordering) as topological_ordering,"
+ " MAX(stream_ordering) as stream_ordering"
+ " FROM events"
+ " INNER JOIN receipts_linearized USING (room_id, event_id)"
" WHERE receipt_type = 'm.read' AND user_id = ?"
" GROUP BY room_id"
- ") AND ep.user_id = ? AND ep.stream_ordering > ?"
+ ") AS rl,"
+ " event_push_actions AS ep"
+ " INNER JOIN events AS e USING (room_id, event_id)"
+ " WHERE"
+ " ep.room_id = rl.room_id"
+ " AND ("
+ " ep.topological_ordering > rl.topological_ordering"
+ " OR ("
+ " ep.topological_ordering = rl.topological_ordering"
+ " AND ep.stream_ordering > rl.stream_ordering"
+ " )"
+ " )"
+ " AND ep.user_id = ?"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering DESC LIMIT ?"
)
- args = [user_id, user_id, min_stream_ordering]
- if max_stream_ordering is not None:
- sql += " AND ep.stream_ordering <= ?"
- args.append(max_stream_ordering)
- sql += " ORDER BY ep.stream_ordering ASC"
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
+ txn.execute(sql, args)
+ return txn.fetchall()
+ after_read_receipt = yield self.runInteraction(
+ "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
+ )
+
+ # There are rooms with push actions in them but you don't have a read receipt in
+ # them e.g. rooms you've been invited to, so get push actions for rooms which do
+ # not have read receipts in them too.
+ def get_no_receipt(txn):
+ sql = (
+ "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
+ " e.received_ts"
+ " FROM event_push_actions AS ep"
+ " INNER JOIN events AS e USING (room_id, event_id)"
+ " WHERE"
+ " ep.room_id NOT IN ("
+ " SELECT room_id FROM receipts_linearized"
+ " WHERE receipt_type = 'm.read' AND user_id = ?"
+ " GROUP BY room_id"
+ " )"
+ " AND ep.user_id = ?"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering DESC LIMIT ?"
+ )
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
- "get_unread_push_actions_for_user_in_range", get_no_receipt
+ "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
- defer.returnValue([
+ # Make a list of dicts from the two sets of results.
+ notifs = [
{
"event_id": row[0],
"room_id": row[1],
@@ -191,7 +326,16 @@ class EventPushActionsStore(SQLBaseStore):
"actions": json.loads(row[3]),
"received_ts": row[4],
} for row in after_read_receipt + no_read_receipt
- ])
+ ]
+
+ # Now sort it so it's ordered correctly, since currently it will
+ # contain results from the first query, correctly ordered, followed
+ # by results from the second query, but we want them all ordered
+ # by received_ts (most recent first)
+ notifs.sort(key=lambda r: -(r['received_ts'] or 0))
+
+ # Now return the first `limit`
+ defer.returnValue(notifs[:limit])
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 6d978ffcd5..d2feee8dbb 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -23,9 +23,11 @@ from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
from canonicaljson import encode_canonical_json
-from collections import deque, namedtuple
+from collections import deque, namedtuple, OrderedDict
+from functools import wraps
import synapse
import synapse.metrics
@@ -149,8 +151,29 @@ class _EventPeristenceQueue(object):
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+def _retry_on_integrity_error(func):
+ """Wraps a database function so that it gets retried on IntegrityError,
+ with `delete_existing=True` passed in.
+
+ Args:
+ func: function that returns a Deferred and accepts a `delete_existing` arg
+ """
+ @wraps(func)
+ @defer.inlineCallbacks
+ def f(self, *args, **kwargs):
+ try:
+ res = yield func(self, *args, **kwargs)
+ except self.database_engine.module.IntegrityError:
+ logger.exception("IntegrityError, retrying.")
+ res = yield func(self, *args, delete_existing=True, **kwargs)
+ defer.returnValue(res)
+
+ return f
+
+
class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
+ EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, hs):
super(EventsStore, self).__init__(hs)
@@ -158,6 +181,10 @@ class EventsStore(SQLBaseStore):
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
+ self.register_background_update_handler(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
+ self._background_reindex_fields_sender,
+ )
self._event_persist_queue = _EventPeristenceQueue()
@@ -223,8 +250,10 @@ class EventsStore(SQLBaseStore):
self._event_persist_queue.handle_queue(room_id, persisting_queue)
+ @_retry_on_integrity_error
@defer.inlineCallbacks
- def _persist_events(self, events_and_contexts, backfilled=False):
+ def _persist_events(self, events_and_contexts, backfilled=False,
+ delete_existing=False):
if not events_and_contexts:
return
@@ -267,12 +296,15 @@ class EventsStore(SQLBaseStore):
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
+ delete_existing=delete_existing,
)
persist_event_counter.inc_by(len(chunk))
+ @_retry_on_integrity_error
@defer.inlineCallbacks
@log_function
- def _persist_event(self, event, context, current_state=None, backfilled=False):
+ def _persist_event(self, event, context, current_state=None, backfilled=False,
+ delete_existing=False):
try:
with self._stream_id_gen.get_next() as stream_ordering:
with self._state_groups_id_gen.get_next() as state_group_id:
@@ -285,6 +317,7 @@ class EventsStore(SQLBaseStore):
context=context,
current_state=current_state,
backfilled=backfilled,
+ delete_existing=delete_existing,
)
persist_event_counter.inc()
except _RollbackButIsFineException:
@@ -317,7 +350,7 @@ class EventsStore(SQLBaseStore):
)
if not events and not allow_none:
- raise RuntimeError("Could not find event %s" % (event_id,))
+ raise SynapseError(404, "Could not find event %s" % (event_id,))
defer.returnValue(events[0] if events else None)
@@ -347,7 +380,8 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@log_function
- def _persist_event_txn(self, txn, event, context, current_state, backfilled=False):
+ def _persist_event_txn(self, txn, event, context, current_state, backfilled=False,
+ delete_existing=False):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
@@ -355,7 +389,6 @@ class EventsStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
- txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
@@ -388,10 +421,38 @@ class EventsStore(SQLBaseStore):
txn,
[(event, context)],
backfilled=backfilled,
+ delete_existing=delete_existing,
)
@log_function
- def _persist_events_txn(self, txn, events_and_contexts, backfilled):
+ def _persist_events_txn(self, txn, events_and_contexts, backfilled,
+ delete_existing=False):
+ """Insert some number of room events into the necessary database tables.
+
+ Rejected events are only inserted into the events table, the events_json table,
+ and the rejections table. Things reading from those table will need to check
+ whether the event was rejected.
+
+ If delete_existing is True then existing events will be purged from the
+ database before insertion. This is useful when retrying due to IntegrityError.
+ """
+ # Ensure that we don't have the same event twice.
+ # Pick the earliest non-outlier if there is one, else the earliest one.
+ new_events_and_contexts = OrderedDict()
+ for event, context in events_and_contexts:
+ prev_event_context = new_events_and_contexts.get(event.event_id)
+ if prev_event_context:
+ if not event.internal_metadata.is_outlier():
+ if prev_event_context[0].internal_metadata.is_outlier():
+ # To ensure correct ordering we pop, as OrderedDict is
+ # ordered by first insertion.
+ new_events_and_contexts.pop(event.event_id, None)
+ new_events_and_contexts[event.event_id] = (event, context)
+ else:
+ new_events_and_contexts[event.event_id] = (event, context)
+
+ events_and_contexts = new_events_and_contexts.values()
+
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
@@ -402,21 +463,11 @@ class EventsStore(SQLBaseStore):
event.room_id, event.internal_metadata.stream_ordering,
)
- if not event.internal_metadata.is_outlier():
+ if not event.internal_metadata.is_outlier() and not context.rejected:
depth_updates[event.room_id] = max(
event.depth, depth_updates.get(event.room_id, event.depth)
)
- if context.push_actions:
- self._set_push_actions_for_event_and_users_txn(
- txn, event, context.push_actions
- )
-
- if event.type == EventTypes.Redaction and event.redacts is not None:
- self._remove_push_actions_for_event_id_txn(
- txn, event.room_id, event.redacts
- )
-
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
@@ -426,30 +477,21 @@ class EventsStore(SQLBaseStore):
),
[event.event_id for event, _ in events_and_contexts]
)
+
have_persisted = {
event_id: outlier
for event_id, outlier in txn.fetchall()
}
- event_map = {}
to_remove = set()
for event, context in events_and_contexts:
- # Handle the case of the list including the same event multiple
- # times. The tricky thing here is when they differ by whether
- # they are an outlier.
- if event.event_id in event_map:
- other = event_map[event.event_id]
-
- if not other.internal_metadata.is_outlier():
- to_remove.add(event)
- continue
- elif not event.internal_metadata.is_outlier():
+ if context.rejected:
+ # If the event is rejected then we don't care if the event
+ # was an outlier or not.
+ if event.event_id in have_persisted:
+ # If we have already seen the event then ignore it.
to_remove.add(event)
- continue
- else:
- to_remove.add(other)
-
- event_map[event.event_id] = event
+ continue
if event.event_id not in have_persisted:
continue
@@ -458,6 +500,12 @@ class EventsStore(SQLBaseStore):
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
+ # We received a copy of an event that we had already stored as
+ # an outlier in the database. We now have some state at that
+ # so we need to update the state_groups table with that state.
+
+ # insert into the state_group, state_groups_state and
+ # event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, ((event, context),))
metadata_json = encode_json(
@@ -473,6 +521,8 @@ class EventsStore(SQLBaseStore):
(metadata_json, event.event_id,)
)
+ # Add an entry to the ex_outlier_stream table to replicate the
+ # change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id
self._simple_insert_txn(
@@ -494,6 +544,8 @@ class EventsStore(SQLBaseStore):
(False, event.event_id,)
)
+ # Update the event_backward_extremities table now that this
+ # event isn't an outlier any more.
self._update_extremeties(txn, [event])
events_and_contexts = [
@@ -501,38 +553,12 @@ class EventsStore(SQLBaseStore):
]
if not events_and_contexts:
+ # Make sure we don't pass an empty list to functions that expect to
+ # be storing at least one element.
return
- self._store_mult_state_groups_txn(txn, events_and_contexts)
-
- self._handle_mult_prev_events(
- txn,
- events=[event for event, _ in events_and_contexts],
- )
-
- for event, _ in events_and_contexts:
- if event.type == EventTypes.Name:
- self._store_room_name_txn(txn, event)
- elif event.type == EventTypes.Topic:
- self._store_room_topic_txn(txn, event)
- elif event.type == EventTypes.Message:
- self._store_room_message_txn(txn, event)
- elif event.type == EventTypes.Redaction:
- self._store_redaction(txn, event)
- elif event.type == EventTypes.RoomHistoryVisibility:
- self._store_history_visibility_txn(txn, event)
- elif event.type == EventTypes.GuestAccess:
- self._store_guest_access_txn(txn, event)
-
- self._store_room_members_txn(
- txn,
- [
- event
- for event, _ in events_and_contexts
- if event.type == EventTypes.Member
- ],
- backfilled=backfilled,
- )
+ # From this point onwards the events are only events that we haven't
+ # seen before.
def event_dict(event):
return {
@@ -544,6 +570,43 @@ class EventsStore(SQLBaseStore):
]
}
+ if delete_existing:
+ # For paranoia reasons, we go and delete all the existing entries
+ # for these events so we can reinsert them.
+ # This gets around any problems with some tables already having
+ # entries.
+
+ logger.info("Deleting existing")
+
+ for table in (
+ "events",
+ "event_auth",
+ "event_json",
+ "event_content_hashes",
+ "event_destinations",
+ "event_edge_hashes",
+ "event_edges",
+ "event_forward_extremities",
+ "event_push_actions",
+ "event_reference_hashes",
+ "event_search",
+ "event_signatures",
+ "event_to_state_groups",
+ "guest_access",
+ "history_visibility",
+ "local_invites",
+ "room_names",
+ "state_events",
+ "rejections",
+ "redactions",
+ "room_memberships",
+ "state_events"
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE event_id = ?" % (table,),
+ [(ev.event_id,) for ev, _ in events_and_contexts]
+ )
+
self._simple_insert_many_txn(
txn,
table="event_json",
@@ -576,15 +639,51 @@ class EventsStore(SQLBaseStore):
"content": encode_json(event.content).decode("UTF-8"),
"origin_server_ts": int(event.origin_server_ts),
"received_ts": self._clock.time_msec(),
+ "sender": event.sender,
+ "contains_url": (
+ "url" in event.content
+ and isinstance(event.content["url"], basestring)
+ ),
}
for event, _ in events_and_contexts
],
)
- if context.rejected:
- self._store_rejections_txn(
- txn, event.event_id, context.rejected
- )
+ # Remove the rejected events from the list now that we've added them
+ # to the events table and the events_json table.
+ to_remove = set()
+ for event, context in events_and_contexts:
+ if context.rejected:
+ # Insert the event_id into the rejections table
+ self._store_rejections_txn(
+ txn, event.event_id, context.rejected
+ )
+ to_remove.add(event)
+
+ events_and_contexts = [
+ ec for ec in events_and_contexts if ec[0] not in to_remove
+ ]
+
+ if not events_and_contexts:
+ # Make sure we don't pass an empty list to functions that expect to
+ # be storing at least one element.
+ return
+
+ # From this point onwards the events are only ones that weren't rejected.
+
+ for event, context in events_and_contexts:
+ # Insert all the push actions into the event_push_actions table.
+ if context.push_actions:
+ self._set_push_actions_for_event_and_users_txn(
+ txn, event, context.push_actions
+ )
+
+ if event.type == EventTypes.Redaction and event.redacts is not None:
+ # Remove the entries in the event_push_actions table for the
+ # redacted event.
+ self._remove_push_actions_for_event_id_txn(
+ txn, event.room_id, event.redacts
+ )
self._simple_insert_many_txn(
txn,
@@ -600,6 +699,49 @@ class EventsStore(SQLBaseStore):
],
)
+ # Insert into the state_groups, state_groups_state, and
+ # event_to_state_groups tables.
+ self._store_mult_state_groups_txn(txn, events_and_contexts)
+
+ # Update the event_forward_extremities, event_backward_extremities and
+ # event_edges tables.
+ self._handle_mult_prev_events(
+ txn,
+ events=[event for event, _ in events_and_contexts],
+ )
+
+ for event, _ in events_and_contexts:
+ if event.type == EventTypes.Name:
+ # Insert into the room_names and event_search tables.
+ self._store_room_name_txn(txn, event)
+ elif event.type == EventTypes.Topic:
+ # Insert into the topics table and event_search table.
+ self._store_room_topic_txn(txn, event)
+ elif event.type == EventTypes.Message:
+ # Insert into the event_search table.
+ self._store_room_message_txn(txn, event)
+ elif event.type == EventTypes.Redaction:
+ # Insert into the redactions table.
+ self._store_redaction(txn, event)
+ elif event.type == EventTypes.RoomHistoryVisibility:
+ # Insert into the event_search table.
+ self._store_history_visibility_txn(txn, event)
+ elif event.type == EventTypes.GuestAccess:
+ # Insert into the event_search table.
+ self._store_guest_access_txn(txn, event)
+
+ # Insert into the room_memberships table.
+ self._store_room_members_txn(
+ txn,
+ [
+ event
+ for event, _ in events_and_contexts
+ if event.type == EventTypes.Member
+ ],
+ backfilled=backfilled,
+ )
+
+ # Insert event_reference_hashes table.
self._store_event_reference_hashes_txn(
txn, [event for event, _ in events_and_contexts]
)
@@ -644,6 +786,7 @@ class EventsStore(SQLBaseStore):
],
)
+ # Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
if backfilled:
@@ -656,22 +799,11 @@ class EventsStore(SQLBaseStore):
# Outlier events shouldn't clobber the current state.
continue
- if context.rejected:
- # If the event failed it's auth checks then it shouldn't
- # clobbler the current state.
- continue
-
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
- if event.type in [EventTypes.Name, EventTypes.Aliases]:
- txn.call_after(
- self.get_room_name_and_aliases.invalidate,
- (event.room_id,)
- )
-
self._simple_upsert_txn(
txn,
"current_state_events",
@@ -1122,6 +1254,78 @@ class EventsStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
+ def _background_reindex_fields_sender(self, progress, batch_size):
+ target_min_stream_id = progress["target_min_stream_id_inclusive"]
+ max_stream_id = progress["max_stream_id_exclusive"]
+ rows_inserted = progress.get("rows_inserted", 0)
+
+ INSERT_CLUMP_SIZE = 1000
+
+ def reindex_txn(txn):
+ sql = (
+ "SELECT stream_ordering, event_id, json FROM events"
+ " INNER JOIN event_json USING (event_id)"
+ " WHERE ? <= stream_ordering AND stream_ordering < ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT ?"
+ )
+
+ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ min_stream_id = rows[-1][0]
+
+ update_rows = []
+ for row in rows:
+ try:
+ event_id = row[1]
+ event_json = json.loads(row[2])
+ sender = event_json["sender"]
+ content = event_json["content"]
+
+ contains_url = "url" in content
+ if contains_url:
+ contains_url &= isinstance(content["url"], basestring)
+ except (KeyError, AttributeError):
+ # If the event is missing a necessary field then
+ # skip over it.
+ continue
+
+ update_rows.append((sender, contains_url, event_id))
+
+ sql = (
+ "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
+ )
+
+ for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
+ clump = update_rows[index:index + INSERT_CLUMP_SIZE]
+ txn.executemany(sql, clump)
+
+ progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ "rows_inserted": rows_inserted + len(rows)
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
+ )
+
+ return len(rows)
+
+ result = yield self.runInteraction(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -1288,6 +1492,162 @@ class EventsStore(SQLBaseStore):
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
+ def delete_old_state(self, room_id, topological_ordering):
+ return self.runInteraction(
+ "delete_old_state",
+ self._delete_old_state_txn, room_id, topological_ordering
+ )
+
+ def _delete_old_state_txn(self, txn, room_id, topological_ordering):
+ """Deletes old room state
+ """
+
+ # Tables that should be pruned:
+ # event_auth
+ # event_backward_extremities
+ # event_content_hashes
+ # event_destinations
+ # event_edge_hashes
+ # event_edges
+ # event_forward_extremities
+ # event_json
+ # event_push_actions
+ # event_reference_hashes
+ # event_search
+ # event_signatures
+ # event_to_state_groups
+ # events
+ # rejections
+ # room_depth
+ # state_groups
+ # state_groups_state
+
+ # First ensure that we're not about to delete all the forward extremeties
+ txn.execute(
+ "SELECT e.event_id, e.depth FROM events as e "
+ "INNER JOIN event_forward_extremities as f "
+ "ON e.event_id = f.event_id "
+ "AND e.room_id = f.room_id "
+ "WHERE f.room_id = ?",
+ (room_id,)
+ )
+ rows = txn.fetchall()
+ max_depth = max(row[0] for row in rows)
+
+ if max_depth <= topological_ordering:
+ # We need to ensure we don't delete all the events from the datanase
+ # otherwise we wouldn't be able to send any events (due to not
+ # having any backwards extremeties)
+ raise SynapseError(
+ 400, "topological_ordering is greater than forward extremeties"
+ )
+
+ txn.execute(
+ "SELECT event_id, state_key FROM events"
+ " LEFT JOIN state_events USING (room_id, event_id)"
+ " WHERE room_id = ? AND topological_ordering < ?",
+ (room_id, topological_ordering,)
+ )
+ event_rows = txn.fetchall()
+
+ # We calculate the new entries for the backward extremeties by finding
+ # all events that point to events that are to be purged
+ txn.execute(
+ "SELECT DISTINCT e.event_id FROM events as e"
+ " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
+ " INNER JOIN events as e2 ON e2.event_id = ed.event_id"
+ " WHERE e.room_id = ? AND e.topological_ordering < ?"
+ " AND e2.topological_ordering >= ?",
+ (room_id, topological_ordering, topological_ordering)
+ )
+ new_backwards_extrems = txn.fetchall()
+
+ txn.execute(
+ "DELETE FROM event_backward_extremities WHERE room_id = ?",
+ (room_id,)
+ )
+
+ # Update backward extremeties
+ txn.executemany(
+ "INSERT INTO event_backward_extremities (room_id, event_id)"
+ " VALUES (?, ?)",
+ [
+ (room_id, event_id) for event_id, in new_backwards_extrems
+ ]
+ )
+
+ # Get all state groups that are only referenced by events that are
+ # to be deleted.
+ txn.execute(
+ "SELECT state_group FROM event_to_state_groups"
+ " INNER JOIN events USING (event_id)"
+ " WHERE state_group IN ("
+ " SELECT DISTINCT state_group FROM events"
+ " INNER JOIN event_to_state_groups USING (event_id)"
+ " WHERE room_id = ? AND topological_ordering < ?"
+ " )"
+ " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
+ (room_id, topological_ordering, topological_ordering)
+ )
+ state_rows = txn.fetchall()
+ txn.executemany(
+ "DELETE FROM state_groups_state WHERE state_group = ?",
+ state_rows
+ )
+ txn.executemany(
+ "DELETE FROM state_groups WHERE id = ?",
+ state_rows
+ )
+ # Delete all non-state
+ txn.executemany(
+ "DELETE FROM event_to_state_groups WHERE event_id = ?",
+ [(event_id,) for event_id, _ in event_rows]
+ )
+
+ txn.execute(
+ "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+ (topological_ordering, room_id,)
+ )
+
+ # Delete all remote non-state events
+ to_delete = [
+ (event_id,) for event_id, state_key in event_rows
+ if state_key is None and not self.hs.is_mine_id(event_id)
+ ]
+ for table in (
+ "events",
+ "event_json",
+ "event_auth",
+ "event_content_hashes",
+ "event_destinations",
+ "event_edge_hashes",
+ "event_edges",
+ "event_forward_extremities",
+ "event_push_actions",
+ "event_reference_hashes",
+ "event_search",
+ "event_signatures",
+ "rejections",
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE event_id = ?" % (table,),
+ to_delete
+ )
+
+ txn.executemany(
+ "DELETE FROM events WHERE event_id = ?",
+ to_delete
+ )
+ # Mark all state and own events as outliers
+ txn.executemany(
+ "UPDATE events SET outlier = ?"
+ " WHERE event_id = ?",
+ [
+ (True, event_id,) for event_id, state_key in event_rows
+ if state_key is not None or self.hs.is_mine_id(event_id)
+ ]
+ )
+
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index a495a8a7d9..86b37b9ddd 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -22,6 +22,10 @@ import OpenSSL
from signedjson.key import decode_verify_key_bytes
import hashlib
+import logging
+
+logger = logging.getLogger(__name__)
+
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys and tls X.509 certificates
@@ -74,22 +78,22 @@ class KeyStore(SQLBaseStore):
)
@cachedInlineCallbacks()
- def get_all_server_verify_keys(self, server_name):
- rows = yield self._simple_select_list(
+ def _get_server_verify_key(self, server_name, key_id):
+ verify_key_bytes = yield self._simple_select_one_onecol(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
+ "key_id": key_id,
},
- retcols=["key_id", "verify_key"],
- desc="get_all_server_verify_keys",
+ retcol="verify_key",
+ desc="_get_server_verify_key",
+ allow_none=True,
)
- defer.returnValue({
- row["key_id"]: decode_verify_key_bytes(
- row["key_id"], str(row["verify_key"])
- )
- for row in rows
- })
+ if verify_key_bytes:
+ defer.returnValue(decode_verify_key_bytes(
+ key_id, str(verify_key_bytes)
+ ))
@defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids):
@@ -101,12 +105,12 @@ class KeyStore(SQLBaseStore):
Returns:
(list of VerifyKey): The verification keys.
"""
- keys = yield self.get_all_server_verify_keys(server_name)
- defer.returnValue({
- k: keys[k]
- for k in key_ids
- if k in keys and keys[k]
- })
+ keys = {}
+ for key_id in key_ids:
+ key = yield self._get_server_verify_key(server_name, key_id)
+ if key:
+ keys[key_id] = key
+ defer.returnValue(keys)
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
@@ -133,8 +137,6 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
- self.get_all_server_verify_keys.invalidate((server_name,))
-
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index a820fcf07f..4c0f82353d 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -157,10 +157,25 @@ class MediaRepositoryStore(SQLBaseStore):
"created_ts": time_now_ms,
"upload_name": upload_name,
"filesystem_id": filesystem_id,
+ "last_access_ts": time_now_ms,
},
desc="store_cached_remote_media",
)
+ def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+ def update_cache_txn(txn):
+ sql = (
+ "UPDATE remote_media_cache SET last_access_ts = ?"
+ " WHERE media_origin = ? AND media_id = ?"
+ )
+
+ txn.executemany(sql, (
+ (time_ts, media_origin, media_id)
+ for media_origin, media_id in origin_id_tuples
+ ))
+
+ return self.runInteraction("update_cached_last_access_time", update_cache_txn)
+
def get_remote_media_thumbnails(self, origin, media_id):
return self._simple_select_list(
"remote_media_cache_thumbnails",
@@ -190,3 +205,32 @@ class MediaRepositoryStore(SQLBaseStore):
},
desc="store_remote_media_thumbnail",
)
+
+ def get_remote_media_before(self, before_ts):
+ sql = (
+ "SELECT media_origin, media_id, filesystem_id"
+ " FROM remote_media_cache"
+ " WHERE last_access_ts < ?"
+ )
+
+ return self._execute(
+ "get_remote_media_before", self.cursor_to_dict, sql, before_ts
+ )
+
+ def delete_remote_media(self, media_origin, media_id):
+ def delete_remote_media_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ "remote_media_cache",
+ keyvalues={
+ "media_origin": media_origin, "media_id": media_id
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ "remote_media_cache_thumbnails",
+ keyvalues={
+ "media_origin": media_origin, "media_id": media_id
+ },
+ )
+ return self.runInteraction("delete_remote_media", delete_remote_media_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index c8487c8838..8801669a6b 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 32
+SCHEMA_VERSION = 33
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 3de9e0f709..7e7d32eb66 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,25 +18,40 @@ import re
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
-
-from ._base import SQLBaseStore
+from synapse.storage import background_updates
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-class RegistrationStore(SQLBaseStore):
+class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
super(RegistrationStore, self).__init__(hs)
self.clock = hs.get_clock()
+ self.register_background_index_update(
+ "access_tokens_device_index",
+ index_name="access_tokens_device_id",
+ table="access_tokens",
+ columns=["user_id", "device_id"],
+ )
+
+ self.register_background_index_update(
+ "refresh_tokens_device_index",
+ index_name="refresh_tokens_device_id",
+ table="refresh_tokens",
+ columns=["user_id", "device_id"],
+ )
+
@defer.inlineCallbacks
- def add_access_token_to_user(self, user_id, token):
+ def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
+ device_id (str): ID of the device to associate with the access
+ token
Raises:
StoreError if there was a problem adding this.
"""
@@ -47,18 +62,21 @@ class RegistrationStore(SQLBaseStore):
{
"id": next_id,
"user_id": user_id,
- "token": token
+ "token": token,
+ "device_id": device_id,
},
desc="add_access_token_to_user",
)
@defer.inlineCallbacks
- def add_refresh_token_to_user(self, user_id, token):
+ def add_refresh_token_to_user(self, user_id, token, device_id=None):
"""Adds a refresh token for the given user.
Args:
user_id (str): The user ID.
token (str): The new refresh token to add.
+ device_id (str): ID of the device to associate with the access
+ token
Raises:
StoreError if there was a problem adding this.
"""
@@ -69,20 +87,23 @@ class RegistrationStore(SQLBaseStore):
{
"id": next_id,
"user_id": user_id,
- "token": token
+ "token": token,
+ "device_id": device_id,
},
desc="add_refresh_token_to_user",
)
@defer.inlineCallbacks
- def register(self, user_id, token, password_hash,
+ def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
- create_profile_with_localpart=None):
+ create_profile_with_localpart=None, admin=False):
"""Attempts to register an account.
Args:
user_id (str): The desired user ID to register.
- token (str): The desired access token to use for this user.
+ token (str): The desired access token to use for this user. If this
+ is not None, the given access token is associated with the user
+ id.
password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
@@ -104,6 +125,7 @@ class RegistrationStore(SQLBaseStore):
make_guest,
appservice_id,
create_profile_with_localpart,
+ admin
)
self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,))
@@ -118,6 +140,7 @@ class RegistrationStore(SQLBaseStore):
make_guest,
appservice_id,
create_profile_with_localpart,
+ admin,
):
now = int(self.clock.time())
@@ -125,29 +148,48 @@ class RegistrationStore(SQLBaseStore):
try:
if was_guest:
- txn.execute("UPDATE users SET"
- " password_hash = ?,"
- " upgrade_ts = ?,"
- " is_guest = ?"
- " WHERE name = ?",
- [password_hash, now, 1 if make_guest else 0, user_id])
+ # Ensure that the guest user actually exists
+ # ``allow_none=False`` makes this raise an exception
+ # if the row isn't in the database.
+ self._simple_select_one_txn(
+ txn,
+ "users",
+ keyvalues={
+ "name": user_id,
+ "is_guest": 1,
+ },
+ retcols=("name",),
+ allow_none=False,
+ )
+
+ self._simple_update_one_txn(
+ txn,
+ "users",
+ keyvalues={
+ "name": user_id,
+ "is_guest": 1,
+ },
+ updatevalues={
+ "password_hash": password_hash,
+ "upgrade_ts": now,
+ "is_guest": 1 if make_guest else 0,
+ "appservice_id": appservice_id,
+ "admin": 1 if admin else 0,
+ }
+ )
else:
- txn.execute("INSERT INTO users "
- "("
- " name,"
- " password_hash,"
- " creation_ts,"
- " is_guest,"
- " appservice_id"
- ") "
- "VALUES (?,?,?,?,?)",
- [
- user_id,
- password_hash,
- now,
- 1 if make_guest else 0,
- appservice_id,
- ])
+ self._simple_insert_txn(
+ txn,
+ "users",
+ values={
+ "name": user_id,
+ "password_hash": password_hash,
+ "creation_ts": now,
+ "is_guest": 1 if make_guest else 0,
+ "appservice_id": appservice_id,
+ "admin": 1 if admin else 0,
+ }
+ )
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -209,16 +251,37 @@ class RegistrationStore(SQLBaseStore):
self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks
- def user_delete_access_tokens(self, user_id, except_token_ids=[]):
- def f(txn):
- sql = "SELECT token FROM access_tokens WHERE user_id = ?"
+ def user_delete_access_tokens(self, user_id, except_token_ids=[],
+ device_id=None,
+ delete_refresh_tokens=False):
+ """
+ Invalidate access/refresh tokens belonging to a user
+
+ Args:
+ user_id (str): ID of user the tokens belong to
+ except_token_ids (list[str]): list of access_tokens which should
+ *not* be deleted
+ device_id (str|None): ID of device the tokens are associated with.
+ If None, tokens associated with any device (or no device) will
+ be deleted
+ delete_refresh_tokens (bool): True to delete refresh tokens as
+ well as access tokens.
+ Returns:
+ defer.Deferred:
+ """
+ def f(txn, table, except_tokens, call_after_delete):
+ sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id]
- if except_token_ids:
+ if device_id is not None:
+ sql += " AND device_id = ?"
+ clauses.append(device_id)
+
+ if except_tokens:
sql += " AND id NOT IN (%s)" % (
- ",".join(["?" for _ in except_token_ids]),
+ ",".join(["?" for _ in except_tokens]),
)
- clauses += except_token_ids
+ clauses += except_tokens
txn.execute(sql, clauses)
@@ -227,16 +290,33 @@ class RegistrationStore(SQLBaseStore):
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
- for row in chunk:
- txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
+ if call_after_delete:
+ for row in chunk:
+ txn.call_after(call_after_delete, (row[0],))
txn.execute(
- "DELETE FROM access_tokens WHERE token in (%s)" % (
+ "DELETE FROM %s WHERE token in (%s)" % (
+ table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
)
- yield self.runInteraction("user_delete_access_tokens", f)
+ # delete refresh tokens first, to stop new access tokens being
+ # allocated while our backs are turned
+ if delete_refresh_tokens:
+ yield self.runInteraction(
+ "user_delete_access_tokens", f,
+ table="refresh_tokens",
+ except_tokens=[],
+ call_after_delete=None,
+ )
+
+ yield self.runInteraction(
+ "user_delete_access_tokens", f,
+ table="access_tokens",
+ except_tokens=except_token_ids,
+ call_after_delete=self.get_user_by_access_token.invalidate,
+ )
def delete_access_token(self, access_token):
def f(txn):
@@ -259,9 +339,8 @@ class RegistrationStore(SQLBaseStore):
Args:
token (str): The access token of a user.
Returns:
- dict: Including the name (user_id) and the ID of their access token.
- Raises:
- StoreError if no user was found.
+ defer.Deferred: None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
@@ -270,18 +349,18 @@ class RegistrationStore(SQLBaseStore):
)
def exchange_refresh_token(self, refresh_token, token_generator):
- """Exchange a refresh token for a new access token and refresh token.
+ """Exchange a refresh token for a new one.
Doing so invalidates the old refresh token - refresh tokens are single
use.
Args:
- token (str): The refresh token of a user.
+ refresh_token (str): The refresh token of a user.
token_generator (fn: str -> str): Function which, when given a
user ID, returns a unique refresh token for that user. This
function must never return the same value twice.
Returns:
- tuple of (user_id, refresh_token)
+ tuple of (user_id, new_refresh_token, device_id)
Raises:
StoreError if no user was found with that refresh token.
"""
@@ -293,12 +372,13 @@ class RegistrationStore(SQLBaseStore):
)
def _exchange_refresh_token(self, txn, old_token, token_generator):
- sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
+ sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
txn.execute(sql, (old_token,))
rows = self.cursor_to_dict(txn)
if not rows:
raise StoreError(403, "Did not recognize refresh token")
user_id = rows[0]["user_id"]
+ device_id = rows[0]["device_id"]
# TODO(danielwh): Maybe perform a validation on the macaroon that
# macaroon.user_id == user_id.
@@ -307,7 +387,7 @@ class RegistrationStore(SQLBaseStore):
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
txn.execute(sql, (new_token, old_token,))
- return user_id, new_token
+ return user_id, new_token, device_id
@defer.inlineCallbacks
def is_server_admin(self, user):
@@ -335,7 +415,8 @@ class RegistrationStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id"
+ "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ " access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
@@ -384,6 +465,15 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(ret['user_id'])
defer.returnValue(None)
+ def user_delete_threepids(self, user_id):
+ return self._simple_delete(
+ "user_threepids",
+ keyvalues={
+ "user_id": user_id,
+ },
+ desc="user_delete_threepids",
+ )
+
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 97f9f1929c..8251f58670 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks
from .engines import PostgresEngine, Sqlite3Engine
import collections
@@ -192,49 +191,6 @@ class RoomStore(SQLBaseStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- @cachedInlineCallbacks()
- def get_room_name_and_aliases(self, room_id):
- def get_room_name(txn):
- sql = (
- "SELECT name FROM room_names"
- " INNER JOIN current_state_events USING (room_id, event_id)"
- " WHERE room_id = ?"
- " LIMIT 1"
- )
-
- txn.execute(sql, (room_id,))
- rows = txn.fetchall()
- if rows:
- return rows[0][0]
- else:
- return None
-
- return [row[0] for row in txn.fetchall()]
-
- def get_room_aliases(txn):
- sql = (
- "SELECT content FROM current_state_events"
- " INNER JOIN events USING (room_id, event_id)"
- " WHERE room_id = ?"
- )
- txn.execute(sql, (room_id,))
- return [row[0] for row in txn.fetchall()]
-
- name = yield self.runInteraction("get_room_name", get_room_name)
- alias_contents = yield self.runInteraction("get_room_aliases", get_room_aliases)
-
- aliases = []
-
- for c in alias_contents:
- try:
- content = json.loads(c)
- except:
- continue
-
- aliases.extend(content.get('aliases', []))
-
- defer.returnValue((name, aliases))
-
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
next_id = self._event_reports_id_gen.get_next()
diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/schema/delta/33/access_tokens_device_index.sql
new file mode 100644
index 0000000000..61ad3fe3e8
--- /dev/null
+++ b/synapse/storage/schema/delta/33/access_tokens_device_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('access_tokens_device_index', '{}');
diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/schema/delta/33/devices.sql
new file mode 100644
index 0000000000..eca7268d82
--- /dev/null
+++ b/synapse/storage/schema/delta/33/devices.sql
@@ -0,0 +1,21 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE devices (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ display_name TEXT,
+ CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
+);
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
new file mode 100644
index 0000000000..aa4a3b9f2f
--- /dev/null
+++ b/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
@@ -0,0 +1,19 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- make sure that we have a device record for each set of E2E keys, so that the
+-- user can delete them if they like.
+INSERT INTO devices
+ SELECT user_id, device_id, NULL FROM e2e_device_keys_json;
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
new file mode 100644
index 0000000000..6671573398
--- /dev/null
+++ b/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
@@ -0,0 +1,20 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- a previous version of the "devices_for_e2e_keys" delta set all the device
+-- names to "unknown device". This wasn't terribly helpful
+UPDATE devices
+ SET display_name = NULL
+ WHERE display_name = 'unknown device';
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py
new file mode 100644
index 0000000000..83066cccc9
--- /dev/null
+++ b/synapse/storage/schema/delta/33/event_fields.py
@@ -0,0 +1,60 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.prepare_database import get_statements
+
+import logging
+import ujson
+
+logger = logging.getLogger(__name__)
+
+
+ALTER_TABLE = """
+ALTER TABLE events ADD COLUMN sender TEXT;
+ALTER TABLE events ADD COLUMN contains_url BOOLEAN;
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ for statement in get_statements(ALTER_TABLE.splitlines()):
+ cur.execute(statement)
+
+ cur.execute("SELECT MIN(stream_ordering) FROM events")
+ rows = cur.fetchall()
+ min_stream_id = rows[0][0]
+
+ cur.execute("SELECT MAX(stream_ordering) FROM events")
+ rows = cur.fetchall()
+ max_stream_id = rows[0][0]
+
+ if min_stream_id is not None and max_stream_id is not None:
+ progress = {
+ "target_min_stream_id_inclusive": min_stream_id,
+ "max_stream_id_exclusive": max_stream_id + 1,
+ "rows_inserted": 0,
+ }
+ progress_json = ujson.dumps(progress)
+
+ sql = (
+ "INSERT into background_updates (update_name, progress_json)"
+ " VALUES (?, ?)"
+ )
+
+ sql = database_engine.convert_param_style(sql)
+
+ cur.execute(sql, ("event_fields_sender_url", progress_json))
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/33/refreshtoken_device.sql
new file mode 100644
index 0000000000..290bd6da86
--- /dev/null
+++ b/synapse/storage/schema/delta/33/refreshtoken_device.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql
new file mode 100644
index 0000000000..bb225dafbf
--- /dev/null
+++ b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('refresh_tokens_device_index', '{}');
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py
new file mode 100644
index 0000000000..55ae43f395
--- /dev/null
+++ b/synapse/storage/schema/delta/33/remote_media_ts.py
@@ -0,0 +1,31 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+
+
+ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT"
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ cur.execute(ALTER_TABLE)
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ cur.execute(
+ database_engine.convert_param_style(
+ "UPDATE remote_media_cache SET last_access_ts = ?"
+ ),
+ (int(time.time() * 1000),)
+ )
diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/schema/delta/33/user_ips_index.sql
new file mode 100644
index 0000000000..473f75a78e
--- /dev/null
+++ b/synapse/storage/schema/delta/33/user_ips_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('user_ips_device_index', '{}');
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index b9ad965fd6..862c5c3ea1 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -40,6 +40,7 @@ from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
@@ -54,26 +55,92 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
-def lower_bound(token):
+def lower_bound(token, engine, inclusive=False):
+ inclusive = "=" if inclusive else ""
if token.topological is None:
- return "(%d < %s)" % (token.stream, "stream_ordering")
+ return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
- return "(%d < %s OR (%d = %s AND %d < %s))" % (
+ if isinstance(engine, PostgresEngine):
+ # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
+ # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
+ # use the later form when running against postgres.
+ return "((%d,%d) <%s (%s,%s))" % (
+ token.topological, token.stream, inclusive,
+ "topological_ordering", "stream_ordering",
+ )
+ return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
- token.stream, "stream_ordering",
+ token.stream, inclusive, "stream_ordering",
)
-def upper_bound(token):
+def upper_bound(token, engine, inclusive=True):
+ inclusive = "=" if inclusive else ""
if token.topological is None:
- return "(%d >= %s)" % (token.stream, "stream_ordering")
+ return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
- return "(%d > %s OR (%d = %s AND %d >= %s))" % (
+ if isinstance(engine, PostgresEngine):
+ # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
+ # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
+ # use the later form when running against postgres.
+ return "((%d,%d) >%s (%s,%s))" % (
+ token.topological, token.stream, inclusive,
+ "topological_ordering", "stream_ordering",
+ )
+ return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
- token.stream, "stream_ordering",
+ token.stream, inclusive, "stream_ordering",
+ )
+
+
+def filter_to_clause(event_filter):
+ # NB: This may create SQL clauses that don't optimise well (and we don't
+ # have indices on all possible clauses). E.g. it may create
+ # "room_id == X AND room_id != X", which postgres doesn't optimise.
+
+ if not event_filter:
+ return "", []
+
+ clauses = []
+ args = []
+
+ if event_filter.types:
+ clauses.append(
+ "(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
+ )
+ args.extend(event_filter.types)
+
+ for typ in event_filter.not_types:
+ clauses.append("type != ?")
+ args.append(typ)
+
+ if event_filter.senders:
+ clauses.append(
+ "(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
)
+ args.extend(event_filter.senders)
+
+ for sender in event_filter.not_senders:
+ clauses.append("sender != ?")
+ args.append(sender)
+
+ if event_filter.rooms:
+ clauses.append(
+ "(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
+ )
+ args.extend(event_filter.rooms)
+
+ for room_id in event_filter.not_rooms:
+ clauses.append("room_id != ?")
+ args.append(room_id)
+
+ if event_filter.contains_url:
+ clauses.append("contains_url = ?")
+ args.append(event_filter.contains_url)
+
+ return " AND ".join(clauses), args
class StreamStore(SQLBaseStore):
@@ -301,25 +368,35 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None,
- direction='b', limit=-1):
+ direction='b', limit=-1, event_filter=None):
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
args = [False, room_id]
if direction == 'b':
order = "DESC"
- bounds = upper_bound(RoomStreamToken.parse(from_key))
+ bounds = upper_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
if to_key:
- bounds = "%s AND %s" % (
- bounds, lower_bound(RoomStreamToken.parse(to_key))
- )
+ bounds = "%s AND %s" % (bounds, lower_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
else:
order = "ASC"
- bounds = lower_bound(RoomStreamToken.parse(from_key))
+ bounds = lower_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
if to_key:
- bounds = "%s AND %s" % (
- bounds, upper_bound(RoomStreamToken.parse(to_key))
- )
+ bounds = "%s AND %s" % (bounds, upper_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
+
+ filter_clause, filter_args = filter_to_clause(event_filter)
+
+ if filter_clause:
+ bounds += " AND " + filter_clause
+ args.extend(filter_args)
if int(limit) > 0:
args.append(int(limit))
@@ -487,13 +564,13 @@ class StreamStore(SQLBaseStore):
row["topological_ordering"], row["stream_ordering"],)
)
- def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
+ def get_max_topological_token(self, room_id, stream_key):
sql = (
"SELECT max(topological_ordering) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
return self._execute(
- "get_max_topological_token_for_stream_and_room", None,
+ "get_max_topological_token", None,
sql, room_id, stream_key,
).addCallback(
lambda r: r[0][0] if r else 0
@@ -586,32 +663,60 @@ class StreamStore(SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
- stream_ordering = results["stream_ordering"]
- topological_ordering = results["topological_ordering"]
-
- query_before = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND (topological_ordering < ?"
- " OR (topological_ordering = ? AND stream_ordering < ?))"
- " ORDER BY topological_ordering DESC, stream_ordering DESC"
- " LIMIT ?"
+ token = RoomStreamToken(
+ results["topological_ordering"],
+ results["stream_ordering"],
)
- query_after = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND (topological_ordering > ?"
- " OR (topological_ordering = ? AND stream_ordering > ?))"
- " ORDER BY topological_ordering ASC, stream_ordering ASC"
- " LIMIT ?"
- )
+ if isinstance(self.database_engine, Sqlite3Engine):
+ # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)``
+ # So we give pass it to SQLite3 as the UNION ALL of the two queries.
+
+ query_before = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering < ?"
+ " UNION ALL"
+ " SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
+ " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+ )
+ before_args = (
+ room_id, token.topological,
+ room_id, token.topological, token.stream,
+ before_limit,
+ )
- txn.execute(
- query_before,
- (
- room_id, topological_ordering, topological_ordering,
- stream_ordering, before_limit,
+ query_after = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering > ?"
+ " UNION ALL"
+ " SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
+ " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
)
- )
+ after_args = (
+ room_id, token.topological,
+ room_id, token.topological, token.stream,
+ after_limit,
+ )
+ else:
+ query_before = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND %s"
+ " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+ ) % (upper_bound(token, self.database_engine, inclusive=False),)
+
+ before_args = (room_id, before_limit)
+
+ query_after = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND %s"
+ " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ after_args = (room_id, after_limit)
+
+ txn.execute(query_before, before_args)
rows = self.cursor_to_dict(txn)
events_before = [r["event_id"] for r in rows]
@@ -623,17 +728,11 @@ class StreamStore(SQLBaseStore):
))
else:
start_token = str(RoomStreamToken(
- topological_ordering,
- stream_ordering - 1,
+ token.topological,
+ token.stream - 1,
))
- txn.execute(
- query_after,
- (
- room_id, topological_ordering, topological_ordering,
- stream_ordering, after_limit,
- )
- )
+ txn.execute(query_after, after_args)
rows = self.cursor_to_dict(txn)
events_after = [r["event_id"] for r in rows]
@@ -644,10 +743,7 @@ class StreamStore(SQLBaseStore):
rows[-1]["stream_ordering"],
))
else:
- end_token = str(RoomStreamToken(
- topological_ordering,
- stream_ordering,
- ))
+ end_token = str(token)
return {
"before": {
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 6c7481a728..6258ff1725 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -24,6 +24,7 @@ from collections import namedtuple
import itertools
import logging
+import ujson as json
logger = logging.getLogger(__name__)
@@ -101,7 +102,7 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
- return result["response_code"], result["response_json"]
+ return result["response_code"], json.loads(str(result["response_json"]))
else:
return None
|