diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 49feb77779..73fb334dd6 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -14,10 +14,12 @@
# limitations under the License.
from twisted.internet import defer
+
+from synapse.storage.devices import DeviceStore
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
-from ._base import Cache
+from ._base import LoggingTransaction
from .directory import DirectoryStore
from .events import EventsStore
from .presence import PresenceStore, UserPresenceState
@@ -45,6 +47,7 @@ from .search import SearchStore
from .tags import TagsStore
from .account_data import AccountDataStore
from .openid import OpenIdStore
+from .client_ips import ClientIpStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
@@ -58,12 +61,6 @@ import logging
logger = logging.getLogger(__name__)
-# Number of msec of granularity to store the user IP 'last seen' time. Smaller
-# times give more inserts into the database even for readonly API hits
-# 120 seconds == 2 minutes
-LAST_SEEN_GRANULARITY = 120 * 1000
-
-
class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore,
PresenceStore, TransactionStore,
@@ -84,6 +81,8 @@ class DataStore(RoomMemberStore, RoomStore,
AccountDataStore,
EventPushActionsStore,
OpenIdStore,
+ ClientIpStore,
+ DeviceStore,
):
def __init__(self, db_conn, hs):
@@ -91,17 +90,13 @@ class DataStore(RoomMemberStore, RoomStore,
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
- self.client_ip_last_seen = Cache(
- name="client_ip_last_seen",
- keylen=4,
- )
-
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")]
)
self._backfill_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering", step=-1
+ db_conn, "events", "stream_ordering", step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
@@ -149,7 +144,7 @@ class DataStore(RoomMemberStore, RoomStore,
"AccountDataAndTagsChangeCache", account_max,
)
- self.__presence_on_startup = self._get_active_presence(db_conn)
+ self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream",
@@ -174,7 +169,12 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=push_rules_prefill,
)
- cur = db_conn.cursor()
+ cur = LoggingTransaction(
+ db_conn.cursor(),
+ name="_find_stream_orderings_for_times_txn",
+ database_engine=self.database_engine,
+ after_callbacks=[]
+ )
self._find_stream_orderings_for_times_txn(cur)
cur.close()
@@ -185,8 +185,8 @@ class DataStore(RoomMemberStore, RoomStore,
super(DataStore, self).__init__(hs)
def take_presence_startup_info(self):
- active_on_startup = self.__presence_on_startup
- self.__presence_on_startup = None
+ active_on_startup = self._presence_on_startup
+ self._presence_on_startup = None
return active_on_startup
def _get_active_presence(self, db_conn):
@@ -212,39 +212,6 @@ class DataStore(RoomMemberStore, RoomStore,
return [UserPresenceState(**row) for row in rows]
@defer.inlineCallbacks
- def insert_client_ip(self, user, access_token, ip, user_agent):
- now = int(self._clock.time_msec())
- key = (user.to_string(), access_token, ip)
-
- try:
- last_seen = self.client_ip_last_seen.get(key)
- except KeyError:
- last_seen = None
-
- # Rate-limited inserts
- if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
- defer.returnValue(None)
-
- self.client_ip_last_seen.prefill(key, now)
-
- # It's safe not to lock here: a) no unique constraint,
- # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
- yield self._simple_upsert(
- "user_ips",
- keyvalues={
- "user_id": user.to_string(),
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- },
- values={
- "last_seen": now,
- },
- desc="insert_client_ip",
- lock=False,
- )
-
- @defer.inlineCallbacks
def count_daily_users(self):
"""
Counts the number of users who used this homeserver in the last 24 hours.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 56a0dd80f3..0117fdc639 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -152,6 +152,7 @@ class SQLBaseStore(object):
def __init__(self, hs):
self.hs = hs
+ self._clock = hs.get_clock()
self._db_pool = hs.get_db_pool()
self._previous_txn_total_time = 0
@@ -596,10 +597,13 @@ class SQLBaseStore(object):
more rows, returning the result as a list of dicts.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the rows with,
- or None to not apply a WHERE clause.
- retcols : list of strings giving the names of the columns to return
+ table (str): the table name
+ keyvalues (dict[str, Any] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
desc,
@@ -614,9 +618,11 @@ class SQLBaseStore(object):
Args:
txn : Transaction object
- table : string giving the table name
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ table (str): the table name
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -806,6 +812,11 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
+ def _simple_delete(self, table, keyvalues, desc):
+ return self.runInteraction(
+ desc, self._simple_delete_txn, table, keyvalues
+ )
+
@staticmethod
def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index ec7e8d40d2..3fa226e92d 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -138,6 +138,9 @@ class AccountDataStore(SQLBaseStore):
A deferred pair of lists of tuples of stream_id int, user_id string,
room_id string, type string, and content string.
"""
+ if last_room_id == current_id and last_global_id == current_id:
+ return defer.succeed(([], []))
+
def get_updated_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, account_data_type, content"
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index feb9d228ae..d1ee533fac 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -298,6 +298,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
dict(txn_id=txn_id, as_id=service.id)
)
+ @defer.inlineCallbacks
def get_oldest_unsent_txn(self, service):
"""Get the oldest transaction which has not been sent for this
service.
@@ -308,12 +309,23 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves to an AppServiceTransaction or
None.
"""
- return self.runInteraction(
+ entry = yield self.runInteraction(
"get_oldest_unsent_appservice_txn",
self._get_oldest_unsent_txn,
service
)
+ if not entry:
+ defer.returnValue(None)
+
+ event_ids = json.loads(entry["event_ids"])
+
+ events = yield self._get_events(event_ids)
+
+ defer.returnValue(AppServiceTransaction(
+ service=service, id=entry["txn_id"], events=events
+ ))
+
def _get_oldest_unsent_txn(self, txn, service):
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
@@ -328,12 +340,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
entry = rows[0]
- event_ids = json.loads(entry["event_ids"])
- events = self._get_events_txn(txn, event_ids)
-
- return AppServiceTransaction(
- service=service, id=entry["txn_id"], events=events
- )
+ return entry
def _get_last_txn(self, txn, service_id):
txn.execute(
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 66a995157d..30d0e4c5dc 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -14,6 +14,7 @@
# limitations under the License.
from ._base import SQLBaseStore
+from . import engines
from twisted.internet import defer
@@ -87,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks
def start_doing_background_updates(self):
- while True:
- if self._background_update_timer is not None:
- return
+ assert self._background_update_timer is None, \
+ "background updates already running"
+
+ logger.info("Starting background schema updates")
+ while True:
sleep = defer.Deferred()
self._background_update_timer = self._clock.call_later(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
@@ -101,22 +104,23 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_timer = None
try:
- result = yield self.do_background_update(
+ result = yield self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
except:
logger.exception("Error doing update")
-
- if result is None:
- logger.info(
- "No more background updates to do."
- " Unscheduling background update task."
- )
- return
+ else:
+ if result is None:
+ logger.info(
+ "No more background updates to do."
+ " Unscheduling background update task."
+ )
+ defer.returnValue(None)
@defer.inlineCallbacks
- def do_background_update(self, desired_duration_ms):
- """Does some amount of work on a background update
+ def do_next_background_update(self, desired_duration_ms):
+ """Does some amount of work on the next queued background update
+
Args:
desired_duration_ms(float): How long we want to spend
updating.
@@ -135,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue.append(update['update_name'])
if not self._background_update_queue:
+ # no work left to do
defer.returnValue(None)
+ # pop from the front, and add back to the back
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
+ res = yield self._do_background_update(update_name, desired_duration_ms)
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def _do_background_update(self, update_name, desired_duration_ms):
+ logger.info("Starting update batch on background update '%s'",
+ update_name)
+
update_handler = self._background_update_handlers[update_name]
performance = self._background_update_performance.get(update_name)
@@ -202,6 +216,64 @@ class BackgroundUpdateStore(SQLBaseStore):
"""
self._background_update_handlers[update_name] = update_handler
+ def register_background_index_update(self, update_name, index_name,
+ table, columns):
+ """Helper for store classes to do a background index addition
+
+ To use:
+
+ 1. use a schema delta file to add a background update. Example:
+ INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('my_new_index', '{}');
+
+ 2. In the Store constructor, call this method
+
+ Args:
+ update_name (str): update_name to register for
+ index_name (str): name of index to add
+ table (str): table to add index to
+ columns (list[str]): columns/expressions to include in index
+ """
+
+ # if this is postgres, we add the indexes concurrently. Otherwise
+ # we fall back to doing it inline
+ if isinstance(self.database_engine, engines.PostgresEngine):
+ conc = True
+ else:
+ conc = False
+
+ sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \
+ % {
+ "conc": "CONCURRENTLY" if conc else "",
+ "name": index_name,
+ "table": table,
+ "columns": ", ".join(columns),
+ }
+
+ def create_index_concurrently(conn):
+ conn.rollback()
+ # postgres insists on autocommit for the index
+ conn.set_session(autocommit=True)
+ c = conn.cursor()
+ c.execute(sql)
+ conn.set_session(autocommit=False)
+
+ def create_index(conn):
+ c = conn.cursor()
+ c.execute(sql)
+
+ @defer.inlineCallbacks
+ def updater(progress, batch_size):
+ logger.info("Adding index %s to %s", index_name, table)
+ if conc:
+ yield self.runWithConnection(create_index_concurrently)
+ else:
+ yield self.runWithConnection(create_index)
+ yield self._end_background_update(update_name)
+ defer.returnValue(1)
+
+ self.register_background_update_handler(update_name, updater)
+
def start_background_update(self, update_name, progress):
"""Starts a background update running.
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
new file mode 100644
index 0000000000..71e5ea112f
--- /dev/null
+++ b/synapse/storage/client_ips.py
@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from ._base import Cache
+from . import background_updates
+
+logger = logging.getLogger(__name__)
+
+# Number of msec of granularity to store the user IP 'last seen' time. Smaller
+# times give more inserts into the database even for readonly API hits
+# 120 seconds == 2 minutes
+LAST_SEEN_GRANULARITY = 120 * 1000
+
+
+class ClientIpStore(background_updates.BackgroundUpdateStore):
+ def __init__(self, hs):
+ self.client_ip_last_seen = Cache(
+ name="client_ip_last_seen",
+ keylen=4,
+ )
+
+ super(ClientIpStore, self).__init__(hs)
+
+ self.register_background_index_update(
+ "user_ips_device_index",
+ index_name="user_ips_device_id",
+ table="user_ips",
+ columns=["user_id", "device_id", "last_seen"],
+ )
+
+ @defer.inlineCallbacks
+ def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
+ now = int(self._clock.time_msec())
+ key = (user.to_string(), access_token, ip)
+
+ try:
+ last_seen = self.client_ip_last_seen.get(key)
+ except KeyError:
+ last_seen = None
+
+ # Rate-limited inserts
+ if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
+ defer.returnValue(None)
+
+ self.client_ip_last_seen.prefill(key, now)
+
+ # It's safe not to lock here: a) no unique constraint,
+ # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
+ yield self._simple_upsert(
+ "user_ips",
+ keyvalues={
+ "user_id": user.to_string(),
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": device_id,
+ },
+ values={
+ "last_seen": now,
+ },
+ desc="insert_client_ip",
+ lock=False,
+ )
+
+ @defer.inlineCallbacks
+ def get_last_client_ip_by_device(self, devices):
+ """For each device_id listed, give the user_ip it was last seen on
+
+ Args:
+ devices (iterable[(str, str)]): list of (user_id, device_id) pairs
+
+ Returns:
+ defer.Deferred: resolves to a dict, where the keys
+ are (user_id, device_id) tuples. The values are also dicts, with
+ keys giving the column names
+ """
+
+ res = yield self.runInteraction(
+ "get_last_client_ip_by_device",
+ self._get_last_client_ip_by_device_txn,
+ retcols=(
+ "user_id",
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ),
+ devices=devices
+ )
+
+ ret = {(d["user_id"], d["device_id"]): d for d in res}
+ defer.returnValue(ret)
+
+ @classmethod
+ def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+ where_clauses = []
+ bindings = []
+ for (user_id, device_id) in devices:
+ if device_id is None:
+ where_clauses.append("(user_id = ? AND device_id IS NULL)")
+ bindings.extend((user_id, ))
+ else:
+ where_clauses.append("(user_id = ? AND device_id = ?)")
+ bindings.extend((user_id, device_id))
+
+ inner_select = (
+ "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
+ "WHERE %(where)s "
+ "GROUP BY user_id, device_id"
+ ) % {
+ "where": " OR ".join(where_clauses),
+ }
+
+ sql = (
+ "SELECT %(retcols)s FROM user_ips "
+ "JOIN (%(inner_select)s) ips ON"
+ " user_ips.last_seen = ips.mls AND"
+ " user_ips.user_id = ips.user_id AND"
+ " (user_ips.device_id = ips.device_id OR"
+ " (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
+ " )"
+ ) % {
+ "retcols": ",".join("user_ips." + c for c in retcols),
+ "inner_select": inner_select,
+ }
+
+ txn.execute(sql, bindings)
+ return cls.cursor_to_dict(txn)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
new file mode 100644
index 0000000000..afd6530cab
--- /dev/null
+++ b/synapse/storage/devices.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from ._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def store_device(self, user_id, device_id,
+ initial_device_display_name,
+ ignore_if_known=True):
+ """Ensure the given device is known; add it to the store if not
+
+ Args:
+ user_id (str): id of user associated with the device
+ device_id (str): id of device
+ initial_device_display_name (str): initial displayname of the
+ device
+ ignore_if_known (bool): ignore integrity errors which mean the
+ device is already known
+ Returns:
+ defer.Deferred
+ Raises:
+ StoreError: if ignore_if_known is False and the device was already
+ known
+ """
+ try:
+ yield self._simple_insert(
+ "devices",
+ values={
+ "user_id": user_id,
+ "device_id": device_id,
+ "display_name": initial_device_display_name
+ },
+ desc="store_device",
+ or_ignore=ignore_if_known,
+ )
+ except Exception as e:
+ logger.error("store_device with device_id=%s failed: %s",
+ device_id, e)
+ raise StoreError(500, "Problem storing device.")
+
+ def get_device(self, user_id, device_id):
+ """Retrieve a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to retrieve
+ Returns:
+ defer.Deferred for a dict containing the device information
+ Raises:
+ StoreError: if the device is not found
+ """
+ return self._simple_select_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_device",
+ )
+
+ def delete_device(self, user_id, device_id):
+ """Delete a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to delete
+ Returns:
+ defer.Deferred
+ """
+ return self._simple_delete_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="delete_device",
+ )
+
+ def update_device(self, user_id, device_id, new_display_name=None):
+ """Update a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to update
+ new_display_name (str|None): new displayname for device; None
+ to leave unchanged
+ Raises:
+ StoreError: if the device is not found
+ Returns:
+ defer.Deferred
+ """
+ updates = {}
+ if new_display_name is not None:
+ updates["display_name"] = new_display_name
+ if not updates:
+ return defer.succeed(None)
+ return self._simple_update_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ updatevalues=updates,
+ desc="update_device",
+ )
+
+ @defer.inlineCallbacks
+ def get_devices_by_user(self, user_id):
+ """Retrieve all of a user's registered devices.
+
+ Args:
+ user_id (str):
+ Returns:
+ defer.Deferred: resolves to a dict from device_id to a dict
+ containing "device_id", "user_id" and "display_name" for each
+ device.
+ """
+ devices = yield self._simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user"
+ )
+
+ defer.returnValue({d["device_id"]: d for d in devices})
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2e89066515..385d607056 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections
+
+import twisted.internet.defer
from ._base import SQLBaseStore
@@ -36,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore):
query_list(list): List of pairs of user_ids and device_ids.
Returns:
Dict mapping from user-id to dict mapping from device_id to
- key json byte strings.
+ dict containing "key_json", "device_display_name".
"""
- def _get_e2e_device_keys(txn):
- result = {}
- for user_id, device_id in query_list:
- user_result = result.setdefault(user_id, {})
- keyvalues = {"user_id": user_id}
- if device_id:
- keyvalues["device_id"] = device_id
- rows = self._simple_select_list_txn(
- txn, table="e2e_device_keys_json",
- keyvalues=keyvalues,
- retcols=["device_id", "key_json"]
- )
- for row in rows:
- user_result[row["device_id"]] = row["key_json"]
- return result
- return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
+ if not query_list:
+ return {}
+
+ return self.runInteraction(
+ "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
+ )
+
+ def _get_e2e_device_keys_txn(self, txn, query_list):
+ query_clauses = []
+ query_params = []
+
+ for (user_id, device_id) in query_list:
+ query_clause = "k.user_id = ?"
+ query_params.append(user_id)
+
+ if device_id:
+ query_clause += " AND k.device_id = ?"
+ query_params.append(device_id)
+
+ query_clauses.append(query_clause)
+
+ sql = (
+ "SELECT k.user_id, k.device_id, "
+ " d.display_name AS device_display_name, "
+ " k.key_json"
+ " FROM e2e_device_keys_json k"
+ " LEFT JOIN devices d ON d.user_id = k.user_id"
+ " AND d.device_id = k.device_id"
+ " WHERE %s"
+ ) % (
+ " OR ".join("(" + q + ")" for q in query_clauses)
+ )
+
+ txn.execute(sql, query_params)
+ rows = self.cursor_to_dict(txn)
+
+ result = collections.defaultdict(dict)
+ for row in rows:
+ result[row["user_id"]][row["device_id"]] = row
+
+ return result
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
def _add_e2e_one_time_keys(txn):
@@ -123,3 +151,16 @@ class EndToEndKeyStore(SQLBaseStore):
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
+
+ @twisted.internet.defer.inlineCallbacks
+ def delete_e2e_keys_by_device(self, user_id, device_id):
+ yield self._simple_delete(
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="delete_e2e_device_keys_by_device"
+ )
+ yield self._simple_delete(
+ table="e2e_one_time_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="delete_e2e_one_time_keys_by_device"
+ )
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 7bb5de1fe7..338b495611 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -32,7 +32,7 @@ def create_engine(database_config):
if engine_class:
module = importlib.import_module(name)
- return engine_class(module)
+ return engine_class(module, database_config)
raise RuntimeError(
"Unsupported database engine '%s'" % (name,)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index c2290943b4..a6ae79dfad 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -19,9 +19,10 @@ from ._base import IncorrectDatabaseSetup
class PostgresEngine(object):
single_threaded = False
- def __init__(self, database_module):
+ def __init__(self, database_module, database_config):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
+ self.synchronous_commit = database_config.get("synchronous_commit", True)
def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING")
@@ -40,9 +41,19 @@ class PostgresEngine(object):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
+ # Asynchronous commit, don't wait for the server to call fsync before
+ # ending the transaction.
+ # https://www.postgresql.org/docs/current/static/wal-async-commit.html
+ if not self.synchronous_commit:
+ cursor = db_conn.cursor()
+ cursor.execute("SET synchronous_commit TO OFF")
+ cursor.close()
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
+ # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
+ # "40001" serialization_failure
+ # "40P01" deadlock_detected
return error.pgcode in ["40001", "40P01"]
return False
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 14203aa500..755c9a1f07 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -21,7 +21,7 @@ import struct
class Sqlite3Engine(object):
single_threaded = True
- def __init__(self, database_module):
+ def __init__(self, database_module, database_config):
self.module = database_module
def check_database(self, txn):
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 5123072c44..0ba0310c0d 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -16,6 +16,8 @@
from ._base import SQLBaseStore
from twisted.internet import defer
from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.types import RoomStreamToken
+from .stream import lower_bound
import logging
import ujson as json
@@ -73,6 +75,9 @@ class EventPushActionsStore(SQLBaseStore):
stream_ordering = results[0][0]
topological_ordering = results[0][1]
+ token = RoomStreamToken(
+ topological_ordering, stream_ordering
+ )
sql = (
"SELECT sum(notif), sum(highlight)"
@@ -80,15 +85,10 @@ class EventPushActionsStore(SQLBaseStore):
" WHERE"
" user_id = ?"
" AND room_id = ?"
- " AND ("
- " topological_ordering > ?"
- " OR (topological_ordering = ? AND stream_ordering > ?)"
- ")"
- )
- txn.execute(sql, (
- user_id, room_id,
- topological_ordering, topological_ordering, stream_ordering
- ))
+ " AND %s"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ txn.execute(sql, (user_id, room_id))
row = txn.fetchone()
if row:
return {
@@ -117,23 +117,42 @@ class EventPushActionsStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
- def get_unread_push_actions_for_user_in_range(self, user_id,
- min_stream_ordering,
- max_stream_ordering=None):
+ def get_unread_push_actions_for_user_in_range_for_http(
+ self, user_id, min_stream_ordering, max_stream_ordering, limit=20
+ ):
+ """Get a list of the most recent unread push actions for a given user,
+ within the given stream ordering range. Called by the httppusher.
+
+ Args:
+ user_id (str): The user to fetch push actions for.
+ min_stream_ordering(int): The exclusive lower bound on the
+ stream ordering of event push actions to fetch.
+ max_stream_ordering(int): The inclusive upper bound on the
+ stream ordering of event push actions to fetch.
+ limit (int): The maximum number of rows to return.
+ Returns:
+ A promise which resolves to a list of dicts with the keys "event_id",
+ "room_id", "stream_ordering", "actions".
+ The list will be ordered by ascending stream_ordering.
+ The list will have between 0~limit entries.
+ """
+ # find rooms that have a read receipt in them and return the next
+ # push actions
def get_after_receipt(txn):
+ # find rooms that have a read receipt in them and return the next
+ # push actions
sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, "
- "e.received_ts "
- "FROM ("
- " SELECT room_id, user_id, "
- " max(topological_ordering) as topological_ordering, "
- " max(stream_ordering) as stream_ordering "
- " FROM events"
- " NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
- " GROUP BY room_id, user_id"
+ "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
+ " FROM ("
+ " SELECT room_id,"
+ " MAX(topological_ordering) as topological_ordering,"
+ " MAX(stream_ordering) as stream_ordering"
+ " FROM events"
+ " INNER JOIN receipts_linearized USING (room_id, event_id)"
+ " WHERE receipt_type = 'm.read' AND user_id = ?"
+ " GROUP BY room_id"
") AS rl,"
" event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
" ep.room_id = rl.room_id"
" AND ("
@@ -143,45 +162,163 @@ class EventPushActionsStore(SQLBaseStore):
" AND ep.stream_ordering > rl.stream_ordering"
" )"
" )"
- " AND ep.stream_ordering > ?"
" AND ep.user_id = ?"
- " AND ep.user_id = rl.user_id"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering ASC LIMIT ?"
)
- args = [min_stream_ordering, user_id]
- if max_stream_ordering is not None:
- sql += " AND ep.stream_ordering <= ?"
- args.append(max_stream_ordering)
- sql += " ORDER BY ep.stream_ordering ASC"
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
- "get_unread_push_actions_for_user_in_range", get_after_receipt
+ "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
+ # There are rooms with push actions in them but you don't have a read receipt in
+ # them e.g. rooms you've been invited to, so get push actions for rooms which do
+ # not have read receipts in them too.
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts"
" FROM event_push_actions AS ep"
- " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
- " WHERE ep.room_id not in ("
- " SELECT room_id FROM events NATURAL JOIN receipts_linearized"
+ " INNER JOIN events AS e USING (room_id, event_id)"
+ " WHERE"
+ " ep.room_id NOT IN ("
+ " SELECT room_id FROM receipts_linearized"
+ " WHERE receipt_type = 'm.read' AND user_id = ?"
+ " GROUP BY room_id"
+ " )"
+ " AND ep.user_id = ?"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering ASC LIMIT ?"
+ )
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
+ txn.execute(sql, args)
+ return txn.fetchall()
+ no_read_receipt = yield self.runInteraction(
+ "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
+ )
+
+ notifs = [
+ {
+ "event_id": row[0],
+ "room_id": row[1],
+ "stream_ordering": row[2],
+ "actions": json.loads(row[3]),
+ } for row in after_read_receipt + no_read_receipt
+ ]
+
+ # Now sort it so it's ordered correctly, since currently it will
+ # contain results from the first query, correctly ordered, followed
+ # by results from the second query, but we want them all ordered
+ # by stream_ordering, oldest first.
+ notifs.sort(key=lambda r: r['stream_ordering'])
+
+ # Take only up to the limit. We have to stop at the limit because
+ # one of the subqueries may have hit the limit.
+ defer.returnValue(notifs[:limit])
+
+ @defer.inlineCallbacks
+ def get_unread_push_actions_for_user_in_range_for_email(
+ self, user_id, min_stream_ordering, max_stream_ordering, limit=20
+ ):
+ """Get a list of the most recent unread push actions for a given user,
+ within the given stream ordering range. Called by the emailpusher
+
+ Args:
+ user_id (str): The user to fetch push actions for.
+ min_stream_ordering(int): The exclusive lower bound on the
+ stream ordering of event push actions to fetch.
+ max_stream_ordering(int): The inclusive upper bound on the
+ stream ordering of event push actions to fetch.
+ limit (int): The maximum number of rows to return.
+ Returns:
+ A promise which resolves to a list of dicts with the keys "event_id",
+ "room_id", "stream_ordering", "actions", "received_ts".
+ The list will be ordered by descending received_ts.
+ The list will have between 0~limit entries.
+ """
+ # find rooms that have a read receipt in them and return the most recent
+ # push actions
+ def get_after_receipt(txn):
+ sql = (
+ "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
+ " e.received_ts"
+ " FROM ("
+ " SELECT room_id,"
+ " MAX(topological_ordering) as topological_ordering,"
+ " MAX(stream_ordering) as stream_ordering"
+ " FROM events"
+ " INNER JOIN receipts_linearized USING (room_id, event_id)"
" WHERE receipt_type = 'm.read' AND user_id = ?"
" GROUP BY room_id"
- ") AND ep.user_id = ? AND ep.stream_ordering > ?"
+ ") AS rl,"
+ " event_push_actions AS ep"
+ " INNER JOIN events AS e USING (room_id, event_id)"
+ " WHERE"
+ " ep.room_id = rl.room_id"
+ " AND ("
+ " ep.topological_ordering > rl.topological_ordering"
+ " OR ("
+ " ep.topological_ordering = rl.topological_ordering"
+ " AND ep.stream_ordering > rl.stream_ordering"
+ " )"
+ " )"
+ " AND ep.user_id = ?"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering DESC LIMIT ?"
)
- args = [user_id, user_id, min_stream_ordering]
- if max_stream_ordering is not None:
- sql += " AND ep.stream_ordering <= ?"
- args.append(max_stream_ordering)
- sql += " ORDER BY ep.stream_ordering ASC"
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
+ txn.execute(sql, args)
+ return txn.fetchall()
+ after_read_receipt = yield self.runInteraction(
+ "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
+ )
+
+ # There are rooms with push actions in them but you don't have a read receipt in
+ # them e.g. rooms you've been invited to, so get push actions for rooms which do
+ # not have read receipts in them too.
+ def get_no_receipt(txn):
+ sql = (
+ "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
+ " e.received_ts"
+ " FROM event_push_actions AS ep"
+ " INNER JOIN events AS e USING (room_id, event_id)"
+ " WHERE"
+ " ep.room_id NOT IN ("
+ " SELECT room_id FROM receipts_linearized"
+ " WHERE receipt_type = 'm.read' AND user_id = ?"
+ " GROUP BY room_id"
+ " )"
+ " AND ep.user_id = ?"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering DESC LIMIT ?"
+ )
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
- "get_unread_push_actions_for_user_in_range", get_no_receipt
+ "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
- defer.returnValue([
+ # Make a list of dicts from the two sets of results.
+ notifs = [
{
"event_id": row[0],
"room_id": row[1],
@@ -189,7 +326,16 @@ class EventPushActionsStore(SQLBaseStore):
"actions": json.loads(row[3]),
"received_ts": row[4],
} for row in after_read_receipt + no_read_receipt
- ])
+ ]
+
+ # Now sort it so it's ordered correctly, since currently it will
+ # contain results from the first query, correctly ordered, followed
+ # by results from the second query, but we want them all ordered
+ # by received_ts (most recent first)
+ notifs.sort(key=lambda r: -(r['received_ts'] or 0))
+
+ # Now return the first `limit`
+ defer.returnValue(notifs[:limit])
@defer.inlineCallbacks
def get_push_actions_for_user(self, user_id, before=None, limit=50):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 4655669ba0..d2feee8dbb 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -23,9 +23,14 @@ from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
from canonicaljson import encode_canonical_json
-from collections import deque, namedtuple
+from collections import deque, namedtuple, OrderedDict
+from functools import wraps
+
+import synapse
+import synapse.metrics
import logging
@@ -35,6 +40,10 @@ import ujson as json
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+persist_event_counter = metrics.register_counter("persisted_events")
+
+
def encode_json(json_object):
if USE_FROZEN_DICTS:
# ujson doesn't like frozen_dicts
@@ -139,8 +148,32 @@ class _EventPeristenceQueue(object):
pass
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+def _retry_on_integrity_error(func):
+ """Wraps a database function so that it gets retried on IntegrityError,
+ with `delete_existing=True` passed in.
+
+ Args:
+ func: function that returns a Deferred and accepts a `delete_existing` arg
+ """
+ @wraps(func)
+ @defer.inlineCallbacks
+ def f(self, *args, **kwargs):
+ try:
+ res = yield func(self, *args, **kwargs)
+ except self.database_engine.module.IntegrityError:
+ logger.exception("IntegrityError, retrying.")
+ res = yield func(self, *args, delete_existing=True, **kwargs)
+ defer.returnValue(res)
+
+ return f
+
+
class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
+ EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, hs):
super(EventsStore, self).__init__(hs)
@@ -148,6 +181,10 @@ class EventsStore(SQLBaseStore):
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
+ self.register_background_update_handler(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
+ self._background_reindex_fields_sender,
+ )
self._event_persist_queue = _EventPeristenceQueue()
@@ -213,8 +250,10 @@ class EventsStore(SQLBaseStore):
self._event_persist_queue.handle_queue(room_id, persisting_queue)
+ @_retry_on_integrity_error
@defer.inlineCallbacks
- def _persist_events(self, events_and_contexts, backfilled=False):
+ def _persist_events(self, events_and_contexts, backfilled=False,
+ delete_existing=False):
if not events_and_contexts:
return
@@ -257,11 +296,15 @@ class EventsStore(SQLBaseStore):
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
+ delete_existing=delete_existing,
)
+ persist_event_counter.inc_by(len(chunk))
+ @_retry_on_integrity_error
@defer.inlineCallbacks
@log_function
- def _persist_event(self, event, context, current_state=None, backfilled=False):
+ def _persist_event(self, event, context, current_state=None, backfilled=False,
+ delete_existing=False):
try:
with self._stream_id_gen.get_next() as stream_ordering:
with self._state_groups_id_gen.get_next() as state_group_id:
@@ -274,7 +317,9 @@ class EventsStore(SQLBaseStore):
context=context,
current_state=current_state,
backfilled=backfilled,
+ delete_existing=delete_existing,
)
+ persist_event_counter.inc()
except _RollbackButIsFineException:
pass
@@ -305,7 +350,7 @@ class EventsStore(SQLBaseStore):
)
if not events and not allow_none:
- raise RuntimeError("Could not find event %s" % (event_id,))
+ raise SynapseError(404, "Could not find event %s" % (event_id,))
defer.returnValue(events[0] if events else None)
@@ -335,18 +380,15 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@log_function
- def _persist_event_txn(self, txn, event, context, current_state, backfilled=False):
+ def _persist_event_txn(self, txn, event, context, current_state, backfilled=False,
+ delete_existing=False):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
- txn.call_after(
- self.get_users_with_pushers_in_room.invalidate, (event.room_id,)
- )
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
- txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
@@ -379,10 +421,38 @@ class EventsStore(SQLBaseStore):
txn,
[(event, context)],
backfilled=backfilled,
+ delete_existing=delete_existing,
)
@log_function
- def _persist_events_txn(self, txn, events_and_contexts, backfilled):
+ def _persist_events_txn(self, txn, events_and_contexts, backfilled,
+ delete_existing=False):
+ """Insert some number of room events into the necessary database tables.
+
+ Rejected events are only inserted into the events table, the events_json table,
+ and the rejections table. Things reading from those table will need to check
+ whether the event was rejected.
+
+ If delete_existing is True then existing events will be purged from the
+ database before insertion. This is useful when retrying due to IntegrityError.
+ """
+ # Ensure that we don't have the same event twice.
+ # Pick the earliest non-outlier if there is one, else the earliest one.
+ new_events_and_contexts = OrderedDict()
+ for event, context in events_and_contexts:
+ prev_event_context = new_events_and_contexts.get(event.event_id)
+ if prev_event_context:
+ if not event.internal_metadata.is_outlier():
+ if prev_event_context[0].internal_metadata.is_outlier():
+ # To ensure correct ordering we pop, as OrderedDict is
+ # ordered by first insertion.
+ new_events_and_contexts.pop(event.event_id, None)
+ new_events_and_contexts[event.event_id] = (event, context)
+ else:
+ new_events_and_contexts[event.event_id] = (event, context)
+
+ events_and_contexts = new_events_and_contexts.values()
+
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
@@ -393,21 +463,11 @@ class EventsStore(SQLBaseStore):
event.room_id, event.internal_metadata.stream_ordering,
)
- if not event.internal_metadata.is_outlier():
+ if not event.internal_metadata.is_outlier() and not context.rejected:
depth_updates[event.room_id] = max(
event.depth, depth_updates.get(event.room_id, event.depth)
)
- if context.push_actions:
- self._set_push_actions_for_event_and_users_txn(
- txn, event, context.push_actions
- )
-
- if event.type == EventTypes.Redaction and event.redacts is not None:
- self._remove_push_actions_for_event_id_txn(
- txn, event.room_id, event.redacts
- )
-
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
@@ -417,30 +477,21 @@ class EventsStore(SQLBaseStore):
),
[event.event_id for event, _ in events_and_contexts]
)
+
have_persisted = {
event_id: outlier
for event_id, outlier in txn.fetchall()
}
- event_map = {}
to_remove = set()
for event, context in events_and_contexts:
- # Handle the case of the list including the same event multiple
- # times. The tricky thing here is when they differ by whether
- # they are an outlier.
- if event.event_id in event_map:
- other = event_map[event.event_id]
-
- if not other.internal_metadata.is_outlier():
- to_remove.add(event)
- continue
- elif not event.internal_metadata.is_outlier():
+ if context.rejected:
+ # If the event is rejected then we don't care if the event
+ # was an outlier or not.
+ if event.event_id in have_persisted:
+ # If we have already seen the event then ignore it.
to_remove.add(event)
- continue
- else:
- to_remove.add(other)
-
- event_map[event.event_id] = event
+ continue
if event.event_id not in have_persisted:
continue
@@ -449,6 +500,12 @@ class EventsStore(SQLBaseStore):
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
+ # We received a copy of an event that we had already stored as
+ # an outlier in the database. We now have some state at that
+ # so we need to update the state_groups table with that state.
+
+ # insert into the state_group, state_groups_state and
+ # event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, ((event, context),))
metadata_json = encode_json(
@@ -464,6 +521,8 @@ class EventsStore(SQLBaseStore):
(metadata_json, event.event_id,)
)
+ # Add an entry to the ex_outlier_stream table to replicate the
+ # change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id
self._simple_insert_txn(
@@ -485,6 +544,8 @@ class EventsStore(SQLBaseStore):
(False, event.event_id,)
)
+ # Update the event_backward_extremities table now that this
+ # event isn't an outlier any more.
self._update_extremeties(txn, [event])
events_and_contexts = [
@@ -492,38 +553,12 @@ class EventsStore(SQLBaseStore):
]
if not events_and_contexts:
+ # Make sure we don't pass an empty list to functions that expect to
+ # be storing at least one element.
return
- self._store_mult_state_groups_txn(txn, events_and_contexts)
-
- self._handle_mult_prev_events(
- txn,
- events=[event for event, _ in events_and_contexts],
- )
-
- for event, _ in events_and_contexts:
- if event.type == EventTypes.Name:
- self._store_room_name_txn(txn, event)
- elif event.type == EventTypes.Topic:
- self._store_room_topic_txn(txn, event)
- elif event.type == EventTypes.Message:
- self._store_room_message_txn(txn, event)
- elif event.type == EventTypes.Redaction:
- self._store_redaction(txn, event)
- elif event.type == EventTypes.RoomHistoryVisibility:
- self._store_history_visibility_txn(txn, event)
- elif event.type == EventTypes.GuestAccess:
- self._store_guest_access_txn(txn, event)
-
- self._store_room_members_txn(
- txn,
- [
- event
- for event, _ in events_and_contexts
- if event.type == EventTypes.Member
- ],
- backfilled=backfilled,
- )
+ # From this point onwards the events are only events that we haven't
+ # seen before.
def event_dict(event):
return {
@@ -535,6 +570,43 @@ class EventsStore(SQLBaseStore):
]
}
+ if delete_existing:
+ # For paranoia reasons, we go and delete all the existing entries
+ # for these events so we can reinsert them.
+ # This gets around any problems with some tables already having
+ # entries.
+
+ logger.info("Deleting existing")
+
+ for table in (
+ "events",
+ "event_auth",
+ "event_json",
+ "event_content_hashes",
+ "event_destinations",
+ "event_edge_hashes",
+ "event_edges",
+ "event_forward_extremities",
+ "event_push_actions",
+ "event_reference_hashes",
+ "event_search",
+ "event_signatures",
+ "event_to_state_groups",
+ "guest_access",
+ "history_visibility",
+ "local_invites",
+ "room_names",
+ "state_events",
+ "rejections",
+ "redactions",
+ "room_memberships",
+ "state_events"
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE event_id = ?" % (table,),
+ [(ev.event_id,) for ev, _ in events_and_contexts]
+ )
+
self._simple_insert_many_txn(
txn,
table="event_json",
@@ -567,15 +639,51 @@ class EventsStore(SQLBaseStore):
"content": encode_json(event.content).decode("UTF-8"),
"origin_server_ts": int(event.origin_server_ts),
"received_ts": self._clock.time_msec(),
+ "sender": event.sender,
+ "contains_url": (
+ "url" in event.content
+ and isinstance(event.content["url"], basestring)
+ ),
}
for event, _ in events_and_contexts
],
)
- if context.rejected:
- self._store_rejections_txn(
- txn, event.event_id, context.rejected
- )
+ # Remove the rejected events from the list now that we've added them
+ # to the events table and the events_json table.
+ to_remove = set()
+ for event, context in events_and_contexts:
+ if context.rejected:
+ # Insert the event_id into the rejections table
+ self._store_rejections_txn(
+ txn, event.event_id, context.rejected
+ )
+ to_remove.add(event)
+
+ events_and_contexts = [
+ ec for ec in events_and_contexts if ec[0] not in to_remove
+ ]
+
+ if not events_and_contexts:
+ # Make sure we don't pass an empty list to functions that expect to
+ # be storing at least one element.
+ return
+
+ # From this point onwards the events are only ones that weren't rejected.
+
+ for event, context in events_and_contexts:
+ # Insert all the push actions into the event_push_actions table.
+ if context.push_actions:
+ self._set_push_actions_for_event_and_users_txn(
+ txn, event, context.push_actions
+ )
+
+ if event.type == EventTypes.Redaction and event.redacts is not None:
+ # Remove the entries in the event_push_actions table for the
+ # redacted event.
+ self._remove_push_actions_for_event_id_txn(
+ txn, event.room_id, event.redacts
+ )
self._simple_insert_many_txn(
txn,
@@ -591,6 +699,49 @@ class EventsStore(SQLBaseStore):
],
)
+ # Insert into the state_groups, state_groups_state, and
+ # event_to_state_groups tables.
+ self._store_mult_state_groups_txn(txn, events_and_contexts)
+
+ # Update the event_forward_extremities, event_backward_extremities and
+ # event_edges tables.
+ self._handle_mult_prev_events(
+ txn,
+ events=[event for event, _ in events_and_contexts],
+ )
+
+ for event, _ in events_and_contexts:
+ if event.type == EventTypes.Name:
+ # Insert into the room_names and event_search tables.
+ self._store_room_name_txn(txn, event)
+ elif event.type == EventTypes.Topic:
+ # Insert into the topics table and event_search table.
+ self._store_room_topic_txn(txn, event)
+ elif event.type == EventTypes.Message:
+ # Insert into the event_search table.
+ self._store_room_message_txn(txn, event)
+ elif event.type == EventTypes.Redaction:
+ # Insert into the redactions table.
+ self._store_redaction(txn, event)
+ elif event.type == EventTypes.RoomHistoryVisibility:
+ # Insert into the event_search table.
+ self._store_history_visibility_txn(txn, event)
+ elif event.type == EventTypes.GuestAccess:
+ # Insert into the event_search table.
+ self._store_guest_access_txn(txn, event)
+
+ # Insert into the room_memberships table.
+ self._store_room_members_txn(
+ txn,
+ [
+ event
+ for event, _ in events_and_contexts
+ if event.type == EventTypes.Member
+ ],
+ backfilled=backfilled,
+ )
+
+ # Insert event_reference_hashes table.
self._store_event_reference_hashes_txn(
txn, [event for event, _ in events_and_contexts]
)
@@ -635,6 +786,9 @@ class EventsStore(SQLBaseStore):
],
)
+ # Prefill the event cache
+ self._add_to_cache(txn, events_and_contexts)
+
if backfilled:
# Backfilled events come before the current state so we don't need
# to update the current state table
@@ -645,22 +799,11 @@ class EventsStore(SQLBaseStore):
# Outlier events shouldn't clobber the current state.
continue
- if context.rejected:
- # If the event failed it's auth checks then it shouldn't
- # clobbler the current state.
- continue
-
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
- if event.type in [EventTypes.Name, EventTypes.Aliases]:
- txn.call_after(
- self.get_room_name_and_aliases.invalidate,
- (event.room_id,)
- )
-
self._simple_upsert_txn(
txn,
"current_state_events",
@@ -676,6 +819,45 @@ class EventsStore(SQLBaseStore):
return
+ def _add_to_cache(self, txn, events_and_contexts):
+ to_prefill = []
+
+ rows = []
+ N = 200
+ for i in range(0, len(events_and_contexts), N):
+ ev_map = {
+ e[0].event_id: e[0]
+ for e in events_and_contexts[i:i + N]
+ }
+ if not ev_map:
+ break
+
+ sql = (
+ "SELECT "
+ " e.event_id as event_id, "
+ " r.redacts as redacts,"
+ " rej.event_id as rejects "
+ " FROM events as e"
+ " LEFT JOIN rejections as rej USING (event_id)"
+ " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+ " WHERE e.event_id IN (%s)"
+ ) % (",".join(["?"] * len(ev_map)),)
+
+ txn.execute(sql, ev_map.keys())
+ rows = self.cursor_to_dict(txn)
+ for row in rows:
+ event = ev_map[row["event_id"]]
+ if not row["rejects"] and not row["redacts"]:
+ to_prefill.append(_EventCacheEntry(
+ event=event,
+ redacted_event=None,
+ ))
+
+ def prefill():
+ for cache_entry in to_prefill:
+ self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry)
+ txn.call_after(prefill)
+
def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event
txn.call_after(self._invalidate_get_event_cache, event.redacts)
@@ -741,100 +923,65 @@ class EventsStore(SQLBaseStore):
event_id_list = event_ids
event_ids = set(event_ids)
- event_map = self._get_events_from_cache(
+ event_entry_map = self._get_events_from_cache(
event_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
- missing_events_ids = [e for e in event_ids if e not in event_map]
+ missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
missing_events = yield self._enqueue_events(
missing_events_ids,
check_redacted=check_redacted,
- get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
- event_map.update(missing_events)
-
- defer.returnValue([
- event_map[e_id] for e_id in event_id_list
- if e_id in event_map and event_map[e_id]
- ])
+ event_entry_map.update(missing_events)
- def _get_events_txn(self, txn, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
- if not event_ids:
- return []
-
- event_map = self._get_events_from_cache(
- event_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- missing_events_ids = [e for e in event_ids if e not in event_map]
+ events = []
+ for event_id in event_id_list:
+ entry = event_entry_map.get(event_id, None)
+ if not entry:
+ continue
- if not missing_events_ids:
- return [
- event_map[e_id] for e_id in event_ids
- if e_id in event_map and event_map[e_id]
- ]
+ if allow_rejected or not entry.event.rejected_reason:
+ if check_redacted and entry.redacted_event:
+ event = entry.redacted_event
+ else:
+ event = entry.event
- missing_events = self._fetch_events_txn(
- txn,
- missing_events_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
+ events.append(event)
- event_map.update(missing_events)
+ if get_prev_content:
+ if "replaces_state" in event.unsigned:
+ prev = yield self.get_event(
+ event.unsigned["replaces_state"],
+ get_prev_content=False,
+ allow_none=True,
+ )
+ if prev:
+ event.unsigned = dict(event.unsigned)
+ event.unsigned["prev_content"] = prev.content
+ event.unsigned["prev_sender"] = prev.sender
- return [
- event_map[e_id] for e_id in event_ids
- if e_id in event_map and event_map[e_id]
- ]
+ defer.returnValue(events)
def _invalidate_get_event_cache(self, event_id):
- for check_redacted in (False, True):
- for get_prev_content in (False, True):
- self._get_event_cache.invalidate(
- (event_id, check_redacted, get_prev_content)
- )
-
- def _get_event_txn(self, txn, event_id, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
+ self._get_event_cache.invalidate((event_id,))
- events = self._get_events_txn(
- txn, [event_id],
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- return events[0] if events else None
-
- def _get_events_from_cache(self, events, check_redacted, get_prev_content,
- allow_rejected):
+ def _get_events_from_cache(self, events, allow_rejected):
event_map = {}
for event_id in events:
- try:
- ret = self._get_event_cache.get(
- (event_id, check_redacted, get_prev_content,)
- )
+ ret = self._get_event_cache.get((event_id,), None)
+ if not ret:
+ continue
- if allow_rejected or not ret.rejected_reason:
- event_map[event_id] = ret
- else:
- event_map[event_id] = None
- except KeyError:
- pass
+ if allow_rejected or not ret.event.rejected_reason:
+ event_map[event_id] = ret
+ else:
+ event_map[event_id] = None
return event_map
@@ -905,8 +1052,7 @@ class EventsStore(SQLBaseStore):
reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
- def _enqueue_events(self, events, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
+ def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -944,8 +1090,6 @@ class EventsStore(SQLBaseStore):
[
preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"],
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
rejected_reason=row["rejects"],
)
for row in rows
@@ -954,7 +1098,7 @@ class EventsStore(SQLBaseStore):
)
defer.returnValue({
- e.event_id: e
+ e.event.event_id: e
for e in res if e
})
@@ -984,37 +1128,8 @@ class EventsStore(SQLBaseStore):
return rows
- def _fetch_events_txn(self, txn, events, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
- if not events:
- return {}
-
- rows = self._fetch_event_rows(
- txn, events,
- )
-
- if not allow_rejected:
- rows[:] = [r for r in rows if not r["rejects"]]
-
- res = [
- self._get_event_from_row_txn(
- txn,
- row["internal_metadata"], row["json"], row["redacts"],
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- rejected_reason=row["rejects"],
- )
- for row in rows
- ]
-
- return {
- r.event_id: r
- for r in res
- }
-
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
- check_redacted=True, get_prev_content=False,
rejected_reason=None):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
@@ -1024,26 +1139,27 @@ class EventsStore(SQLBaseStore):
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
- desc="_get_event_from_row",
+ desc="_get_event_from_row_rejected_reason",
)
- ev = FrozenEvent(
+ original_ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
- if check_redacted and redacted:
- ev = prune_event(ev)
+ redacted_event = None
+ if redacted:
+ redacted_event = prune_event(original_ev)
redaction_id = yield self._simple_select_one_onecol(
table="redactions",
- keyvalues={"redacts": ev.event_id},
+ keyvalues={"redacts": redacted_event.event_id},
retcol="event_id",
- desc="_get_event_from_row",
+ desc="_get_event_from_row_redactions",
)
- ev.unsigned["redacted_by"] = redaction_id
+ redacted_event.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = yield self.get_event(
@@ -1055,86 +1171,16 @@ class EventsStore(SQLBaseStore):
if because:
# It's fine to do add the event directly, since get_pdu_json
# will serialise this field correctly
- ev.unsigned["redacted_because"] = because
-
- if get_prev_content and "replaces_state" in ev.unsigned:
- prev = yield self.get_event(
- ev.unsigned["replaces_state"],
- get_prev_content=False,
- allow_none=True,
- )
- if prev:
- ev.unsigned["prev_content"] = prev.content
- ev.unsigned["prev_sender"] = prev.sender
-
- self._get_event_cache.prefill(
- (ev.event_id, check_redacted, get_prev_content), ev
- )
-
- defer.returnValue(ev)
+ redacted_event.unsigned["redacted_because"] = because
- def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
- check_redacted=True, get_prev_content=False,
- rejected_reason=None):
- d = json.loads(js)
- internal_metadata = json.loads(internal_metadata)
-
- if rejected_reason:
- rejected_reason = self._simple_select_one_onecol_txn(
- txn,
- table="rejections",
- keyvalues={"event_id": rejected_reason},
- retcol="reason",
- )
-
- ev = FrozenEvent(
- d,
- internal_metadata_dict=internal_metadata,
- rejected_reason=rejected_reason,
+ cache_entry = _EventCacheEntry(
+ event=original_ev,
+ redacted_event=redacted_event,
)
- if check_redacted and redacted:
- ev = prune_event(ev)
-
- redaction_id = self._simple_select_one_onecol_txn(
- txn,
- table="redactions",
- keyvalues={"redacts": ev.event_id},
- retcol="event_id",
- )
+ self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
- ev.unsigned["redacted_by"] = redaction_id
- # Get the redaction event.
-
- because = self._get_event_txn(
- txn,
- redaction_id,
- check_redacted=False
- )
-
- if because:
- ev.unsigned["redacted_because"] = because
-
- if get_prev_content and "replaces_state" in ev.unsigned:
- prev = self._get_event_txn(
- txn,
- ev.unsigned["replaces_state"],
- get_prev_content=False,
- )
- if prev:
- ev.unsigned["prev_content"] = prev.content
- ev.unsigned["prev_sender"] = prev.sender
-
- self._get_event_cache.prefill(
- (ev.event_id, check_redacted, get_prev_content), ev
- )
-
- return ev
-
- def _parse_events_txn(self, txn, rows):
- event_ids = [r["event_id"] for r in rows]
-
- return self._get_events_txn(txn, event_ids)
+ defer.returnValue(cache_entry)
@defer.inlineCallbacks
def count_daily_messages(self):
@@ -1208,6 +1254,78 @@ class EventsStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
+ def _background_reindex_fields_sender(self, progress, batch_size):
+ target_min_stream_id = progress["target_min_stream_id_inclusive"]
+ max_stream_id = progress["max_stream_id_exclusive"]
+ rows_inserted = progress.get("rows_inserted", 0)
+
+ INSERT_CLUMP_SIZE = 1000
+
+ def reindex_txn(txn):
+ sql = (
+ "SELECT stream_ordering, event_id, json FROM events"
+ " INNER JOIN event_json USING (event_id)"
+ " WHERE ? <= stream_ordering AND stream_ordering < ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT ?"
+ )
+
+ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ min_stream_id = rows[-1][0]
+
+ update_rows = []
+ for row in rows:
+ try:
+ event_id = row[1]
+ event_json = json.loads(row[2])
+ sender = event_json["sender"]
+ content = event_json["content"]
+
+ contains_url = "url" in content
+ if contains_url:
+ contains_url &= isinstance(content["url"], basestring)
+ except (KeyError, AttributeError):
+ # If the event is missing a necessary field then
+ # skip over it.
+ continue
+
+ update_rows.append((sender, contains_url, event_id))
+
+ sql = (
+ "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
+ )
+
+ for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
+ clump = update_rows[index:index + INSERT_CLUMP_SIZE]
+ txn.executemany(sql, clump)
+
+ progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ "rows_inserted": rows_inserted + len(rows)
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
+ )
+
+ return len(rows)
+
+ result = yield self.runInteraction(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -1374,6 +1492,162 @@ class EventsStore(SQLBaseStore):
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
+ def delete_old_state(self, room_id, topological_ordering):
+ return self.runInteraction(
+ "delete_old_state",
+ self._delete_old_state_txn, room_id, topological_ordering
+ )
+
+ def _delete_old_state_txn(self, txn, room_id, topological_ordering):
+ """Deletes old room state
+ """
+
+ # Tables that should be pruned:
+ # event_auth
+ # event_backward_extremities
+ # event_content_hashes
+ # event_destinations
+ # event_edge_hashes
+ # event_edges
+ # event_forward_extremities
+ # event_json
+ # event_push_actions
+ # event_reference_hashes
+ # event_search
+ # event_signatures
+ # event_to_state_groups
+ # events
+ # rejections
+ # room_depth
+ # state_groups
+ # state_groups_state
+
+ # First ensure that we're not about to delete all the forward extremeties
+ txn.execute(
+ "SELECT e.event_id, e.depth FROM events as e "
+ "INNER JOIN event_forward_extremities as f "
+ "ON e.event_id = f.event_id "
+ "AND e.room_id = f.room_id "
+ "WHERE f.room_id = ?",
+ (room_id,)
+ )
+ rows = txn.fetchall()
+ max_depth = max(row[0] for row in rows)
+
+ if max_depth <= topological_ordering:
+ # We need to ensure we don't delete all the events from the datanase
+ # otherwise we wouldn't be able to send any events (due to not
+ # having any backwards extremeties)
+ raise SynapseError(
+ 400, "topological_ordering is greater than forward extremeties"
+ )
+
+ txn.execute(
+ "SELECT event_id, state_key FROM events"
+ " LEFT JOIN state_events USING (room_id, event_id)"
+ " WHERE room_id = ? AND topological_ordering < ?",
+ (room_id, topological_ordering,)
+ )
+ event_rows = txn.fetchall()
+
+ # We calculate the new entries for the backward extremeties by finding
+ # all events that point to events that are to be purged
+ txn.execute(
+ "SELECT DISTINCT e.event_id FROM events as e"
+ " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
+ " INNER JOIN events as e2 ON e2.event_id = ed.event_id"
+ " WHERE e.room_id = ? AND e.topological_ordering < ?"
+ " AND e2.topological_ordering >= ?",
+ (room_id, topological_ordering, topological_ordering)
+ )
+ new_backwards_extrems = txn.fetchall()
+
+ txn.execute(
+ "DELETE FROM event_backward_extremities WHERE room_id = ?",
+ (room_id,)
+ )
+
+ # Update backward extremeties
+ txn.executemany(
+ "INSERT INTO event_backward_extremities (room_id, event_id)"
+ " VALUES (?, ?)",
+ [
+ (room_id, event_id) for event_id, in new_backwards_extrems
+ ]
+ )
+
+ # Get all state groups that are only referenced by events that are
+ # to be deleted.
+ txn.execute(
+ "SELECT state_group FROM event_to_state_groups"
+ " INNER JOIN events USING (event_id)"
+ " WHERE state_group IN ("
+ " SELECT DISTINCT state_group FROM events"
+ " INNER JOIN event_to_state_groups USING (event_id)"
+ " WHERE room_id = ? AND topological_ordering < ?"
+ " )"
+ " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
+ (room_id, topological_ordering, topological_ordering)
+ )
+ state_rows = txn.fetchall()
+ txn.executemany(
+ "DELETE FROM state_groups_state WHERE state_group = ?",
+ state_rows
+ )
+ txn.executemany(
+ "DELETE FROM state_groups WHERE id = ?",
+ state_rows
+ )
+ # Delete all non-state
+ txn.executemany(
+ "DELETE FROM event_to_state_groups WHERE event_id = ?",
+ [(event_id,) for event_id, _ in event_rows]
+ )
+
+ txn.execute(
+ "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+ (topological_ordering, room_id,)
+ )
+
+ # Delete all remote non-state events
+ to_delete = [
+ (event_id,) for event_id, state_key in event_rows
+ if state_key is None and not self.hs.is_mine_id(event_id)
+ ]
+ for table in (
+ "events",
+ "event_json",
+ "event_auth",
+ "event_content_hashes",
+ "event_destinations",
+ "event_edge_hashes",
+ "event_edges",
+ "event_forward_extremities",
+ "event_push_actions",
+ "event_reference_hashes",
+ "event_search",
+ "event_signatures",
+ "rejections",
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE event_id = ?" % (table,),
+ to_delete
+ )
+
+ txn.executemany(
+ "DELETE FROM events WHERE event_id = ?",
+ to_delete
+ )
+ # Mark all state and own events as outliers
+ txn.executemany(
+ "UPDATE events SET outlier = ?"
+ " WHERE event_id = ?",
+ [
+ (True, event_id,) for event_id, state_key in event_rows
+ if state_key is not None or self.hs.is_mine_id(event_id)
+ ]
+ )
+
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index a495a8a7d9..86b37b9ddd 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -22,6 +22,10 @@ import OpenSSL
from signedjson.key import decode_verify_key_bytes
import hashlib
+import logging
+
+logger = logging.getLogger(__name__)
+
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys and tls X.509 certificates
@@ -74,22 +78,22 @@ class KeyStore(SQLBaseStore):
)
@cachedInlineCallbacks()
- def get_all_server_verify_keys(self, server_name):
- rows = yield self._simple_select_list(
+ def _get_server_verify_key(self, server_name, key_id):
+ verify_key_bytes = yield self._simple_select_one_onecol(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
+ "key_id": key_id,
},
- retcols=["key_id", "verify_key"],
- desc="get_all_server_verify_keys",
+ retcol="verify_key",
+ desc="_get_server_verify_key",
+ allow_none=True,
)
- defer.returnValue({
- row["key_id"]: decode_verify_key_bytes(
- row["key_id"], str(row["verify_key"])
- )
- for row in rows
- })
+ if verify_key_bytes:
+ defer.returnValue(decode_verify_key_bytes(
+ key_id, str(verify_key_bytes)
+ ))
@defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids):
@@ -101,12 +105,12 @@ class KeyStore(SQLBaseStore):
Returns:
(list of VerifyKey): The verification keys.
"""
- keys = yield self.get_all_server_verify_keys(server_name)
- defer.returnValue({
- k: keys[k]
- for k in key_ids
- if k in keys and keys[k]
- })
+ keys = {}
+ for key_id in key_ids:
+ key = yield self._get_server_verify_key(server_name, key_id)
+ if key:
+ keys[key_id] = key
+ defer.returnValue(keys)
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
@@ -133,8 +137,6 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
- self.get_all_server_verify_keys.invalidate((server_name,))
-
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index a820fcf07f..4c0f82353d 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -157,10 +157,25 @@ class MediaRepositoryStore(SQLBaseStore):
"created_ts": time_now_ms,
"upload_name": upload_name,
"filesystem_id": filesystem_id,
+ "last_access_ts": time_now_ms,
},
desc="store_cached_remote_media",
)
+ def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+ def update_cache_txn(txn):
+ sql = (
+ "UPDATE remote_media_cache SET last_access_ts = ?"
+ " WHERE media_origin = ? AND media_id = ?"
+ )
+
+ txn.executemany(sql, (
+ (time_ts, media_origin, media_id)
+ for media_origin, media_id in origin_id_tuples
+ ))
+
+ return self.runInteraction("update_cached_last_access_time", update_cache_txn)
+
def get_remote_media_thumbnails(self, origin, media_id):
return self._simple_select_list(
"remote_media_cache_thumbnails",
@@ -190,3 +205,32 @@ class MediaRepositoryStore(SQLBaseStore):
},
desc="store_remote_media_thumbnail",
)
+
+ def get_remote_media_before(self, before_ts):
+ sql = (
+ "SELECT media_origin, media_id, filesystem_id"
+ " FROM remote_media_cache"
+ " WHERE last_access_ts < ?"
+ )
+
+ return self._execute(
+ "get_remote_media_before", self.cursor_to_dict, sql, before_ts
+ )
+
+ def delete_remote_media(self, media_origin, media_id):
+ def delete_remote_media_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ "remote_media_cache",
+ keyvalues={
+ "media_origin": media_origin, "media_id": media_id
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ "remote_media_cache_thumbnails",
+ keyvalues={
+ "media_origin": media_origin, "media_id": media_id
+ },
+ )
+ return self.runInteraction("delete_remote_media", delete_remote_media_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index c8487c8838..8801669a6b 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 32
+SCHEMA_VERSION = 33
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 3fab57a7e8..d03f7c541e 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -118,6 +118,9 @@ class PresenceStore(SQLBaseStore):
)
def get_all_presence_updates(self, last_id, current_id):
+ if last_id == current_id:
+ return defer.succeed([])
+
def get_all_presence_updates_txn(txn):
sql = (
"SELECT stream_id, user_id, state, last_active_ts,"
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index d2bf7f2aec..8183b7f1b0 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -14,7 +14,8 @@
# limitations under the License.
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.push.baserules import list_with_base_rules
from twisted.internet import defer
import logging
@@ -23,8 +24,31 @@ import simplejson as json
logger = logging.getLogger(__name__)
+def _load_rules(rawrules, enabled_map):
+ ruleslist = []
+ for rawrule in rawrules:
+ rule = dict(rawrule)
+ rule["conditions"] = json.loads(rawrule["conditions"])
+ rule["actions"] = json.loads(rawrule["actions"])
+ ruleslist.append(rule)
+
+ # We're going to be mutating this a lot, so do a deep copy
+ rules = list(list_with_base_rules(ruleslist))
+
+ for i, rule in enumerate(rules):
+ rule_id = rule['rule_id']
+ if rule_id in enabled_map:
+ if rule.get('enabled', True) != bool(enabled_map[rule_id]):
+ # Rules are cached across users.
+ rule = dict(rule)
+ rule['enabled'] = bool(enabled_map[rule_id])
+ rules[i] = rule
+
+ return rules
+
+
class PushRuleStore(SQLBaseStore):
- @cachedInlineCallbacks()
+ @cachedInlineCallbacks(lru=True)
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
@@ -42,9 +66,13 @@ class PushRuleStore(SQLBaseStore):
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
)
- defer.returnValue(rows)
+ enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
- @cachedInlineCallbacks()
+ rules = _load_rules(rows, enabled_map)
+
+ defer.returnValue(rules)
+
+ @cachedInlineCallbacks(lru=True)
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
@@ -60,12 +88,16 @@ class PushRuleStore(SQLBaseStore):
r['rule_id']: False if r['enabled'] == 0 else True for r in results
})
- @defer.inlineCallbacks
+ @cachedList(cached_method_name="get_push_rules_for_user",
+ list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules(self, user_ids):
if not user_ids:
defer.returnValue({})
- results = {}
+ results = {
+ user_id: []
+ for user_id in user_ids
+ }
rows = yield self._simple_select_many_batch(
table="push_rules",
@@ -75,18 +107,32 @@ class PushRuleStore(SQLBaseStore):
desc="bulk_get_push_rules",
)
- rows.sort(key=lambda e: (-e["priority_class"], -e["priority"]))
+ rows.sort(
+ key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
+ )
for row in rows:
results.setdefault(row['user_name'], []).append(row)
+
+ enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+
+ for user_id, rules in results.items():
+ results[user_id] = _load_rules(
+ rules, enabled_map_by_user.get(user_id, {})
+ )
+
defer.returnValue(results)
- @defer.inlineCallbacks
+ @cachedList(cached_method_name="get_push_rules_enabled_for_user",
+ list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
defer.returnValue({})
- results = {}
+ results = {
+ user_id: {}
+ for user_id in user_ids
+ }
rows = yield self._simple_select_many_batch(
table="push_rules_enable",
@@ -96,7 +142,8 @@ class PushRuleStore(SQLBaseStore):
desc="bulk_get_push_rules_enabled",
)
for row in rows:
- results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
+ enabled = bool(row['enabled'])
+ results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
defer.returnValue(results)
@defer.inlineCallbacks
@@ -374,6 +421,9 @@ class PushRuleStore(SQLBaseStore):
def get_all_push_rule_updates(self, last_id, current_id, limit):
"""Get all the push rules changes that have happend on the server"""
+ if last_id == current_id:
+ return defer.succeed([])
+
def get_all_push_rule_updates_txn(txn):
sql = (
"SELECT stream_id, event_stream_ordering, user_id, rule_id,"
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 9e8e2e2964..a7d7c54d7e 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
from canonicaljson import encode_canonical_json
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
import logging
import simplejson as json
@@ -135,19 +135,35 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn
)
- @cachedInlineCallbacks(num_args=1)
- def get_users_with_pushers_in_room(self, room_id):
- users = yield self.get_users_in_room(room_id)
-
+ @cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
+ def get_if_user_has_pusher(self, user_id):
result = yield self._simple_select_many_batch(
table='pushers',
+ keyvalues={
+ 'user_name': 'user_id',
+ },
+ retcol='user_name',
+ desc='get_if_user_has_pusher',
+ allow_none=True,
+ )
+
+ defer.returnValue(bool(result))
+
+ @cachedList(cached_method_name="get_if_user_has_pusher",
+ list_name="user_ids", num_args=1, inlineCallbacks=True)
+ def get_if_users_have_pushers(self, user_ids):
+ rows = yield self._simple_select_many_batch(
+ table='pushers',
column='user_name',
- iterable=users,
+ iterable=user_ids,
retcols=['user_name'],
- desc='get_users_with_pushers_in_room'
+ desc='get_if_users_have_pushers'
)
- defer.returnValue([r['user_name'] for r in result])
+ result = {user_id: False for user_id in user_ids}
+ result.update({r['user_name']: True for r in rows})
+
+ defer.returnValue(result)
@defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id,
@@ -178,16 +194,16 @@ class PusherStore(SQLBaseStore):
},
)
if newly_inserted:
- # get_users_with_pushers_in_room only cares if the user has
+ # get_if_user_has_pusher only cares if the user has
# at least *one* pusher.
- txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
+ txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
yield self.runInteraction("add_pusher", f)
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
def delete_pusher_txn(txn, stream_id):
- txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
+ txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
self._simple_delete_one_txn(
txn,
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index d147a60602..cb4e04a679 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -34,6 +34,26 @@ class ReceiptsStore(SQLBaseStore):
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
+ @cachedInlineCallbacks()
+ def get_users_with_read_receipts_in_room(self, room_id):
+ receipts = yield self.get_receipts_for_room(room_id, "m.read")
+ defer.returnValue(set(r['user_id'] for r in receipts))
+
+ def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+ user_id):
+ if receipt_type != "m.read":
+ return
+
+ # Returns an ObservableDeferred
+ res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
+
+ if res and res.called and user_id in res.result:
+ # We'd only be adding to the set, so no point invalidating if the
+ # user is already there
+ return
+
+ self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
@@ -254,6 +274,10 @@ class ReceiptsStore(SQLBaseStore):
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
txn.call_after(
+ self._invalidate_get_users_with_receipts_in_room,
+ room_id, receipt_type, user_id,
+ )
+ txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
@@ -399,6 +423,10 @@ class ReceiptsStore(SQLBaseStore):
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
txn.call_after(
+ self._invalidate_get_users_with_receipts_in_room,
+ room_id, receipt_type, user_id,
+ )
+ txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index bda84a744a..7e7d32eb66 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,25 +18,40 @@ import re
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
-
-from ._base import SQLBaseStore
+from synapse.storage import background_updates
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-class RegistrationStore(SQLBaseStore):
+class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
super(RegistrationStore, self).__init__(hs)
self.clock = hs.get_clock()
+ self.register_background_index_update(
+ "access_tokens_device_index",
+ index_name="access_tokens_device_id",
+ table="access_tokens",
+ columns=["user_id", "device_id"],
+ )
+
+ self.register_background_index_update(
+ "refresh_tokens_device_index",
+ index_name="refresh_tokens_device_id",
+ table="refresh_tokens",
+ columns=["user_id", "device_id"],
+ )
+
@defer.inlineCallbacks
- def add_access_token_to_user(self, user_id, token):
+ def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
+ device_id (str): ID of the device to associate with the access
+ token
Raises:
StoreError if there was a problem adding this.
"""
@@ -47,18 +62,21 @@ class RegistrationStore(SQLBaseStore):
{
"id": next_id,
"user_id": user_id,
- "token": token
+ "token": token,
+ "device_id": device_id,
},
desc="add_access_token_to_user",
)
@defer.inlineCallbacks
- def add_refresh_token_to_user(self, user_id, token):
+ def add_refresh_token_to_user(self, user_id, token, device_id=None):
"""Adds a refresh token for the given user.
Args:
user_id (str): The user ID.
token (str): The new refresh token to add.
+ device_id (str): ID of the device to associate with the access
+ token
Raises:
StoreError if there was a problem adding this.
"""
@@ -69,25 +87,31 @@ class RegistrationStore(SQLBaseStore):
{
"id": next_id,
"user_id": user_id,
- "token": token
+ "token": token,
+ "device_id": device_id,
},
desc="add_refresh_token_to_user",
)
@defer.inlineCallbacks
- def register(self, user_id, token, password_hash,
- was_guest=False, make_guest=False, appservice_id=None):
+ def register(self, user_id, token=None, password_hash=None,
+ was_guest=False, make_guest=False, appservice_id=None,
+ create_profile_with_localpart=None, admin=False):
"""Attempts to register an account.
Args:
user_id (str): The desired user ID to register.
- token (str): The desired access token to use for this user.
+ token (str): The desired access token to use for this user. If this
+ is not None, the given access token is associated with the user
+ id.
password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str): The ID of the appservice registering the user.
+ create_profile_with_localpart (str): Optionally create a profile for
+ the given localpart.
Raises:
StoreError if the user_id could not be registered.
"""
@@ -99,7 +123,9 @@ class RegistrationStore(SQLBaseStore):
password_hash,
was_guest,
make_guest,
- appservice_id
+ appservice_id,
+ create_profile_with_localpart,
+ admin
)
self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,))
@@ -112,7 +138,9 @@ class RegistrationStore(SQLBaseStore):
password_hash,
was_guest,
make_guest,
- appservice_id
+ appservice_id,
+ create_profile_with_localpart,
+ admin,
):
now = int(self.clock.time())
@@ -120,29 +148,48 @@ class RegistrationStore(SQLBaseStore):
try:
if was_guest:
- txn.execute("UPDATE users SET"
- " password_hash = ?,"
- " upgrade_ts = ?,"
- " is_guest = ?"
- " WHERE name = ?",
- [password_hash, now, 1 if make_guest else 0, user_id])
+ # Ensure that the guest user actually exists
+ # ``allow_none=False`` makes this raise an exception
+ # if the row isn't in the database.
+ self._simple_select_one_txn(
+ txn,
+ "users",
+ keyvalues={
+ "name": user_id,
+ "is_guest": 1,
+ },
+ retcols=("name",),
+ allow_none=False,
+ )
+
+ self._simple_update_one_txn(
+ txn,
+ "users",
+ keyvalues={
+ "name": user_id,
+ "is_guest": 1,
+ },
+ updatevalues={
+ "password_hash": password_hash,
+ "upgrade_ts": now,
+ "is_guest": 1 if make_guest else 0,
+ "appservice_id": appservice_id,
+ "admin": 1 if admin else 0,
+ }
+ )
else:
- txn.execute("INSERT INTO users "
- "("
- " name,"
- " password_hash,"
- " creation_ts,"
- " is_guest,"
- " appservice_id"
- ") "
- "VALUES (?,?,?,?,?)",
- [
- user_id,
- password_hash,
- now,
- 1 if make_guest else 0,
- appservice_id,
- ])
+ self._simple_insert_txn(
+ txn,
+ "users",
+ values={
+ "name": user_id,
+ "password_hash": password_hash,
+ "creation_ts": now,
+ "is_guest": 1 if make_guest else 0,
+ "appservice_id": appservice_id,
+ "admin": 1 if admin else 0,
+ }
+ )
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -157,6 +204,12 @@ class RegistrationStore(SQLBaseStore):
(next_id, user_id, token,)
)
+ if create_profile_with_localpart:
+ txn.execute(
+ "INSERT INTO profiles(user_id) VALUES (?)",
+ (create_profile_with_localpart,)
+ )
+
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
@@ -198,16 +251,37 @@ class RegistrationStore(SQLBaseStore):
self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks
- def user_delete_access_tokens(self, user_id, except_token_ids=[]):
- def f(txn):
- sql = "SELECT token FROM access_tokens WHERE user_id = ?"
+ def user_delete_access_tokens(self, user_id, except_token_ids=[],
+ device_id=None,
+ delete_refresh_tokens=False):
+ """
+ Invalidate access/refresh tokens belonging to a user
+
+ Args:
+ user_id (str): ID of user the tokens belong to
+ except_token_ids (list[str]): list of access_tokens which should
+ *not* be deleted
+ device_id (str|None): ID of device the tokens are associated with.
+ If None, tokens associated with any device (or no device) will
+ be deleted
+ delete_refresh_tokens (bool): True to delete refresh tokens as
+ well as access tokens.
+ Returns:
+ defer.Deferred:
+ """
+ def f(txn, table, except_tokens, call_after_delete):
+ sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id]
- if except_token_ids:
+ if device_id is not None:
+ sql += " AND device_id = ?"
+ clauses.append(device_id)
+
+ if except_tokens:
sql += " AND id NOT IN (%s)" % (
- ",".join(["?" for _ in except_token_ids]),
+ ",".join(["?" for _ in except_tokens]),
)
- clauses += except_token_ids
+ clauses += except_tokens
txn.execute(sql, clauses)
@@ -216,16 +290,33 @@ class RegistrationStore(SQLBaseStore):
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
- for row in chunk:
- txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
+ if call_after_delete:
+ for row in chunk:
+ txn.call_after(call_after_delete, (row[0],))
txn.execute(
- "DELETE FROM access_tokens WHERE token in (%s)" % (
+ "DELETE FROM %s WHERE token in (%s)" % (
+ table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
)
- yield self.runInteraction("user_delete_access_tokens", f)
+ # delete refresh tokens first, to stop new access tokens being
+ # allocated while our backs are turned
+ if delete_refresh_tokens:
+ yield self.runInteraction(
+ "user_delete_access_tokens", f,
+ table="refresh_tokens",
+ except_tokens=[],
+ call_after_delete=None,
+ )
+
+ yield self.runInteraction(
+ "user_delete_access_tokens", f,
+ table="access_tokens",
+ except_tokens=except_token_ids,
+ call_after_delete=self.get_user_by_access_token.invalidate,
+ )
def delete_access_token(self, access_token):
def f(txn):
@@ -248,9 +339,8 @@ class RegistrationStore(SQLBaseStore):
Args:
token (str): The access token of a user.
Returns:
- dict: Including the name (user_id) and the ID of their access token.
- Raises:
- StoreError if no user was found.
+ defer.Deferred: None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
@@ -259,18 +349,18 @@ class RegistrationStore(SQLBaseStore):
)
def exchange_refresh_token(self, refresh_token, token_generator):
- """Exchange a refresh token for a new access token and refresh token.
+ """Exchange a refresh token for a new one.
Doing so invalidates the old refresh token - refresh tokens are single
use.
Args:
- token (str): The refresh token of a user.
+ refresh_token (str): The refresh token of a user.
token_generator (fn: str -> str): Function which, when given a
user ID, returns a unique refresh token for that user. This
function must never return the same value twice.
Returns:
- tuple of (user_id, refresh_token)
+ tuple of (user_id, new_refresh_token, device_id)
Raises:
StoreError if no user was found with that refresh token.
"""
@@ -282,12 +372,13 @@ class RegistrationStore(SQLBaseStore):
)
def _exchange_refresh_token(self, txn, old_token, token_generator):
- sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
+ sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
txn.execute(sql, (old_token,))
rows = self.cursor_to_dict(txn)
if not rows:
raise StoreError(403, "Did not recognize refresh token")
user_id = rows[0]["user_id"]
+ device_id = rows[0]["device_id"]
# TODO(danielwh): Maybe perform a validation on the macaroon that
# macaroon.user_id == user_id.
@@ -296,7 +387,7 @@ class RegistrationStore(SQLBaseStore):
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
txn.execute(sql, (new_token, old_token,))
- return user_id, new_token
+ return user_id, new_token, device_id
@defer.inlineCallbacks
def is_server_admin(self, user):
@@ -324,7 +415,8 @@ class RegistrationStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id"
+ "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ " access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
@@ -373,6 +465,15 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(ret['user_id'])
defer.returnValue(None)
+ def user_delete_threepids(self, user_id):
+ return self._simple_delete(
+ "user_threepids",
+ keyvalues={
+ "user_id": user_id,
+ },
+ desc="user_delete_threepids",
+ )
+
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 26933e593a..8251f58670 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks
from .engines import PostgresEngine, Sqlite3Engine
import collections
@@ -192,37 +191,6 @@ class RoomStore(SQLBaseStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- @cachedInlineCallbacks()
- def get_room_name_and_aliases(self, room_id):
- def f(txn):
- sql = (
- "SELECT event_id FROM current_state_events "
- "WHERE room_id = ? "
- )
-
- sql += " AND ((type = 'm.room.name' AND state_key = '')"
- sql += " OR type = 'm.room.aliases')"
-
- txn.execute(sql, (room_id,))
- results = self.cursor_to_dict(txn)
-
- return self._parse_events_txn(txn, results)
-
- events = yield self.runInteraction("get_room_name_and_aliases", f)
-
- name = None
- aliases = []
-
- for e in events:
- if e.type == 'm.room.name':
- if 'name' in e.content:
- name = e.content['name']
- elif e.type == 'm.room.aliases':
- if 'aliases' in e.content:
- aliases.extend(e.content['aliases'])
-
- defer.returnValue((name, aliases))
-
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
next_id = self._event_reports_id_gen.get_next()
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index face685ed2..8bd693be72 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -59,9 +59,6 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
- self.get_users_with_pushers_in_room.invalidate, (event.room_id,)
- )
- txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
)
@@ -241,30 +238,10 @@ class RoomMemberStore(SQLBaseStore):
return results
- @cached(max_entries=5000)
+ @cachedInlineCallbacks(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):
- return self.runInteraction(
- "get_joined_hosts_for_room",
- self._get_joined_hosts_for_room_txn,
- room_id,
- )
-
- def _get_joined_hosts_for_room_txn(self, txn, room_id):
- rows = self._get_members_rows_txn(
- txn,
- room_id, membership=Membership.JOIN
- )
-
- joined_domains = set(get_domain_from_id(r["user_id"]) for r in rows)
-
- return joined_domains
-
- def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
- rows = self._get_members_rows_txn(
- txn,
- room_id, membership, user_id,
- )
- return [r["event_id"] for r in rows]
+ user_ids = yield self.get_users_in_room(room_id)
+ defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?"
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py
index b417e3ac08..5b7d8d1ab5 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/schema/delta/30/as_users.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from synapse.storage.appservice import ApplicationServiceStore
+from synapse.config.appservice import load_appservices
logger = logging.getLogger(__name__)
@@ -38,7 +38,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
logger.warning("Could not get app_service_config_files from config")
pass
- appservices = ApplicationServiceStore.load_appservices(
+ appservices = load_appservices(
config.server_name, config_files
)
diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/schema/delta/33/access_tokens_device_index.sql
new file mode 100644
index 0000000000..61ad3fe3e8
--- /dev/null
+++ b/synapse/storage/schema/delta/33/access_tokens_device_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('access_tokens_device_index', '{}');
diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/schema/delta/33/devices.sql
new file mode 100644
index 0000000000..eca7268d82
--- /dev/null
+++ b/synapse/storage/schema/delta/33/devices.sql
@@ -0,0 +1,21 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE devices (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ display_name TEXT,
+ CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
+);
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
new file mode 100644
index 0000000000..aa4a3b9f2f
--- /dev/null
+++ b/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
@@ -0,0 +1,19 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- make sure that we have a device record for each set of E2E keys, so that the
+-- user can delete them if they like.
+INSERT INTO devices
+ SELECT user_id, device_id, NULL FROM e2e_device_keys_json;
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
new file mode 100644
index 0000000000..6671573398
--- /dev/null
+++ b/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
@@ -0,0 +1,20 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- a previous version of the "devices_for_e2e_keys" delta set all the device
+-- names to "unknown device". This wasn't terribly helpful
+UPDATE devices
+ SET display_name = NULL
+ WHERE display_name = 'unknown device';
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py
new file mode 100644
index 0000000000..83066cccc9
--- /dev/null
+++ b/synapse/storage/schema/delta/33/event_fields.py
@@ -0,0 +1,60 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.prepare_database import get_statements
+
+import logging
+import ujson
+
+logger = logging.getLogger(__name__)
+
+
+ALTER_TABLE = """
+ALTER TABLE events ADD COLUMN sender TEXT;
+ALTER TABLE events ADD COLUMN contains_url BOOLEAN;
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ for statement in get_statements(ALTER_TABLE.splitlines()):
+ cur.execute(statement)
+
+ cur.execute("SELECT MIN(stream_ordering) FROM events")
+ rows = cur.fetchall()
+ min_stream_id = rows[0][0]
+
+ cur.execute("SELECT MAX(stream_ordering) FROM events")
+ rows = cur.fetchall()
+ max_stream_id = rows[0][0]
+
+ if min_stream_id is not None and max_stream_id is not None:
+ progress = {
+ "target_min_stream_id_inclusive": min_stream_id,
+ "max_stream_id_exclusive": max_stream_id + 1,
+ "rows_inserted": 0,
+ }
+ progress_json = ujson.dumps(progress)
+
+ sql = (
+ "INSERT into background_updates (update_name, progress_json)"
+ " VALUES (?, ?)"
+ )
+
+ sql = database_engine.convert_param_style(sql)
+
+ cur.execute(sql, ("event_fields_sender_url", progress_json))
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/33/refreshtoken_device.sql
new file mode 100644
index 0000000000..290bd6da86
--- /dev/null
+++ b/synapse/storage/schema/delta/33/refreshtoken_device.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql
new file mode 100644
index 0000000000..bb225dafbf
--- /dev/null
+++ b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('refresh_tokens_device_index', '{}');
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py
new file mode 100644
index 0000000000..55ae43f395
--- /dev/null
+++ b/synapse/storage/schema/delta/33/remote_media_ts.py
@@ -0,0 +1,31 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+
+
+ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT"
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ cur.execute(ALTER_TABLE)
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ cur.execute(
+ database_engine.convert_param_style(
+ "UPDATE remote_media_cache SET last_access_ts = ?"
+ ),
+ (int(time.time() * 1000),)
+ )
diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/schema/delta/33/user_ips_index.sql
new file mode 100644
index 0000000000..473f75a78e
--- /dev/null
+++ b/synapse/storage/schema/delta/33/user_ips_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('user_ips_device_index', '{}');
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 0224299625..12941d1775 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -21,6 +21,7 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
import re
+import ujson as json
logger = logging.getLogger(__name__)
@@ -52,7 +53,7 @@ class SearchStore(BackgroundUpdateStore):
def reindex_search_txn(txn):
sql = (
- "SELECT stream_ordering, event_id FROM events"
+ "SELECT stream_ordering, event_id, room_id, type, content FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
@@ -61,28 +62,30 @@ class SearchStore(BackgroundUpdateStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
- rows = txn.fetchall()
+ rows = self.cursor_to_dict(txn)
if not rows:
return 0
- min_stream_id = rows[-1][0]
- event_ids = [row[1] for row in rows]
-
- events = self._get_events_txn(txn, event_ids)
+ min_stream_id = rows[-1]["stream_ordering"]
event_search_rows = []
- for event in events:
+ for row in rows:
try:
- event_id = event.event_id
- room_id = event.room_id
- content = event.content
- if event.type == "m.room.message":
+ event_id = row["event_id"]
+ room_id = row["room_id"]
+ etype = row["type"]
+ try:
+ content = json.loads(row["content"])
+ except:
+ continue
+
+ if etype == "m.room.message":
key = "content.body"
value = content["body"]
- elif event.type == "m.room.topic":
+ elif etype == "m.room.topic":
key = "content.topic"
value = content["topic"]
- elif event.type == "m.room.name":
+ elif etype == "m.room.name":
key = "content.name"
value = content["name"]
except (KeyError, AttributeError):
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index b10f2a5787..ea6823f18d 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -19,17 +19,24 @@ from ._base import SQLBaseStore
from unpaddedbase64 import encode_base64
from synapse.crypto.event_signing import compute_event_reference_hash
+from synapse.util.caches.descriptors import cached, cachedList
class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes"""
+ @cached(lru=True)
+ def get_event_reference_hash(self, event_id):
+ return self._get_event_reference_hashes_txn(event_id)
+
+ @cachedList(cached_method_name="get_event_reference_hash",
+ list_name="event_ids", num_args=1)
def get_event_reference_hashes(self, event_ids):
def f(txn):
- return [
- self._get_event_reference_hashes_txn(txn, ev)
- for ev in event_ids
- ]
+ return {
+ event_id: self._get_event_reference_hashes_txn(txn, event_id)
+ for event_id in event_ids
+ }
return self.runInteraction(
"get_event_reference_hashes",
@@ -41,15 +48,15 @@ class SignatureStore(SQLBaseStore):
hashes = yield self.get_event_reference_hashes(
event_ids
)
- hashes = [
- {
+ hashes = {
+ e_id: {
k: encode_base64(v) for k, v in h.items()
if k == "sha256"
}
- for h in hashes
- ]
+ for e_id, h in hashes.items()
+ }
- defer.returnValue(zip(event_ids, hashes))
+ defer.returnValue(hashes.items())
def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 95b12559a6..862c5c3ea1 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -40,6 +40,7 @@ from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
@@ -54,26 +55,92 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
-def lower_bound(token):
+def lower_bound(token, engine, inclusive=False):
+ inclusive = "=" if inclusive else ""
if token.topological is None:
- return "(%d < %s)" % (token.stream, "stream_ordering")
+ return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
- return "(%d < %s OR (%d = %s AND %d < %s))" % (
+ if isinstance(engine, PostgresEngine):
+ # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
+ # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
+ # use the later form when running against postgres.
+ return "((%d,%d) <%s (%s,%s))" % (
+ token.topological, token.stream, inclusive,
+ "topological_ordering", "stream_ordering",
+ )
+ return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
- token.stream, "stream_ordering",
+ token.stream, inclusive, "stream_ordering",
)
-def upper_bound(token):
+def upper_bound(token, engine, inclusive=True):
+ inclusive = "=" if inclusive else ""
if token.topological is None:
- return "(%d >= %s)" % (token.stream, "stream_ordering")
+ return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
- return "(%d > %s OR (%d = %s AND %d >= %s))" % (
+ if isinstance(engine, PostgresEngine):
+ # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
+ # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
+ # use the later form when running against postgres.
+ return "((%d,%d) >%s (%s,%s))" % (
+ token.topological, token.stream, inclusive,
+ "topological_ordering", "stream_ordering",
+ )
+ return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
- token.stream, "stream_ordering",
+ token.stream, inclusive, "stream_ordering",
+ )
+
+
+def filter_to_clause(event_filter):
+ # NB: This may create SQL clauses that don't optimise well (and we don't
+ # have indices on all possible clauses). E.g. it may create
+ # "room_id == X AND room_id != X", which postgres doesn't optimise.
+
+ if not event_filter:
+ return "", []
+
+ clauses = []
+ args = []
+
+ if event_filter.types:
+ clauses.append(
+ "(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
)
+ args.extend(event_filter.types)
+
+ for typ in event_filter.not_types:
+ clauses.append("type != ?")
+ args.append(typ)
+
+ if event_filter.senders:
+ clauses.append(
+ "(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
+ )
+ args.extend(event_filter.senders)
+
+ for sender in event_filter.not_senders:
+ clauses.append("sender != ?")
+ args.append(sender)
+
+ if event_filter.rooms:
+ clauses.append(
+ "(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
+ )
+ args.extend(event_filter.rooms)
+
+ for room_id in event_filter.not_rooms:
+ clauses.append("room_id != ?")
+ args.append(room_id)
+
+ if event_filter.contains_url:
+ clauses.append("contains_url = ?")
+ args.append(event_filter.contains_url)
+
+ return " AND ".join(clauses), args
class StreamStore(SQLBaseStore):
@@ -132,29 +199,25 @@ class StreamStore(SQLBaseStore):
return True
return False
- ret = self._get_events_txn(
- txn,
- # apply the filter on the room id list
- [
- r["event_id"] for r in rows
- if app_service_interested(r)
- ],
- get_prev_content=True
- )
+ return [r for r in rows if app_service_interested(r)]
- self._set_before_and_after(ret, rows)
+ rows = yield self.runInteraction("get_appservice_room_stream", f)
- if rows:
- key = "s%d" % max(r["stream_ordering"] for r in rows)
- else:
- # Assume we didn't get anything because there was nothing to
- # get.
- key = to_key
+ ret = yield self._get_events(
+ [r["event_id"] for r in rows],
+ get_prev_content=True
+ )
- return ret, key
+ self._set_before_and_after(ret, rows, topo_order=from_id is None)
- results = yield self.runInteraction("get_appservice_room_stream", f)
- defer.returnValue(results)
+ if rows:
+ key = "s%d" % max(r["stream_ordering"] for r in rows)
+ else:
+ # Assume we didn't get anything because there was nothing to
+ # get.
+ key = to_key
+
+ defer.returnValue((ret, key))
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
@@ -305,25 +368,35 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None,
- direction='b', limit=-1):
+ direction='b', limit=-1, event_filter=None):
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
args = [False, room_id]
if direction == 'b':
order = "DESC"
- bounds = upper_bound(RoomStreamToken.parse(from_key))
+ bounds = upper_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
if to_key:
- bounds = "%s AND %s" % (
- bounds, lower_bound(RoomStreamToken.parse(to_key))
- )
+ bounds = "%s AND %s" % (bounds, lower_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
else:
order = "ASC"
- bounds = lower_bound(RoomStreamToken.parse(from_key))
+ bounds = lower_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
if to_key:
- bounds = "%s AND %s" % (
- bounds, upper_bound(RoomStreamToken.parse(to_key))
- )
+ bounds = "%s AND %s" % (bounds, upper_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
+
+ filter_clause, filter_args = filter_to_clause(event_filter)
+
+ if filter_clause:
+ bounds += " AND " + filter_clause
+ args.extend(filter_args)
if int(limit) > 0:
args.append(int(limit))
@@ -491,13 +564,13 @@ class StreamStore(SQLBaseStore):
row["topological_ordering"], row["stream_ordering"],)
)
- def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
+ def get_max_topological_token(self, room_id, stream_key):
sql = (
"SELECT max(topological_ordering) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
return self._execute(
- "get_max_topological_token_for_stream_and_room", None,
+ "get_max_topological_token", None,
sql, room_id, stream_key,
).addCallback(
lambda r: r[0][0] if r else 0
@@ -590,32 +663,60 @@ class StreamStore(SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
- stream_ordering = results["stream_ordering"]
- topological_ordering = results["topological_ordering"]
-
- query_before = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND (topological_ordering < ?"
- " OR (topological_ordering = ? AND stream_ordering < ?))"
- " ORDER BY topological_ordering DESC, stream_ordering DESC"
- " LIMIT ?"
+ token = RoomStreamToken(
+ results["topological_ordering"],
+ results["stream_ordering"],
)
- query_after = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND (topological_ordering > ?"
- " OR (topological_ordering = ? AND stream_ordering > ?))"
- " ORDER BY topological_ordering ASC, stream_ordering ASC"
- " LIMIT ?"
- )
+ if isinstance(self.database_engine, Sqlite3Engine):
+ # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)``
+ # So we give pass it to SQLite3 as the UNION ALL of the two queries.
+
+ query_before = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering < ?"
+ " UNION ALL"
+ " SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
+ " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+ )
+ before_args = (
+ room_id, token.topological,
+ room_id, token.topological, token.stream,
+ before_limit,
+ )
- txn.execute(
- query_before,
- (
- room_id, topological_ordering, topological_ordering,
- stream_ordering, before_limit,
+ query_after = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering > ?"
+ " UNION ALL"
+ " SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
+ " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
)
- )
+ after_args = (
+ room_id, token.topological,
+ room_id, token.topological, token.stream,
+ after_limit,
+ )
+ else:
+ query_before = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND %s"
+ " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+ ) % (upper_bound(token, self.database_engine, inclusive=False),)
+
+ before_args = (room_id, before_limit)
+
+ query_after = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND %s"
+ " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ after_args = (room_id, after_limit)
+
+ txn.execute(query_before, before_args)
rows = self.cursor_to_dict(txn)
events_before = [r["event_id"] for r in rows]
@@ -627,17 +728,11 @@ class StreamStore(SQLBaseStore):
))
else:
start_token = str(RoomStreamToken(
- topological_ordering,
- stream_ordering - 1,
+ token.topological,
+ token.stream - 1,
))
- txn.execute(
- query_after,
- (
- room_id, topological_ordering, topological_ordering,
- stream_ordering, after_limit,
- )
- )
+ txn.execute(query_after, after_args)
rows = self.cursor_to_dict(txn)
events_after = [r["event_id"] for r in rows]
@@ -648,10 +743,7 @@ class StreamStore(SQLBaseStore):
rows[-1]["stream_ordering"],
))
else:
- end_token = str(RoomStreamToken(
- topological_ordering,
- stream_ordering,
- ))
+ end_token = str(token)
return {
"before": {
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 9da23f34cb..5a2c1aa59b 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -68,6 +68,9 @@ class TagsStore(SQLBaseStore):
A deferred list of tuples of stream_id int, user_id string,
room_id string, tag string and content string.
"""
+ if last_id == current_id:
+ defer.returnValue([])
+
def get_all_updated_tags_txn(txn):
sql = (
"SELECT stream_id, user_id, room_id"
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 6c7481a728..6258ff1725 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -24,6 +24,7 @@ from collections import namedtuple
import itertools
import logging
+import ujson as json
logger = logging.getLogger(__name__)
@@ -101,7 +102,7 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
- return result["response_code"], result["response_json"]
+ return result["response_code"], json.loads(str(result["response_json"]))
else:
return None
|