diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b7637b5dc0..88546ad614 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -40,7 +40,7 @@ class SQLBaseStore(object):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
- self.database_engine = hs.database_engine
+ self.database_engine = database.engine
self.db = database
self.rand = random.SystemRandom()
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cafedd5c0d..d20df5f076 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -13,24 +13,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.database import Database
+import logging
+
+from synapse.storage.data_stores.state import StateGroupDataStore
+from synapse.storage.database import Database, make_conn
+from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
+logger = logging.getLogger(__name__)
+
class DataStores(object):
"""The various data stores.
These are low level interfaces to physical databases.
+
+ Attributes:
+ main (DataStore)
"""
- def __init__(self, main_store_class, db_conn, hs):
+ def __init__(self, main_store_class, hs):
# Note we pass in the main store class here as workers use a different main
# store.
- database = Database(hs)
- # Check that db is correctly configured.
- database.engine.check_database(db_conn.cursor())
+ self.databases = []
+
+ for database_config in hs.config.database.databases:
+ db_name = database_config.name
+ engine = create_engine(database_config.config)
+
+ with make_conn(database_config, engine) as db_conn:
+ logger.info("Preparing database %r...", db_name)
+
+ engine.check_database(db_conn.cursor())
+ prepare_database(
+ db_conn, engine, hs.config, data_stores=database_config.data_stores,
+ )
+
+ database = Database(hs, database_config, engine)
+
+ if "main" in database_config.data_stores:
+ logger.info("Starting 'main' data store")
+ self.main = main_store_class(database, db_conn, hs)
+
+ if "state" in database_config.data_stores:
+ logger.info("Starting 'state' data store")
+ self.state = StateGroupDataStore(database, db_conn, hs)
+
+ db_conn.commit()
- prepare_database(db_conn, database.engine, config=hs.config)
+ self.databases.append(database)
- self.main = main_store_class(database, db_conn, hs)
+ logger.info("Database %r prepared", db_name)
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 320c5b0f07..13f4c9c72e 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating
- if not self.hs.get_db_pool().running:
+ if not self.db.is_running():
return
to_update = self._batch_row_update
@@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
- self.db.simple_upsert_txn(
+ # this is always an update rather than an upsert: the row should
+ # already exist, and if it doesn't, that may be because it has been
+ # deleted, and we don't want to re-create it.
+ self.db.simple_update_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
- values={
+ updatevalues={
"user_agent": user_agent,
"last_seen": last_seen,
"ip": ip,
},
- lock=False,
)
except Exception as e:
# Failed to upsert, log and continue
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 85cfa16850..0613b49f4a 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -358,21 +358,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
- # Compatible method of performing an upsert
- sql = "SELECT stream_id FROM device_max_stream_id"
-
- txn.execute(sql)
- rows = txn.fetchone()
- if rows:
- db_stream_id = rows[0]
- if db_stream_id < stream_id:
- # Insert the new stream_id
- sql = "UPDATE device_max_stream_id SET stream_id = ?"
- else:
- # No rows, perform an insert
- sql = "INSERT INTO device_max_stream_id (stream_id) VALUES (?)"
-
- txn.execute(sql, (stream_id,))
+ sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
+ txn.execute(sql, (stream_id, stream_id))
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 38cd0ca9b8..e551606f9d 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -14,15 +14,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, List
+
from six import iteritems
from canonicaljson import encode_canonical_json, json
+from twisted.enterprise.adbapi import Connection
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being set: either 'master'
+ key_type (str): the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
@@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""Returns a user's cross-signing key.
Args:
- user_id (str): the user whose self-signing key is being requested
- key_type (str): the type of cross-signing key to get
+ user_id (str): the user whose key is being requested
+ key_type (str): the type of key that is being requested: either 'master'
+ for a master key, 'self_signing' for a self-signing key, or
+ 'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
the self-signing key will be included in the result
@@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
from_user_id,
)
+ @cached(num_args=1)
+ def _get_bare_e2e_cross_signing_keys(self, user_id):
+ """Dummy function. Only used to make a cache for
+ _get_bare_e2e_cross_signing_keys_bulk.
+ """
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_bare_e2e_cross_signing_keys",
+ list_name="user_ids",
+ num_args=1,
+ )
+ def _get_bare_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str]
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing keys for a set of users. The output of this
+ function should be passed to _get_e2e_cross_signing_signatures_txn if
+ the signatures for the calling user need to be fetched.
+
+ Args:
+ user_ids (list[str]): the users whose keys are being requested
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. If a user's cross-signing keys were not found, either
+ their user ID will not be in the dict, or their user ID will map
+ to None.
+
+ """
+ return self.db.runInteraction(
+ "get_bare_e2e_cross_signing_keys_bulk",
+ self._get_bare_e2e_cross_signing_keys_bulk_txn,
+ user_ids,
+ )
+
+ def _get_bare_e2e_cross_signing_keys_bulk_txn(
+ self, txn: Connection, user_ids: List[str],
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing keys for a set of users. The output of this
+ function should be passed to _get_e2e_cross_signing_signatures_txn if
+ the signatures for the calling user need to be fetched.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ user_ids (list[str]): the users whose keys are being requested
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. If a user's cross-signing keys were not found, their user
+ ID will not be in the dict.
+
+ """
+ result = {}
+
+ batch_size = 100
+ chunks = [
+ user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
+ ]
+ for user_chunk in chunks:
+ sql = """
+ SELECT k.user_id, k.keytype, k.keydata, k.stream_id
+ FROM e2e_cross_signing_keys k
+ INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
+ FROM e2e_cross_signing_keys
+ GROUP BY user_id, keytype) s
+ USING (user_id, stream_id, keytype)
+ WHERE k.user_id IN (%s)
+ """ % (
+ ",".join("?" for u in user_chunk),
+ )
+ query_params = []
+ query_params.extend(user_chunk)
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ for row in rows:
+ user_id = row["user_id"]
+ key_type = row["keytype"]
+ key = json.loads(row["keydata"])
+ user_info = result.setdefault(user_id, {})
+ user_info[key_type] = key
+
+ return result
+
+ def _get_e2e_cross_signing_signatures_txn(
+ self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing signatures made by a user on a set of keys.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ keys (dict[str, dict[str, dict]]): a map of user ID to key type to
+ key data. This dict will be modified to add signatures.
+ from_user_id (str): fetch the signatures made by this user
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. The return value will be the same as the keys argument,
+ with the modifications included.
+ """
+
+ # find out what cross-signing keys (a.k.a. devices) we need to get
+ # signatures for. This is a map of (user_id, device_id) to key type
+ # (device_id is the key's public part).
+ devices = {}
+
+ for user_id, user_info in keys.items():
+ if user_info is None:
+ continue
+ for key_type, key in user_info.items():
+ device_id = None
+ for k in key["keys"].values():
+ device_id = k
+ devices[(user_id, device_id)] = key_type
+
+ device_list = list(devices)
+
+ # split into batches
+ batch_size = 100
+ chunks = [
+ device_list[i : i + batch_size]
+ for i in range(0, len(device_list), batch_size)
+ ]
+ for user_chunk in chunks:
+ sql = """
+ SELECT target_user_id, target_device_id, key_id, signature
+ FROM e2e_cross_signing_signatures
+ WHERE user_id = ?
+ AND (%s)
+ """ % (
+ " OR ".join(
+ "(target_user_id = ? AND target_device_id = ?)" for d in devices
+ )
+ )
+ query_params = [from_user_id]
+ for item in devices:
+ # item is a (user_id, device_id) tuple
+ query_params.extend(item)
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ # and add the signatures to the appropriate keys
+ for row in rows:
+ key_id = row["key_id"]
+ target_user_id = row["target_user_id"]
+ target_device_id = row["target_device_id"]
+ key_type = devices[(target_user_id, target_device_id)]
+ # We need to copy everything, because the result may have come
+ # from the cache. dict.copy only does a shallow copy, so we
+ # need to recursively copy the dicts that will be modified.
+ user_info = keys[target_user_id] = keys[target_user_id].copy()
+ target_user_key = user_info[key_type] = user_info[key_type].copy()
+ if "signatures" in target_user_key:
+ signatures = target_user_key["signatures"] = target_user_key[
+ "signatures"
+ ].copy()
+ if from_user_id in signatures:
+ user_sigs = signatures[from_user_id] = signatures[from_user_id]
+ user_sigs[key_id] = row["signature"]
+ else:
+ signatures[from_user_id] = {key_id: row["signature"]}
+ else:
+ target_user_key["signatures"] = {
+ from_user_id: {key_id: row["signature"]}
+ }
+
+ return keys
+
+ @defer.inlineCallbacks
+ def get_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str], from_user_id: str = None
+ ) -> defer.Deferred:
+ """Returns the cross-signing keys for a set of users.
+
+ Args:
+ user_ids (list[str]): the users whose keys are being requested
+ from_user_id (str): if specified, signatures made by this user on
+ the self-signing keys will be included in the result
+
+ Returns:
+ Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
+ key data. If a user's cross-signing keys were not found, either
+ their user ID will not be in the dict, or their user ID will map
+ to None.
+ """
+
+ result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+
+ if from_user_id:
+ result = yield self.db.runInteraction(
+ "get_e2e_cross_signing_signatures",
+ self._get_e2e_cross_signing_signatures_txn,
+ result,
+ from_user_id,
+ )
+
+ return result
+
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
@@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
},
)
+ self._invalidate_cache_and_stream(
+ txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
+ )
+
def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 998bba1aad..58f35d7f56 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1757,163 +1757,6 @@ class EventsStore(
return state_groups
- def purge_unreferenced_state_groups(
- self, room_id: str, state_groups_to_delete
- ) -> defer.Deferred:
- """Deletes no longer referenced state groups and de-deltas any state
- groups that reference them.
-
- Args:
- room_id: The room the state groups belong to (must all be in the
- same room).
- state_groups_to_delete (Collection[int]): Set of all state groups
- to delete.
- """
-
- return self.db.runInteraction(
- "purge_unreferenced_state_groups",
- self._purge_unreferenced_state_groups,
- room_id,
- state_groups_to_delete,
- )
-
- def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
- logger.info(
- "[purge] found %i state groups to delete", len(state_groups_to_delete)
- )
-
- rows = self.db.simple_select_many_txn(
- txn,
- table="state_group_edges",
- column="prev_state_group",
- iterable=state_groups_to_delete,
- keyvalues={},
- retcols=("state_group",),
- )
-
- remaining_state_groups = set(
- row["state_group"]
- for row in rows
- if row["state_group"] not in state_groups_to_delete
- )
-
- logger.info(
- "[purge] de-delta-ing %i remaining state groups",
- len(remaining_state_groups),
- )
-
- # Now we turn the state groups that reference to-be-deleted state
- # groups to non delta versions.
- for sg in remaining_state_groups:
- logger.info("[purge] de-delta-ing remaining state group %s", sg)
- curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
- curr_state = curr_state[sg]
-
- self.db.simple_delete_txn(
- txn, table="state_groups_state", keyvalues={"state_group": sg}
- )
-
- self.db.simple_delete_txn(
- txn, table="state_group_edges", keyvalues={"state_group": sg}
- )
-
- self.db.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": sg,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(curr_state)
- ],
- )
-
- logger.info("[purge] removing redundant state groups")
- txn.executemany(
- "DELETE FROM state_groups_state WHERE state_group = ?",
- ((sg,) for sg in state_groups_to_delete),
- )
- txn.executemany(
- "DELETE FROM state_groups WHERE id = ?",
- ((sg,) for sg in state_groups_to_delete),
- )
-
- @defer.inlineCallbacks
- def get_previous_state_groups(self, state_groups):
- """Fetch the previous groups of the given state groups.
-
- Args:
- state_groups (Iterable[int])
-
- Returns:
- Deferred[dict[int, int]]: mapping from state group to previous
- state group.
- """
-
- rows = yield self.db.simple_select_many_batch(
- table="state_group_edges",
- column="prev_state_group",
- iterable=state_groups,
- keyvalues={},
- retcols=("prev_state_group", "state_group"),
- desc="get_previous_state_groups",
- )
-
- return {row["state_group"]: row["prev_state_group"] for row in rows}
-
- def purge_room_state(self, room_id, state_groups_to_delete):
- """Deletes all record of a room from state tables
-
- Args:
- room_id (str):
- state_groups_to_delete (list[int]): State groups to delete
- """
-
- return self.db.runInteraction(
- "purge_room_state",
- self._purge_room_state_txn,
- room_id,
- state_groups_to_delete,
- )
-
- def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
- # first we have to delete the state groups states
- logger.info("[purge] removing %s from state_groups_state", room_id)
-
- self.db.simple_delete_many_txn(
- txn,
- table="state_groups_state",
- column="state_group",
- iterable=state_groups_to_delete,
- keyvalues={},
- )
-
- # ... and the state group edges
- logger.info("[purge] removing %s from state_group_edges", room_id)
-
- self.db.simple_delete_many_txn(
- txn,
- table="state_group_edges",
- column="state_group",
- iterable=state_groups_to_delete,
- keyvalues={},
- )
-
- # ... and the state groups
- logger.info("[purge] removing %s from state_groups", room_id)
-
- self.db.simple_delete_many_txn(
- txn,
- table="state_groups",
- column="id",
- iterable=state_groups_to_delete,
- keyvalues={},
- )
-
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 9ee117ce0f..2c9142814c 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,8 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
+from typing import List, Optional
from canonicaljson import json
+from constantly import NamedConstant, Names
from twisted.internet import defer
@@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+class EventRedactBehaviour(Names):
+ """
+ What to do when retrieving a redacted event from the database.
+ """
+
+ AS_IS = NamedConstant()
+ REDACT = NamedConstant()
+ BLOCK = NamedConstant()
+
+
class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
@@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_event(
self,
- event_id,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
- allow_none=False,
- check_room_id=None,
+ event_id: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: bool = False,
+ check_room_id: Optional[str] = None,
):
"""Get an event from the database by event_id.
Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_id: The event_id of the event to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
+ allow_rejected: If True return rejected events.
+ allow_none: If True, return None if no event found, if
False throw a NotFoundError
- check_room_id (str|None): if not None, check the room of the found event.
+ check_room_id: if not None, check the room of the found event.
If there is a mismatch, behave as per allow_none.
Returns:
@@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore):
events = yield self.get_events_as_list(
[event_id],
- check_redacted=check_redacted,
+ redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events(
self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
):
"""Get events from the database
Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_ids: The event_ids of the events to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible
+ values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
+ allow_rejected: If True return rejected events.
Returns:
Deferred : Dict from event_id to event.
"""
events = yield self.get_events_as_list(
event_ids,
- check_redacted=check_redacted,
+ redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events_as_list(
self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
):
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_ids: The event_ids of the events to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
+ allow_rejected: If True, return rejected events.
Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The
@@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore):
# Update the cache to save doing the checks again.
entry.event.internal_metadata.recheck_redaction = False
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
+ event = entry.event
+
+ if entry.redacted_event:
+ if redact_behaviour == EventRedactBehaviour.BLOCK:
+ # Skip this event
+ continue
+ elif redact_behaviour == EventRedactBehaviour.REDACT:
+ event = entry.redacted_event
events.append(event)
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 5ba13aa973..e2673ae073 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -244,7 +244,7 @@ class PushRulesWorkerStore(
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield context.get_current_state_ids(self)
+ current_state_ids = yield context.get_current_state_ids()
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index f07309ef09..6b03233262 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -15,8 +15,7 @@
# limitations under the License.
import logging
-
-import six
+from typing import Iterable, Iterator
from canonicaljson import encode_canonical_json, json
@@ -27,21 +26,16 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
logger = logging.getLogger(__name__)
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
-
class PusherWorkerStore(SQLBaseStore):
- def _decode_pushers_rows(self, rows):
+ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+ """JSON-decode the data in the rows returned from the `pushers` table
+
+ Drops any rows whose data cannot be decoded
+ """
for r in rows:
dataJson = r["data"]
- r["data"] = None
try:
- if isinstance(dataJson, db_binary_type):
- dataJson = str(dataJson).decode("UTF8")
-
r["data"] = json.loads(dataJson)
except Exception as e:
logger.warning(
@@ -50,12 +44,9 @@ class PusherWorkerStore(SQLBaseStore):
dataJson,
e.args[0],
)
- pass
-
- if isinstance(r["pushkey"], db_binary_type):
- r["pushkey"] = str(r["pushkey"]).decode("UTF8")
+ continue
- return rows
+ yield r
@defer.inlineCallbacks
def user_has_pusher(self, user_id):
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 92e3b9c512..70ff5751b6 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -477,7 +477,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield context.get_current_state_ids(self)
+ current_state_ids = yield context.get_current_state_ids()
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
index 4219cdd06a..2de50d408c 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
@@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream
DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY
-DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT
diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
new file mode 100644
index 0000000000..c2f557fde9
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This line already existed in deltas/35/device_stream_id but was not included in the
+-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist
+INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS (
+ SELECT * from device_max_stream_id
+);
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
new file mode 100644
index 0000000000..4f24c1405d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
@@ -0,0 +1,29 @@
+/* Copyright 2019 Werner Sembach
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made.
+DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
index 4ad2929f32..889a9a0ce4 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
@@ -975,40 +975,6 @@ CREATE TABLE state_events (
-CREATE TABLE state_group_edges (
- state_group bigint NOT NULL,
- prev_state_group bigint NOT NULL
-);
-
-
-
-CREATE SEQUENCE state_group_id_seq
- START WITH 1
- INCREMENT BY 1
- NO MINVALUE
- NO MAXVALUE
- CACHE 1;
-
-
-
-CREATE TABLE state_groups (
- id bigint NOT NULL,
- room_id text NOT NULL,
- event_id text NOT NULL
-);
-
-
-
-CREATE TABLE state_groups_state (
- state_group bigint NOT NULL,
- room_id text NOT NULL,
- type text NOT NULL,
- state_key text NOT NULL,
- event_id text NOT NULL
-);
-
-
-
CREATE TABLE stats_stream_pos (
lock character(1) DEFAULT 'X'::bpchar NOT NULL,
stream_id bigint,
@@ -1482,12 +1448,6 @@ ALTER TABLE ONLY state_events
ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id);
-
-ALTER TABLE ONLY state_groups
- ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id);
-
-
-
ALTER TABLE ONLY stats_stream_pos
ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock);
@@ -1928,18 +1888,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts);
-CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group);
-
-
-
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group);
-
-
-
-CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key);
-
-
-
CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering);
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
index bad33291e7..a0411ede7e 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
@@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id);
CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) );
CREATE INDEX room_depth_room ON room_depth(room_id);
-CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
-CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL );
CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) );
CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) );
CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) );
@@ -120,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL );
CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT);
CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id );
CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id );
-CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL );
-CREATE INDEX state_group_edges_idx ON state_group_edges(state_group);
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group);
CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering );
CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering );
@@ -254,6 +249,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen);
CREATE INDEX users_creation_ts ON users (creation_ts);
CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group);
CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id);
-CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key);
CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id);
CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip);
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
index c265fd20e2..91d21b2921 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
@@ -5,3 +5,4 @@ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coales
INSERT INTO user_directory_stream_pos (stream_id) VALUES (0);
INSERT INTO stats_stream_pos (stream_id) VALUES (0);
INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0);
+-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
new file mode 100644
index 0000000000..bbd3f18604
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,13 @@
+# Building full schema dumps
+
+These schemas need to be made from a database that has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce
+`full.sql.postgres ` and `full.sql.sqlite` files.
+
+Ensure postgres is installed and your user has the ability to run bash commands
+such as `createdb`.
+
+```
+./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
+```
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt
deleted file mode 100644
index d3f6401344..0000000000
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Building full schema dumps
-==========================
-
-These schemas need to be made from a database that has had all background updates run.
-
-Postgres
---------
-
-$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
-
-SQLite
-------
-
-$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
-
-After
------
-
-Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 4eec2fae5e..47ebb8a214 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -384,7 +385,7 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
- search_query = search_query = _parse_query(self.database_engine, search_term)
+ search_query = _parse_query(self.database_engine, search_term)
args = []
@@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self.get_events_as_list([r["event_id"] for r in results])
+ # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+ # search results (which is a data leak)
+ events = yield self.get_events_as_list(
+ [r["event_id"] for r in results],
+ redact_behaviour=EventRedactBehaviour.BLOCK,
+ )
event_map = {ev.event_id: ev for ev in events}
@@ -495,7 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
- search_query = search_query = _parse_query(self.database_engine, search_term)
+ search_query = _parse_query(self.database_engine, search_term)
args = []
@@ -600,7 +606,12 @@ class SearchStore(SearchBackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self.get_events_as_list([r["event_id"] for r in results])
+ # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+ # search results (which is a data leak)
+ events = yield self.get_events_as_list(
+ [r["event_id"] for r in results],
+ redact_behaviour=EventRedactBehaviour.BLOCK,
+ )
event_map = {ev.event_id: ev for ev in events}
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 9ef7b48c74..0dc39f139c 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -17,8 +17,7 @@ import logging
from collections import namedtuple
from typing import Iterable, Tuple
-from six import iteritems, itervalues
-from six.moves import range
+from six import iteritems
from twisted.internet import defer
@@ -29,11 +28,9 @@ from synapse.events.snapshot import EventContext
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.database import Database
-from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
-from synapse.util.caches import get_cache_factor_for, intern_string
+from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__)
@@ -55,207 +52,14 @@ class _GetStateGroupDelta(
return len(self.delta_ids) if self.delta_ids else 0
-class StateGroupBackgroundUpdateStore(SQLBaseStore):
- """Defines functions related to state groups needed to run the state backgroud
- updates.
- """
-
- def _count_state_group_hops_txn(self, txn, state_group):
- """Given a state group, count how many hops there are in the tree.
-
- This is used to ensure the delta chains don't get too long.
- """
- if isinstance(self.database_engine, PostgresEngine):
- sql = """
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT count(*) FROM state;
- """
-
- txn.execute(sql, (state_group,))
- row = txn.fetchone()
- if row and row[0]:
- return row[0]
- else:
- return 0
- else:
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
- count = 0
-
- while next_group:
- next_group = self.db.simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
- if next_group:
- count += 1
-
- return count
-
- def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter=StateFilter.all()
- ):
- results = {group: {} for group in groups}
-
- where_clause, where_args = state_filter.make_sql_filter_clause()
-
- # Unless the filter clause is empty, we're going to append it after an
- # existing where clause
- if where_clause:
- where_clause = " AND (%s)" % (where_clause,)
-
- if isinstance(self.database_engine, PostgresEngine):
- # Temporarily disable sequential scans in this transaction. This is
- # a temporary hack until we can add the right indices in
- txn.execute("SET LOCAL enable_seqscan=off")
-
- # The below query walks the state_group tree so that the "state"
- # table includes all state_groups in the tree. It then joins
- # against `state_groups_state` to fetch the latest state.
- # It assumes that previous state groups are always numerically
- # lesser.
- # The PARTITION is used to get the event_id in the greatest state
- # group for the given type, state_key.
- # This may return multiple rows per (type, state_key), but last_value
- # should be the same.
- sql = """
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT DISTINCT type, state_key, last_value(event_id) OVER (
- PARTITION BY type, state_key ORDER BY state_group ASC
- ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
- ) AS event_id FROM state_groups_state
- WHERE state_group IN (
- SELECT state_group FROM state
- )
- """
-
- for group in groups:
- args = [group]
- args.extend(where_args)
-
- txn.execute(sql + where_clause, args)
- for row in txn:
- typ, state_key, event_id = row
- key = (typ, state_key)
- results[group][key] = event_id
- else:
- max_entries_returned = state_filter.max_entries_returned()
-
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- for group in groups:
- next_group = group
-
- while next_group:
- # We did this before by getting the list of group ids, and
- # then passing that list to sqlite to get latest event for
- # each (type, state_key). However, that was terribly slow
- # without the right indices (which we can't add until
- # after we finish deduping state, which requires this func)
- args = [next_group]
- args.extend(where_args)
-
- txn.execute(
- "SELECT type, state_key, event_id FROM state_groups_state"
- " WHERE state_group = ? " + where_clause,
- args,
- )
- results[group].update(
- ((typ, state_key), event_id)
- for typ, state_key, event_id in txn
- if (typ, state_key) not in results[group]
- )
-
- # If the number of entries in the (type,state_key)->event_id dict
- # matches the number of (type,state_keys) types we were searching
- # for, then we must have found them all, so no need to go walk
- # further down the tree... UNLESS our types filter contained
- # wildcards (i.e. Nones) in which case we have to do an exhaustive
- # search
- if (
- max_entries_returned is not None
- and len(results[group]) == max_entries_returned
- ):
- break
-
- next_group = self.db.simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
-
- return results
-
-
# this inherits from EventsWorkerStore because it calls self.get_events
-class StateGroupWorkerStore(
- EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
-):
+class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
"""
- STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
- STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
- CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
def __init__(self, database: Database, db_conn, hs):
super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
- # Originally the state store used a single DictionaryCache to cache the
- # event IDs for the state types in a given state group to avoid hammering
- # on the state_group* tables.
- #
- # The point of using a DictionaryCache is that it can cache a subset
- # of the state events for a given state group (i.e. a subset of the keys for a
- # given dict which is an entry in the cache for a given state group ID).
- #
- # However, this poses problems when performing complicated queries
- # on the store - for instance: "give me all the state for this group, but
- # limit members to this subset of users", as DictionaryCache's API isn't
- # rich enough to say "please cache any of these fields, apart from this subset".
- # This is problematic when lazy loading members, which requires this behaviour,
- # as without it the cache has no choice but to speculatively load all
- # state events for the group, which negates the efficiency being sought.
- #
- # Rather than overcomplicating DictionaryCache's API, we instead split the
- # state_group_cache into two halves - one for tracking non-member events,
- # and the other for tracking member_events. This means that lazy loading
- # queries can be made in a cache-friendly manner by querying both caches
- # separately and then merging the result. So for the example above, you
- # would query the members cache for a specific subset of state keys
- # (which DictionaryCache will handle efficiently and fine) and the non-members
- # cache for all state (which DictionaryCache will similarly handle fine)
- # and then just merge the results together.
- #
- # We size the non-members cache to be smaller than the members cache as the
- # vast majority of state in Matrix (today) is member events.
-
- self._state_group_cache = DictionaryCache(
- "*stateGroupCache*",
- # TODO: this hasn't been tuned yet
- 50000 * get_cache_factor_for("stateGroupCache"),
- )
- self._state_group_members_cache = DictionaryCache(
- "*stateGroupMembersCache*",
- 500000 * get_cache_factor_for("stateGroupMembersCache"),
- )
-
@defer.inlineCallbacks
def get_room_version(self, room_id):
"""Get the room_version of a given room
@@ -278,7 +82,7 @@ class StateGroupWorkerStore(
@defer.inlineCallbacks
def get_room_predecessor(self, room_id):
- """Get the predecessor room of an upgraded room if one exists.
+ """Get the predecessor of an upgraded room if it exists.
Otherwise return None.
Args:
@@ -291,14 +95,22 @@ class StateGroupWorkerStore(
* room_id (str): The room ID of the predecessor room
* event_id (str): The ID of the tombstone event in the predecessor room
+ None if a predecessor key is not found, or is not a dictionary.
+
Raises:
- NotFoundError if the room is unknown
+ NotFoundError if the given room is unknown
"""
# Retrieve the room's create event
create_event = yield self.get_create_event_for_room(room_id)
- # Return predecessor if present
- return create_event.content.get("predecessor", None)
+ # Retrieve the predecessor key of the create event
+ predecessor = create_event.content.get("predecessor", None)
+
+ # Ensure the key is a dictionary
+ if not isinstance(predecessor, dict):
+ return None
+
+ return predecessor
@defer.inlineCallbacks
def get_create_event_for_room(self, room_id):
@@ -318,7 +130,7 @@ class StateGroupWorkerStore(
# If we can't find the create event, assume we've hit a dead end
if not create_id:
- raise NotFoundError("Unknown room %s" % (room_id))
+ raise NotFoundError("Unknown room %s" % (room_id,))
# Retrieve the room's create event and return
create_event = yield self.get_event(create_id)
@@ -423,229 +235,6 @@ class StateGroupWorkerStore(
return event.content.get("canonical_alias")
- @cached(max_entries=10000, iterable=True)
- def get_state_group_delta(self, state_group):
- """Given a state group try to return a previous group and a delta between
- the old and the new.
-
- Returns:
- (prev_group, delta_ids), where both may be None.
- """
-
- def _get_state_group_delta_txn(txn):
- prev_group = self.db.simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": state_group},
- retcol="prev_state_group",
- allow_none=True,
- )
-
- if not prev_group:
- return _GetStateGroupDelta(None, None)
-
- delta_ids = self.db.simple_select_list_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- retcols=("type", "state_key", "event_id"),
- )
-
- return _GetStateGroupDelta(
- prev_group,
- {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
- )
-
- return self.db.runInteraction(
- "get_state_group_delta", _get_state_group_delta_txn
- )
-
- @defer.inlineCallbacks
- def get_state_groups_ids(self, _room_id, event_ids):
- """Get the event IDs of all the state for the state groups for the given events
-
- Args:
- _room_id (str): id of the room for these events
- event_ids (iterable[str]): ids of the events
-
- Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- """
- if not event_ids:
- return {}
-
- event_to_groups = yield self._get_state_group_for_events(event_ids)
-
- groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups)
-
- return group_to_state
-
- @defer.inlineCallbacks
- def get_state_ids_for_group(self, state_group):
- """Get the event IDs of all the state in the given state group
-
- Args:
- state_group (int)
-
- Returns:
- Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
- """
- group_to_state = yield self._get_state_for_groups((state_group,))
-
- return group_to_state[state_group]
-
- @defer.inlineCallbacks
- def get_state_groups(self, room_id, event_ids):
- """ Get the state groups for the given list of event_ids
-
- Returns:
- Deferred[dict[int, list[EventBase]]]:
- dict of state_group_id -> list of state events.
- """
- if not event_ids:
- return {}
-
- group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
-
- state_event_map = yield self.get_events(
- [
- ev_id
- for group_ids in itervalues(group_to_ids)
- for ev_id in itervalues(group_ids)
- ],
- get_prev_content=False,
- )
-
- return {
- group: [
- state_event_map[v]
- for v in itervalues(event_id_map)
- if v in state_event_map
- ]
- for group, event_id_map in iteritems(group_to_ids)
- }
-
- @defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, state_filter):
- """Returns the state groups for a given set of groups, filtering on
- types of state events.
-
- Args:
- groups(list[int]): list of state group IDs to query
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
- Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- """
- results = {}
-
- chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
- for chunk in chunks:
- res = yield self.db.runInteraction(
- "_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn,
- chunk,
- state_filter,
- )
- results.update(res)
-
- return results
-
- @defer.inlineCallbacks
- def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
- """Given a list of event_ids and type tuples, return a list of state
- dicts for each event.
-
- Args:
- event_ids (list[string])
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
- """
- event_to_groups = yield self._get_state_group_for_events(event_ids)
-
- groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, state_filter)
-
- state_event_map = yield self.get_events(
- [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
- get_prev_content=False,
- )
-
- event_to_state = {
- event_id: {
- k: state_event_map[v]
- for k, v in iteritems(group_to_state[group])
- if v in state_event_map
- }
- for event_id, group in iteritems(event_to_groups)
- }
-
- return {event: event_to_state[event] for event in event_ids}
-
- @defer.inlineCallbacks
- def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
- """
- Get the state dicts corresponding to a list of events, containing the event_ids
- of the state events (as opposed to the events themselves)
-
- Args:
- event_ids(list(str)): events whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- A deferred dict from event_id -> (type, state_key) -> event_id
- """
- event_to_groups = yield self._get_state_group_for_events(event_ids)
-
- groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, state_filter)
-
- event_to_state = {
- event_id: group_to_state[group]
- for event_id, group in iteritems(event_to_groups)
- }
-
- return {event: event_to_state[event] for event in event_ids}
-
- @defer.inlineCallbacks
- def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
- """
- Get the state dict corresponding to a particular event
-
- Args:
- event_id(str): event whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- A deferred dict from (type, state_key) -> state_event
- """
- state_map = yield self.get_state_for_events([event_id], state_filter)
- return state_map[event_id]
-
- @defer.inlineCallbacks
- def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
- """
- Get the state dict corresponding to a particular event
-
- Args:
- event_id(str): event whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- A deferred dict from (type, state_key) -> state_event
- """
- state_map = yield self.get_state_ids_for_events([event_id], state_filter)
- return state_map[event_id]
-
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
return self.db.simple_select_one_onecol(
@@ -676,329 +265,6 @@ class StateGroupWorkerStore(
return {row["event_id"]: row["state_group"] for row in rows}
- def _get_state_for_group_using_cache(self, cache, group, state_filter):
- """Checks if group is in cache. See `_get_state_for_groups`
-
- Args:
- cache(DictionaryCache): the state group cache to use
- group(int): The state group to lookup
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns 2-tuple (`state_dict`, `got_all`).
- `got_all` is a bool indicating if we successfully retrieved all
- requests state from the cache, if False we need to query the DB for the
- missing state.
- """
- is_all, known_absent, state_dict_ids = cache.get(group)
-
- if is_all or state_filter.is_full():
- # Either we have everything or want everything, either way
- # `is_all` tells us whether we've gotten everything.
- return state_filter.filter_state(state_dict_ids), is_all
-
- # tracks whether any of our requested types are missing from the cache
- missing_types = False
-
- if state_filter.has_wildcards():
- # We don't know if we fetched all the state keys for the types in
- # the filter that are wildcards, so we have to assume that we may
- # have missed some.
- missing_types = True
- else:
- # There aren't any wild cards, so `concrete_types()` returns the
- # complete list of event types we're wanting.
- for key in state_filter.concrete_types():
- if key not in state_dict_ids and key not in known_absent:
- missing_types = True
- break
-
- return state_filter.filter_state(state_dict_ids), not missing_types
-
- @defer.inlineCallbacks
- def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
- """Gets the state at each of a list of state groups, optionally
- filtering by type/state_key
-
- Args:
- groups (iterable[int]): list of state groups for which we want
- to get the state.
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
- Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- """
-
- member_filter, non_member_filter = state_filter.get_member_split()
-
- # Now we look them up in the member and non-member caches
- (
- non_member_state,
- incomplete_groups_nm,
- ) = yield self._get_state_for_groups_using_cache(
- groups, self._state_group_cache, state_filter=non_member_filter
- )
-
- (
- member_state,
- incomplete_groups_m,
- ) = yield self._get_state_for_groups_using_cache(
- groups, self._state_group_members_cache, state_filter=member_filter
- )
-
- state = dict(non_member_state)
- for group in groups:
- state[group].update(member_state[group])
-
- # Now fetch any missing groups from the database
-
- incomplete_groups = incomplete_groups_m | incomplete_groups_nm
-
- if not incomplete_groups:
- return state
-
- cache_sequence_nm = self._state_group_cache.sequence
- cache_sequence_m = self._state_group_members_cache.sequence
-
- # Help the cache hit ratio by expanding the filter a bit
- db_state_filter = state_filter.return_expanded()
-
- group_to_state_dict = yield self._get_state_groups_from_groups(
- list(incomplete_groups), state_filter=db_state_filter
- )
-
- # Now lets update the caches
- self._insert_into_cache(
- group_to_state_dict,
- db_state_filter,
- cache_seq_num_members=cache_sequence_m,
- cache_seq_num_non_members=cache_sequence_nm,
- )
-
- # And finally update the result dict, by filtering out any extra
- # stuff we pulled out of the database.
- for group, group_state_dict in iteritems(group_to_state_dict):
- # We just replace any existing entries, as we will have loaded
- # everything we need from the database anyway.
- state[group] = state_filter.filter_state(group_state_dict)
-
- return state
-
- def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
- """Gets the state at each of a list of state groups, optionally
- filtering by type/state_key, querying from a specific cache.
-
- Args:
- groups (iterable[int]): list of state groups for which we want
- to get the state.
- cache (DictionaryCache): the cache of group ids to state dicts which
- we will pass through - either the normal state cache or the specific
- members state cache.
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- of entries in the cache, and the state group ids either missing
- from the cache or incomplete.
- """
- results = {}
- incomplete_groups = set()
- for group in set(groups):
- state_dict_ids, got_all = self._get_state_for_group_using_cache(
- cache, group, state_filter
- )
- results[group] = state_dict_ids
-
- if not got_all:
- incomplete_groups.add(group)
-
- return results, incomplete_groups
-
- def _insert_into_cache(
- self,
- group_to_state_dict,
- state_filter,
- cache_seq_num_members,
- cache_seq_num_non_members,
- ):
- """Inserts results from querying the database into the relevant cache.
-
- Args:
- group_to_state_dict (dict): The new entries pulled from database.
- Map from state group to state dict
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
- cache_seq_num_members (int): Sequence number of member cache since
- last lookup in cache
- cache_seq_num_non_members (int): Sequence number of member cache since
- last lookup in cache
- """
-
- # We need to work out which types we've fetched from the DB for the
- # member vs non-member caches. This should be as accurate as possible,
- # but can be an underestimate (e.g. when we have wild cards)
-
- member_filter, non_member_filter = state_filter.get_member_split()
- if member_filter.is_full():
- # We fetched all member events
- member_types = None
- else:
- # `concrete_types()` will only return a subset when there are wild
- # cards in the filter, but that's fine.
- member_types = member_filter.concrete_types()
-
- if non_member_filter.is_full():
- # We fetched all non member events
- non_member_types = None
- else:
- non_member_types = non_member_filter.concrete_types()
-
- for group, group_state_dict in iteritems(group_to_state_dict):
- state_dict_members = {}
- state_dict_non_members = {}
-
- for k, v in iteritems(group_state_dict):
- if k[0] == EventTypes.Member:
- state_dict_members[k] = v
- else:
- state_dict_non_members[k] = v
-
- self._state_group_members_cache.update(
- cache_seq_num_members,
- key=group,
- value=state_dict_members,
- fetched_keys=member_types,
- )
-
- self._state_group_cache.update(
- cache_seq_num_non_members,
- key=group,
- value=state_dict_non_members,
- fetched_keys=non_member_types,
- )
-
- def store_state_group(
- self, event_id, room_id, prev_group, delta_ids, current_state_ids
- ):
- """Store a new set of state, returning a newly assigned state group.
-
- Args:
- event_id (str): The event ID for which the state was calculated
- room_id (str)
- prev_group (int|None): A previous state group for the room, optional.
- delta_ids (dict|None): The delta between state at `prev_group` and
- `current_state_ids`, if `prev_group` was given. Same format as
- `current_state_ids`.
- current_state_ids (dict): The state to store. Map of (type, state_key)
- to event_id.
-
- Returns:
- Deferred[int]: The state group ID
- """
-
- def _store_state_group_txn(txn):
- if current_state_ids is None:
- # AFAIK, this can never happen
- raise Exception("current_state_ids cannot be None")
-
- state_group = self.database_engine.get_next_state_group_id(txn)
-
- self.db.simple_insert_txn(
- txn,
- table="state_groups",
- values={"id": state_group, "room_id": room_id, "event_id": event_id},
- )
-
- # We persist as a delta if we can, while also ensuring the chain
- # of deltas isn't tooo long, as otherwise read performance degrades.
- if prev_group:
- is_in_db = self.db.simple_select_one_onecol_txn(
- txn,
- table="state_groups",
- keyvalues={"id": prev_group},
- retcol="id",
- allow_none=True,
- )
- if not is_in_db:
- raise Exception(
- "Trying to persist state with unpersisted prev_group: %r"
- % (prev_group,)
- )
-
- potential_hops = self._count_state_group_hops_txn(txn, prev_group)
- if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- self.db.simple_insert_txn(
- txn,
- table="state_group_edges",
- values={"state_group": state_group, "prev_state_group": prev_group},
- )
-
- self.db.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(delta_ids)
- ],
- )
- else:
- self.db.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(current_state_ids)
- ],
- )
-
- # Prefill the state group caches with this group.
- # It's fine to use the sequence like this as the state group map
- # is immutable. (If the map wasn't immutable then this prefill could
- # race with another update)
-
- current_member_state_ids = {
- s: ev
- for (s, ev) in iteritems(current_state_ids)
- if s[0] == EventTypes.Member
- }
- txn.call_after(
- self._state_group_members_cache.update,
- self._state_group_members_cache.sequence,
- key=state_group,
- value=dict(current_member_state_ids),
- )
-
- current_non_member_state_ids = {
- s: ev
- for (s, ev) in iteritems(current_state_ids)
- if s[0] != EventTypes.Member
- }
- txn.call_after(
- self._state_group_cache.update,
- self._state_group_cache.sequence,
- key=state_group,
- value=dict(current_non_member_state_ids),
- )
-
- return state_group
-
- return self.db.runInteraction("store_state_group", _store_state_group_txn)
-
@defer.inlineCallbacks
def get_referenced_state_groups(self, state_groups):
"""Check if the state groups are referenced by events.
@@ -1023,22 +289,14 @@ class StateGroupWorkerStore(
return set(row["state_group"] for row in rows)
-class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
+class MainStateBackgroundUpdateStore(SQLBaseStore):
- STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
- STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
def __init__(self, database: Database, db_conn, hs):
- super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_update_handler(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
- self._background_deduplicate_state,
- )
- self.db.updates.register_background_update_handler(
- self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
- )
+ super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
self.db.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
@@ -1053,181 +311,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["state_group"],
)
- @defer.inlineCallbacks
- def _background_deduplicate_state(self, progress, batch_size):
- """This background update will slowly deduplicate state by reencoding
- them as deltas.
- """
- last_state_group = progress.get("last_state_group", 0)
- rows_inserted = progress.get("rows_inserted", 0)
- max_group = progress.get("max_group", None)
-
- BATCH_SIZE_SCALE_FACTOR = 100
- batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
-
- if max_group is None:
- rows = yield self.db.execute(
- "_background_deduplicate_state",
- None,
- "SELECT coalesce(max(id), 0) FROM state_groups",
- )
- max_group = rows[0][0]
-
- def reindex_txn(txn):
- new_last_state_group = last_state_group
- for count in range(batch_size):
- txn.execute(
- "SELECT id, room_id FROM state_groups"
- " WHERE ? < id AND id <= ?"
- " ORDER BY id ASC"
- " LIMIT 1",
- (new_last_state_group, max_group),
- )
- row = txn.fetchone()
- if row:
- state_group, room_id = row
-
- if not row or not state_group:
- return True, count
-
- txn.execute(
- "SELECT state_group FROM state_group_edges"
- " WHERE state_group = ?",
- (state_group,),
- )
-
- # If we reach a point where we've already started inserting
- # edges we should stop.
- if txn.fetchall():
- return True, count
-
- txn.execute(
- "SELECT coalesce(max(id), 0) FROM state_groups"
- " WHERE id < ? AND room_id = ?",
- (state_group, room_id),
- )
- (prev_group,) = txn.fetchone()
- new_last_state_group = state_group
-
- if prev_group:
- potential_hops = self._count_state_group_hops_txn(txn, prev_group)
- if potential_hops >= MAX_STATE_DELTA_HOPS:
- # We want to ensure chains are at most this long,#
- # otherwise read performance degrades.
- continue
-
- prev_state = self._get_state_groups_from_groups_txn(
- txn, [prev_group]
- )
- prev_state = prev_state[prev_group]
-
- curr_state = self._get_state_groups_from_groups_txn(
- txn, [state_group]
- )
- curr_state = curr_state[state_group]
-
- if not set(prev_state.keys()) - set(curr_state.keys()):
- # We can only do a delta if the current has a strict super set
- # of keys
-
- delta_state = {
- key: value
- for key, value in iteritems(curr_state)
- if prev_state.get(key, None) != value
- }
-
- self.db.simple_delete_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": state_group},
- )
-
- self.db.simple_insert_txn(
- txn,
- table="state_group_edges",
- values={
- "state_group": state_group,
- "prev_state_group": prev_group,
- },
- )
-
- self.db.simple_delete_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- )
-
- self.db.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(delta_state)
- ],
- )
-
- progress = {
- "last_state_group": state_group,
- "rows_inserted": rows_inserted + batch_size,
- "max_group": max_group,
- }
-
- self.db.updates._background_update_progress_txn(
- txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
- )
-
- return False, batch_size
-
- finished, result = yield self.db.runInteraction(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
- )
-
- if finished:
- yield self.db.updates._end_background_update(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
- )
-
- return result * BATCH_SIZE_SCALE_FACTOR
-
- @defer.inlineCallbacks
- def _background_index_state(self, progress, batch_size):
- def reindex_txn(conn):
- conn.rollback()
- if isinstance(self.database_engine, PostgresEngine):
- # postgres insists on autocommit for the index
- conn.set_session(autocommit=True)
- try:
- txn = conn.cursor()
- txn.execute(
- "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
- " ON state_groups_state(state_group, type, state_key)"
- )
- txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
- finally:
- conn.set_session(autocommit=False)
- else:
- txn = conn.cursor()
- txn.execute(
- "CREATE INDEX state_groups_state_type_idx"
- " ON state_groups_state(state_group, type, state_key)"
- )
- txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
-
- yield self.db.runWithConnection(reindex_txn)
-
- yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
-
- return 1
-
-
-class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
+class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py
new file mode 100644
index 0000000000..86e09f6229
--- /dev/null
+++ b/synapse/storage/data_stores/state/__init__.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py
new file mode 100644
index 0000000000..e8edaf9f7b
--- /dev/null
+++ b/synapse/storage/data_stores/state/bg_updates.py
@@ -0,0 +1,374 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import iteritems
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.state import StateFilter
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class StateGroupBackgroundUpdateStore(SQLBaseStore):
+ """Defines functions related to state groups needed to run the state backgroud
+ updates.
+ """
+
+ def _count_state_group_hops_txn(self, txn, state_group):
+ """Given a state group, count how many hops there are in the tree.
+
+ This is used to ensure the delta chains don't get too long.
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT count(*) FROM state;
+ """
+
+ txn.execute(sql, (state_group,))
+ row = txn.fetchone()
+ if row and row[0]:
+ return row[0]
+ else:
+ return 0
+ else:
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ next_group = state_group
+ count = 0
+
+ while next_group:
+ next_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+ if next_group:
+ count += 1
+
+ return count
+
+ def _get_state_groups_from_groups_txn(
+ self, txn, groups, state_filter=StateFilter.all()
+ ):
+ results = {group: {} for group in groups}
+
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ # Unless the filter clause is empty, we're going to append it after an
+ # existing where clause
+ if where_clause:
+ where_clause = " AND (%s)" % (where_clause,)
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # Temporarily disable sequential scans in this transaction. This is
+ # a temporary hack until we can add the right indices in
+ txn.execute("SET LOCAL enable_seqscan=off")
+
+ # The below query walks the state_group tree so that the "state"
+ # table includes all state_groups in the tree. It then joins
+ # against `state_groups_state` to fetch the latest state.
+ # It assumes that previous state groups are always numerically
+ # lesser.
+ # The PARTITION is used to get the event_id in the greatest state
+ # group for the given type, state_key.
+ # This may return multiple rows per (type, state_key), but last_value
+ # should be the same.
+ sql = """
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT DISTINCT type, state_key, last_value(event_id) OVER (
+ PARTITION BY type, state_key ORDER BY state_group ASC
+ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ ) AS event_id FROM state_groups_state
+ WHERE state_group IN (
+ SELECT state_group FROM state
+ )
+ """
+
+ for group in groups:
+ args = [group]
+ args.extend(where_args)
+
+ txn.execute(sql + where_clause, args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (typ, state_key)
+ results[group][key] = event_id
+ else:
+ max_entries_returned = state_filter.max_entries_returned()
+
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ for group in groups:
+ next_group = group
+
+ while next_group:
+ # We did this before by getting the list of group ids, and
+ # then passing that list to sqlite to get latest event for
+ # each (type, state_key). However, that was terribly slow
+ # without the right indices (which we can't add until
+ # after we finish deduping state, which requires this func)
+ args = [next_group]
+ args.extend(where_args)
+
+ txn.execute(
+ "SELECT type, state_key, event_id FROM state_groups_state"
+ " WHERE state_group = ? " + where_clause,
+ args,
+ )
+ results[group].update(
+ ((typ, state_key), event_id)
+ for typ, state_key, event_id in txn
+ if (typ, state_key) not in results[group]
+ )
+
+ # If the number of entries in the (type,state_key)->event_id dict
+ # matches the number of (type,state_keys) types we were searching
+ # for, then we must have found them all, so no need to go walk
+ # further down the tree... UNLESS our types filter contained
+ # wildcards (i.e. Nones) in which case we have to do an exhaustive
+ # search
+ if (
+ max_entries_returned is not None
+ and len(results[group]) == max_entries_returned
+ ):
+ break
+
+ next_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ return results
+
+
+class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
+
+ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+ STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+ STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ self.db.updates.register_background_update_handler(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+ self._background_deduplicate_state,
+ )
+ self.db.updates.register_background_update_handler(
+ self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
+ )
+ self.db.updates.register_background_index_update(
+ self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME,
+ index_name="state_groups_room_id_idx",
+ table="state_groups",
+ columns=["room_id"],
+ )
+
+ @defer.inlineCallbacks
+ def _background_deduplicate_state(self, progress, batch_size):
+ """This background update will slowly deduplicate state by reencoding
+ them as deltas.
+ """
+ last_state_group = progress.get("last_state_group", 0)
+ rows_inserted = progress.get("rows_inserted", 0)
+ max_group = progress.get("max_group", None)
+
+ BATCH_SIZE_SCALE_FACTOR = 100
+
+ batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
+
+ if max_group is None:
+ rows = yield self.db.execute(
+ "_background_deduplicate_state",
+ None,
+ "SELECT coalesce(max(id), 0) FROM state_groups",
+ )
+ max_group = rows[0][0]
+
+ def reindex_txn(txn):
+ new_last_state_group = last_state_group
+ for count in range(batch_size):
+ txn.execute(
+ "SELECT id, room_id FROM state_groups"
+ " WHERE ? < id AND id <= ?"
+ " ORDER BY id ASC"
+ " LIMIT 1",
+ (new_last_state_group, max_group),
+ )
+ row = txn.fetchone()
+ if row:
+ state_group, room_id = row
+
+ if not row or not state_group:
+ return True, count
+
+ txn.execute(
+ "SELECT state_group FROM state_group_edges"
+ " WHERE state_group = ?",
+ (state_group,),
+ )
+
+ # If we reach a point where we've already started inserting
+ # edges we should stop.
+ if txn.fetchall():
+ return True, count
+
+ txn.execute(
+ "SELECT coalesce(max(id), 0) FROM state_groups"
+ " WHERE id < ? AND room_id = ?",
+ (state_group, room_id),
+ )
+ (prev_group,) = txn.fetchone()
+ new_last_state_group = state_group
+
+ if prev_group:
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+ if potential_hops >= MAX_STATE_DELTA_HOPS:
+ # We want to ensure chains are at most this long,#
+ # otherwise read performance degrades.
+ continue
+
+ prev_state = self._get_state_groups_from_groups_txn(
+ txn, [prev_group]
+ )
+ prev_state = prev_state[prev_group]
+
+ curr_state = self._get_state_groups_from_groups_txn(
+ txn, [state_group]
+ )
+ curr_state = curr_state[state_group]
+
+ if not set(prev_state.keys()) - set(curr_state.keys()):
+ # We can only do a delta if the current has a strict super set
+ # of keys
+
+ delta_state = {
+ key: value
+ for key, value in iteritems(curr_state)
+ if prev_state.get(key, None) != value
+ }
+
+ self.db.simple_delete_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": state_group},
+ )
+
+ self.db.simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={
+ "state_group": state_group,
+ "prev_state_group": prev_group,
+ },
+ )
+
+ self.db.simple_delete_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(delta_state)
+ ],
+ )
+
+ progress = {
+ "last_state_group": state_group,
+ "rows_inserted": rows_inserted + batch_size,
+ "max_group": max_group,
+ }
+
+ self.db.updates._background_update_progress_txn(
+ txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
+ )
+
+ return False, batch_size
+
+ finished, result = yield self.db.runInteraction(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
+ )
+
+ if finished:
+ yield self.db.updates._end_background_update(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
+ )
+
+ return result * BATCH_SIZE_SCALE_FACTOR
+
+ @defer.inlineCallbacks
+ def _background_index_state(self, progress, batch_size):
+ def reindex_txn(conn):
+ conn.rollback()
+ if isinstance(self.database_engine, PostgresEngine):
+ # postgres insists on autocommit for the index
+ conn.set_session(autocommit=True)
+ try:
+ txn = conn.cursor()
+ txn.execute(
+ "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
+ " ON state_groups_state(state_group, type, state_key)"
+ )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+ finally:
+ conn.set_session(autocommit=False)
+ else:
+ txn = conn.cursor()
+ txn.execute(
+ "CREATE INDEX state_groups_state_type_idx"
+ " ON state_groups_state(state_group, type, state_key)"
+ )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+
+ yield self.db.runWithConnection(reindex_txn)
+
+ yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+
+ return 1
diff --git a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
index ae09fa0065..ae09fa0065 100644
--- a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
index e85699e82e..e85699e82e 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql
+++ b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
new file mode 100644
index 0000000000..1450313bfa
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
@@ -0,0 +1,19 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- The following indices are redundant, other indices are equivalent or
+-- supersets
+DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
diff --git a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
index 33980d02f0..33980d02f0 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql
index 0f1fa68a89..0f1fa68a89 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/state.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
index 97e5067ef4..97e5067ef4 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
index 9fd1ccf6f7..9fd1ccf6f7 100644
--- a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py
+++ b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
new file mode 100644
index 0000000000..7916ef18b2
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('state_groups_room_id_idx', '{}');
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
new file mode 100644
index 0000000000..35f97d6b3d
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
@@ -0,0 +1,37 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE state_groups (
+ id BIGINT PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_groups_state (
+ state_group BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_group_edges (
+ state_group BIGINT NOT NULL,
+ prev_state_group BIGINT NOT NULL
+);
+
+CREATE INDEX state_group_edges_idx ON state_group_edges (state_group);
+CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group);
+CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key);
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
new file mode 100644
index 0000000000..fcd926c9fb
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE state_group_id_seq
+ START WITH 1
+ INCREMENT BY 1
+ NO MINVALUE
+ NO MAXVALUE
+ CACHE 1;
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
new file mode 100644
index 0000000000..d53695f238
--- /dev/null
+++ b/synapse/storage/data_stores/state/store.py
@@ -0,0 +1,640 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections import namedtuple
+
+from six import iteritems
+from six.moves import range
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
+from synapse.storage.database import Database
+from synapse.storage.state import StateFilter
+from synapse.util.caches import get_cache_factor_for
+from synapse.util.caches.descriptors import cached
+from synapse.util.caches.dictionary_cache import DictionaryCache
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class _GetStateGroupDelta(
+ namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
+ """Return type of get_state_group_delta that implements __len__, which lets
+ us use the itrable flag when caching
+ """
+
+ __slots__ = []
+
+ def __len__(self):
+ return len(self.delta_ids) if self.delta_ids else 0
+
+
+class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
+ """A data store for fetching/storing state groups.
+ """
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateGroupDataStore, self).__init__(database, db_conn, hs)
+
+ # Originally the state store used a single DictionaryCache to cache the
+ # event IDs for the state types in a given state group to avoid hammering
+ # on the state_group* tables.
+ #
+ # The point of using a DictionaryCache is that it can cache a subset
+ # of the state events for a given state group (i.e. a subset of the keys for a
+ # given dict which is an entry in the cache for a given state group ID).
+ #
+ # However, this poses problems when performing complicated queries
+ # on the store - for instance: "give me all the state for this group, but
+ # limit members to this subset of users", as DictionaryCache's API isn't
+ # rich enough to say "please cache any of these fields, apart from this subset".
+ # This is problematic when lazy loading members, which requires this behaviour,
+ # as without it the cache has no choice but to speculatively load all
+ # state events for the group, which negates the efficiency being sought.
+ #
+ # Rather than overcomplicating DictionaryCache's API, we instead split the
+ # state_group_cache into two halves - one for tracking non-member events,
+ # and the other for tracking member_events. This means that lazy loading
+ # queries can be made in a cache-friendly manner by querying both caches
+ # separately and then merging the result. So for the example above, you
+ # would query the members cache for a specific subset of state keys
+ # (which DictionaryCache will handle efficiently and fine) and the non-members
+ # cache for all state (which DictionaryCache will similarly handle fine)
+ # and then just merge the results together.
+ #
+ # We size the non-members cache to be smaller than the members cache as the
+ # vast majority of state in Matrix (today) is member events.
+
+ self._state_group_cache = DictionaryCache(
+ "*stateGroupCache*",
+ # TODO: this hasn't been tuned yet
+ 50000 * get_cache_factor_for("stateGroupCache"),
+ )
+ self._state_group_members_cache = DictionaryCache(
+ "*stateGroupMembersCache*",
+ 500000 * get_cache_factor_for("stateGroupMembersCache"),
+ )
+
+ @cached(max_entries=10000, iterable=True)
+ def get_state_group_delta(self, state_group):
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Returns:
+ (prev_group, delta_ids), where both may be None.
+ """
+
+ def _get_state_group_delta_txn(txn):
+ prev_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": state_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ if not prev_group:
+ return _GetStateGroupDelta(None, None)
+
+ delta_ids = self.db.simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ retcols=("type", "state_key", "event_id"),
+ )
+
+ return _GetStateGroupDelta(
+ prev_group,
+ {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+ )
+
+ return self.db.runInteraction(
+ "get_state_group_delta", _get_state_group_delta_txn
+ )
+
+ @defer.inlineCallbacks
+ def _get_state_groups_from_groups(self, groups, state_filter):
+ """Returns the state groups for a given set of groups, filtering on
+ types of state events.
+
+ Args:
+ groups(list[int]): list of state group IDs to query
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
+ results = {}
+
+ chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
+ for chunk in chunks:
+ res = yield self.db.runInteraction(
+ "_get_state_groups_from_groups",
+ self._get_state_groups_from_groups_txn,
+ chunk,
+ state_filter,
+ )
+ results.update(res)
+
+ return results
+
+ def _get_state_for_group_using_cache(self, cache, group, state_filter):
+ """Checks if group is in cache. See `_get_state_for_groups`
+
+ Args:
+ cache(DictionaryCache): the state group cache to use
+ group(int): The state group to lookup
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns 2-tuple (`state_dict`, `got_all`).
+ `got_all` is a bool indicating if we successfully retrieved all
+ requests state from the cache, if False we need to query the DB for the
+ missing state.
+ """
+ is_all, known_absent, state_dict_ids = cache.get(group)
+
+ if is_all or state_filter.is_full():
+ # Either we have everything or want everything, either way
+ # `is_all` tells us whether we've gotten everything.
+ return state_filter.filter_state(state_dict_ids), is_all
+
+ # tracks whether any of our requested types are missing from the cache
+ missing_types = False
+
+ if state_filter.has_wildcards():
+ # We don't know if we fetched all the state keys for the types in
+ # the filter that are wildcards, so we have to assume that we may
+ # have missed some.
+ missing_types = True
+ else:
+ # There aren't any wild cards, so `concrete_types()` returns the
+ # complete list of event types we're wanting.
+ for key in state_filter.concrete_types():
+ if key not in state_dict_ids and key not in known_absent:
+ missing_types = True
+ break
+
+ return state_filter.filter_state(state_dict_ids), not missing_types
+
+ @defer.inlineCallbacks
+ def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
+
+ member_filter, non_member_filter = state_filter.get_member_split()
+
+ # Now we look them up in the member and non-member caches
+ (
+ non_member_state,
+ incomplete_groups_nm,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache, state_filter=non_member_filter
+ )
+
+ (
+ member_state,
+ incomplete_groups_m,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_members_cache, state_filter=member_filter
+ )
+
+ state = dict(non_member_state)
+ for group in groups:
+ state[group].update(member_state[group])
+
+ # Now fetch any missing groups from the database
+
+ incomplete_groups = incomplete_groups_m | incomplete_groups_nm
+
+ if not incomplete_groups:
+ return state
+
+ cache_sequence_nm = self._state_group_cache.sequence
+ cache_sequence_m = self._state_group_members_cache.sequence
+
+ # Help the cache hit ratio by expanding the filter a bit
+ db_state_filter = state_filter.return_expanded()
+
+ group_to_state_dict = yield self._get_state_groups_from_groups(
+ list(incomplete_groups), state_filter=db_state_filter
+ )
+
+ # Now lets update the caches
+ self._insert_into_cache(
+ group_to_state_dict,
+ db_state_filter,
+ cache_seq_num_members=cache_sequence_m,
+ cache_seq_num_non_members=cache_sequence_nm,
+ )
+
+ # And finally update the result dict, by filtering out any extra
+ # stuff we pulled out of the database.
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ # We just replace any existing entries, as we will have loaded
+ # everything we need from the database anyway.
+ state[group] = state_filter.filter_state(group_state_dict)
+
+ return state
+
+ def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key, querying from a specific cache.
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ cache (DictionaryCache): the cache of group ids to state dicts which
+ we will pass through - either the normal state cache or the specific
+ members state cache.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns:
+ tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ of entries in the cache, and the state group ids either missing
+ from the cache or incomplete.
+ """
+ results = {}
+ incomplete_groups = set()
+ for group in set(groups):
+ state_dict_ids, got_all = self._get_state_for_group_using_cache(
+ cache, group, state_filter
+ )
+ results[group] = state_dict_ids
+
+ if not got_all:
+ incomplete_groups.add(group)
+
+ return results, incomplete_groups
+
+ def _insert_into_cache(
+ self,
+ group_to_state_dict,
+ state_filter,
+ cache_seq_num_members,
+ cache_seq_num_non_members,
+ ):
+ """Inserts results from querying the database into the relevant cache.
+
+ Args:
+ group_to_state_dict (dict): The new entries pulled from database.
+ Map from state group to state dict
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ cache_seq_num_members (int): Sequence number of member cache since
+ last lookup in cache
+ cache_seq_num_non_members (int): Sequence number of member cache since
+ last lookup in cache
+ """
+
+ # We need to work out which types we've fetched from the DB for the
+ # member vs non-member caches. This should be as accurate as possible,
+ # but can be an underestimate (e.g. when we have wild cards)
+
+ member_filter, non_member_filter = state_filter.get_member_split()
+ if member_filter.is_full():
+ # We fetched all member events
+ member_types = None
+ else:
+ # `concrete_types()` will only return a subset when there are wild
+ # cards in the filter, but that's fine.
+ member_types = member_filter.concrete_types()
+
+ if non_member_filter.is_full():
+ # We fetched all non member events
+ non_member_types = None
+ else:
+ non_member_types = non_member_filter.concrete_types()
+
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ state_dict_members = {}
+ state_dict_non_members = {}
+
+ for k, v in iteritems(group_state_dict):
+ if k[0] == EventTypes.Member:
+ state_dict_members[k] = v
+ else:
+ state_dict_non_members[k] = v
+
+ self._state_group_members_cache.update(
+ cache_seq_num_members,
+ key=group,
+ value=state_dict_members,
+ fetched_keys=member_types,
+ )
+
+ self._state_group_cache.update(
+ cache_seq_num_non_members,
+ key=group,
+ value=state_dict_non_members,
+ fetched_keys=non_member_types,
+ )
+
+ def store_state_group(
+ self, event_id, room_id, prev_group, delta_ids, current_state_ids
+ ):
+ """Store a new set of state, returning a newly assigned state group.
+
+ Args:
+ event_id (str): The event ID for which the state was calculated
+ room_id (str)
+ prev_group (int|None): A previous state group for the room, optional.
+ delta_ids (dict|None): The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids (dict): The state to store. Map of (type, state_key)
+ to event_id.
+
+ Returns:
+ Deferred[int]: The state group ID
+ """
+
+ def _store_state_group_txn(txn):
+ if current_state_ids is None:
+ # AFAIK, this can never happen
+ raise Exception("current_state_ids cannot be None")
+
+ state_group = self.database_engine.get_next_state_group_id(txn)
+
+ self.db.simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={"id": state_group, "room_id": room_id, "event_id": event_id},
+ )
+
+ # We persist as a delta if we can, while also ensuring the chain
+ # of deltas isn't tooo long, as otherwise read performance degrades.
+ if prev_group:
+ is_in_db = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+ if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ self.db.simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={"state_group": state_group, "prev_state_group": prev_group},
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(delta_ids)
+ ],
+ )
+ else:
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(current_state_ids)
+ ],
+ )
+
+ # Prefill the state group caches with this group.
+ # It's fine to use the sequence like this as the state group map
+ # is immutable. (If the map wasn't immutable then this prefill could
+ # race with another update)
+
+ current_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] == EventTypes.Member
+ }
+ txn.call_after(
+ self._state_group_members_cache.update,
+ self._state_group_members_cache.sequence,
+ key=state_group,
+ value=dict(current_member_state_ids),
+ )
+
+ current_non_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] != EventTypes.Member
+ }
+ txn.call_after(
+ self._state_group_cache.update,
+ self._state_group_cache.sequence,
+ key=state_group,
+ value=dict(current_non_member_state_ids),
+ )
+
+ return state_group
+
+ return self.db.runInteraction("store_state_group", _store_state_group_txn)
+
+ def purge_unreferenced_state_groups(
+ self, room_id: str, state_groups_to_delete
+ ) -> defer.Deferred:
+ """Deletes no longer referenced state groups and de-deltas any state
+ groups that reference them.
+
+ Args:
+ room_id: The room the state groups belong to (must all be in the
+ same room).
+ state_groups_to_delete (Collection[int]): Set of all state groups
+ to delete.
+ """
+
+ return self.db.runInteraction(
+ "purge_unreferenced_state_groups",
+ self._purge_unreferenced_state_groups,
+ room_id,
+ state_groups_to_delete,
+ )
+
+ def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+ logger.info(
+ "[purge] found %i state groups to delete", len(state_groups_to_delete)
+ )
+
+ rows = self.db.simple_select_many_txn(
+ txn,
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ retcols=("state_group",),
+ )
+
+ remaining_state_groups = set(
+ row["state_group"]
+ for row in rows
+ if row["state_group"] not in state_groups_to_delete
+ )
+
+ logger.info(
+ "[purge] de-delta-ing %i remaining state groups",
+ len(remaining_state_groups),
+ )
+
+ # Now we turn the state groups that reference to-be-deleted state
+ # groups to non delta versions.
+ for sg in remaining_state_groups:
+ logger.info("[purge] de-delta-ing remaining state group %s", sg)
+ curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
+ curr_state = curr_state[sg]
+
+ self.db.simple_delete_txn(
+ txn, table="state_groups_state", keyvalues={"state_group": sg}
+ )
+
+ self.db.simple_delete_txn(
+ txn, table="state_group_edges", keyvalues={"state_group": sg}
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": sg,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(curr_state)
+ ],
+ )
+
+ logger.info("[purge] removing redundant state groups")
+ txn.executemany(
+ "DELETE FROM state_groups_state WHERE state_group = ?",
+ ((sg,) for sg in state_groups_to_delete),
+ )
+ txn.executemany(
+ "DELETE FROM state_groups WHERE id = ?",
+ ((sg,) for sg in state_groups_to_delete),
+ )
+
+ @defer.inlineCallbacks
+ def get_previous_state_groups(self, state_groups):
+ """Fetch the previous groups of the given state groups.
+
+ Args:
+ state_groups (Iterable[int])
+
+ Returns:
+ Deferred[dict[int, int]]: mapping from state group to previous
+ state group.
+ """
+
+ rows = yield self.db.simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("prev_state_group", "state_group"),
+ desc="get_previous_state_groups",
+ )
+
+ return {row["state_group"]: row["prev_state_group"] for row in rows}
+
+ def purge_room_state(self, room_id, state_groups_to_delete):
+ """Deletes all record of a room from state tables
+
+ Args:
+ room_id (str):
+ state_groups_to_delete (list[int]): State groups to delete
+ """
+
+ return self.db.runInteraction(
+ "purge_room_state",
+ self._purge_room_state_txn,
+ room_id,
+ state_groups_to_delete,
+ )
+
+ def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+ # first we have to delete the state groups states
+ logger.info("[purge] removing %s from state_groups_state", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_groups_state",
+ column="state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
+
+ # ... and the state group edges
+ logger.info("[purge] removing %s from state_group_edges", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_group_edges",
+ column="state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
+
+ # ... and the state groups
+ logger.info("[purge] removing %s from state_groups", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_groups",
+ column="id",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ec19ae1d9d..1003dd84a5 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -24,9 +24,11 @@ from six.moves import intern, range
from prometheus_client import Histogram
+from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
@@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
}
+def make_pool(
+ reactor, db_config: DatabaseConnectionConfig, engine
+) -> adbapi.ConnectionPool:
+ """Get the connection pool for the database.
+ """
+
+ return adbapi.ConnectionPool(
+ db_config.config["name"],
+ cp_reactor=reactor,
+ cp_openfun=engine.on_new_connection,
+ **db_config.config.get("args", {})
+ )
+
+
+def make_conn(db_config: DatabaseConnectionConfig, engine):
+ """Make a new connection to the database and return it.
+
+ Returns:
+ Connection
+ """
+
+ db_params = {
+ k: v
+ for k, v in db_config.config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = engine.module.connect(**db_params)
+ engine.on_new_connection(db_conn)
+ return db_conn
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@@ -218,10 +251,11 @@ class Database(object):
_TXN_ID = 0
- def __init__(self, hs):
+ def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
self.hs = hs
self._clock = hs.get_clock()
- self._db_pool = hs.get_db_pool()
+ self._database_config = database_config
+ self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
self.updates = BackgroundUpdater(hs, self)
@@ -234,7 +268,7 @@ class Database(object):
# to watch it
self._txn_perf_counters = PerformanceCounters()
- self.engine = hs.database_engine
+ self.engine = engine
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@@ -255,6 +289,11 @@ class Database(object):
self._check_safe_to_upsert,
)
+ def is_running(self):
+ """Is the database pool currently running
+ """
+ return self._db_pool.running
+
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index cbc74cd302..df039a072d 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -16,8 +16,6 @@
import struct
import threading
-from synapse.storage.prepare_database import prepare_database
-
class Sqlite3Engine(object):
single_threaded = True
@@ -62,6 +60,10 @@ class Sqlite3Engine(object):
return sql
def on_new_connection(self, db_conn):
+
+ # We need to import here to avoid an import loop.
+ from synapse.storage.prepare_database import prepare_database
+
if self._is_in_memory:
# In memory databases need to be rebuilt each time. Ideally we'd
# reuse the same connection as we do when starting up, but that
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index fa03ca9ff7..1ed44925fc 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -183,7 +183,7 @@ class EventsPersistenceStorage(object):
# so we use separate variables here even though they point to the same
# store for now.
self.main_store = stores.main
- self.state_store = stores.main
+ self.state_store = stores.state
self._clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 731e1c9d9c..e70026b80a 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -18,6 +18,7 @@ import imp
import logging
import os
import re
+from collections import Counter
import attr
@@ -41,7 +42,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass
-def prepare_database(db_conn, database_engine, config):
+def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -54,11 +55,10 @@ def prepare_database(db_conn, database_engine, config):
config (synapse.config.homeserver.HomeServerConfig|None):
application config, or None if we are connecting to an existing
database which we expect to be configured already
+ data_stores (list[str]): The name of the data stores that will be used
+ with this database. Defaults to all data stores.
"""
- # For now we only have the one datastore.
- data_stores = ["main"]
-
try:
cur = db_conn.cursor()
version_info = _get_or_create_schema_state(cur, database_engine)
@@ -70,7 +70,10 @@ def prepare_database(db_conn, database_engine, config):
if user_version != SCHEMA_VERSION:
# If we don't pass in a config file then we are expecting to
# have already upgraded the DB.
- raise UpgradeDatabaseException("Database needs to be upgraded")
+ raise UpgradeDatabaseException(
+ "Expected database schema version %i but got %i"
+ % (SCHEMA_VERSION, user_version)
+ )
else:
_upgrade_existing_database(
cur,
@@ -313,6 +316,9 @@ def _upgrade_existing_database(
)
)
+ # Used to check if we have any duplicate file names
+ file_name_counter = Counter()
+
# Now find which directories have anything of interest.
directory_entries = []
for directory in directories:
@@ -323,6 +329,9 @@ def _upgrade_existing_database(
_DirectoryListing(file_name, os.path.join(directory, file_name))
for file_name in file_names
)
+
+ for file_name in file_names:
+ file_name_counter[file_name] += 1
except FileNotFoundError:
# Data stores can have empty entries for a given version delta.
pass
@@ -331,6 +340,17 @@ def _upgrade_existing_database(
"Could not open delta dir for version %d: %s" % (v, directory)
)
+ duplicates = set(
+ file_name for file_name, count in file_name_counter.items() if count > 1
+ )
+ if duplicates:
+ # We don't support using the same file name in the same delta version.
+ raise PrepareDatabaseException(
+ "Found multiple delta files with the same name in v%d: %s",
+ v,
+ duplicates,
+ )
+
# We sort to ensure that we apply the delta files in a consistent
# order (to avoid bugs caused by inconsistent directory listing order)
directory_entries.sort()
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index a368182034..d6a7bd7834 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -58,7 +58,7 @@ class PurgeEventsStorage(object):
sg_to_delete = yield self._find_unreferenced_groups(state_groups)
- yield self.stores.main.purge_unreferenced_state_groups(room_id, sg_to_delete)
+ yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
@defer.inlineCallbacks
def _find_unreferenced_groups(self, state_groups):
@@ -102,7 +102,7 @@ class PurgeEventsStorage(object):
# groups that are referenced.
current_search -= referenced
- edges = yield self.stores.main.get_previous_state_groups(current_search)
+ edges = yield self.stores.state.get_previous_state_groups(current_search)
prevs = set(edges.values())
# We don't bother re-handling groups we've already seen
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 3735846899..cbeb586014 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -342,7 +342,7 @@ class StateGroupStorage(object):
(prev_group, delta_ids)
"""
- return self.stores.main.get_state_group_delta(state_group)
+ return self.stores.state.get_state_group_delta(state_group)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
@@ -362,7 +362,7 @@ class StateGroupStorage(object):
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self.stores.main._get_state_for_groups(groups)
+ group_to_state = yield self.stores.state._get_state_for_groups(groups)
return group_to_state
@@ -423,7 +423,7 @@ class StateGroupStorage(object):
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
- return self.stores.main._get_state_groups_from_groups(groups, state_filter)
+ return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
@@ -439,7 +439,7 @@ class StateGroupStorage(object):
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self.stores.main._get_state_for_groups(
+ group_to_state = yield self.stores.state._get_state_for_groups(
groups, state_filter
)
@@ -476,7 +476,7 @@ class StateGroupStorage(object):
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self.stores.main._get_state_for_groups(
+ group_to_state = yield self.stores.state._get_state_for_groups(
groups, state_filter
)
@@ -532,7 +532,7 @@ class StateGroupStorage(object):
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
- return self.stores.main._get_state_for_groups(groups, state_filter)
+ return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
@@ -552,6 +552,6 @@ class StateGroupStorage(object):
Returns:
Deferred[int]: The state group ID
"""
- return self.stores.main.store_state_group(
+ return self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)
|