diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 136 | ||||
-rw-r--r-- | synapse/storage/_base.py | 167 | ||||
-rw-r--r-- | synapse/storage/directory.py | 1 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 162 | ||||
-rw-r--r-- | synapse/storage/pdu.py | 48 | ||||
-rw-r--r-- | synapse/storage/registration.py | 7 | ||||
-rw-r--r-- | synapse/storage/room.py | 2 | ||||
-rw-r--r-- | synapse/storage/schema/event_edges.sql | 49 | ||||
-rw-r--r-- | synapse/storage/schema/event_signatures.sql | 65 | ||||
-rw-r--r-- | synapse/storage/schema/im.sql | 1 | ||||
-rw-r--r-- | synapse/storage/schema/signatures.sql | 66 | ||||
-rw-r--r-- | synapse/storage/schema/state.sql | 33 | ||||
-rw-r--r-- | synapse/storage/signatures.py | 302 | ||||
-rw-r--r-- | synapse/storage/state.py | 101 | ||||
-rw-r--r-- | synapse/storage/stream.py | 5 | ||||
-rw-r--r-- | synapse/storage/transactions.py | 6 |
16 files changed, 1072 insertions, 79 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4e9291fdff..d75c366834 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -40,6 +40,15 @@ from .stream import StreamStore from .pdu import StatePduStore, PduStore, PdusTable from .transactions import TransactionStore from .keys import KeyStore +from .event_federation import EventFederationStore + +from .state import StateStore +from .signatures import SignatureStore + +from syutil.base64util import decode_base64 + +from synapse.crypto.event_signing import compute_pdu_event_reference_hash + import json import logging @@ -59,6 +68,10 @@ SCHEMAS = [ "room_aliases", "keys", "redactions", + "state", + "signatures", + "event_edges", + "event_signatures", ] @@ -73,10 +86,12 @@ class _RollbackButIsFineException(Exception): """ pass + class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore, PresenceStore, PduStore, StatePduStore, TransactionStore, - DirectoryStore, KeyStore): + DirectoryStore, KeyStore, StateStore, SignatureStore, + EventFederationStore, ): def __init__(self, hs): super(DataStore, self).__init__(hs) @@ -99,6 +114,7 @@ class DataStore(RoomMemberStore, RoomStore, try: yield self.runInteraction( + "persist_event", self._persist_pdu_event_txn, pdu=pdu, event=event, @@ -119,7 +135,8 @@ class DataStore(RoomMemberStore, RoomStore, "type", "room_id", "content", - "unrecognized_keys" + "unrecognized_keys", + "depth", ], allow_none=allow_none, ) @@ -144,6 +161,8 @@ class DataStore(RoomMemberStore, RoomStore, def _persist_event_pdu_txn(self, txn, pdu): cols = dict(pdu.__dict__) unrec_keys = dict(pdu.unrecognized_keys) + del cols["hashes"] + del cols["signatures"] del cols["content"] del cols["prev_pdus"] cols["content_json"] = json.dumps(pdu.content) @@ -159,6 +178,33 @@ class DataStore(RoomMemberStore, RoomStore, logger.debug("Persisting: %s", repr(cols)) + for hash_alg, hash_base64 in pdu.hashes.items(): + hash_bytes = decode_base64(hash_base64) + self._store_pdu_content_hash_txn( + txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes, + ) + + signatures = pdu.signatures.get(pdu.origin, {}) + + for key_id, signature_base64 in signatures.items(): + signature_bytes = decode_base64(signature_base64) + self._store_pdu_origin_signature_txn( + txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes, + ) + + for prev_pdu_id, prev_origin, prev_hashes in pdu.prev_pdus: + for alg, hash_base64 in prev_hashes.items(): + hash_bytes = decode_base64(hash_base64) + self._store_prev_pdu_hash_txn( + txn, pdu.pdu_id, pdu.origin, prev_pdu_id, prev_origin, alg, + hash_bytes + ) + + (ref_alg, ref_hash_bytes) = compute_pdu_event_reference_hash(pdu) + self._store_pdu_reference_hash_txn( + txn, pdu.pdu_id, pdu.origin, ref_alg, ref_hash_bytes + ) + if pdu.is_state: self._persist_state_txn(txn, pdu.prev_pdus, cols) else: @@ -190,6 +236,10 @@ class DataStore(RoomMemberStore, RoomStore, elif event.type == RoomRedactionEvent.TYPE: self._store_redaction(txn, event) + outlier = False + if hasattr(event, "outlier"): + outlier = event.outlier + vals = { "topological_ordering": event.depth, "event_id": event.event_id, @@ -197,25 +247,30 @@ class DataStore(RoomMemberStore, RoomStore, "room_id": event.room_id, "content": json.dumps(event.content), "processed": True, + "outlier": outlier, + "depth": event.depth, } if stream_ordering is not None: vals["stream_ordering"] = stream_ordering - if hasattr(event, "outlier"): - vals["outlier"] = event.outlier - else: - vals["outlier"] = False - unrec = { k: v for k, v in event.get_full_dict().items() - if k not in vals.keys() and k not in ["redacted", "redacted_because"] + if k not in vals.keys() and k not in [ + "redacted", "redacted_because", "signatures", "hashes", + "prev_events", + ] } vals["unrecognized_keys"] = json.dumps(unrec) try: - self._simple_insert_txn(txn, "events", vals) + self._simple_insert_txn( + txn, + "events", + vals, + or_replace=(not outlier), + ) except: logger.warn( "Failed to persist, probably duplicate: %s", @@ -224,6 +279,16 @@ class DataStore(RoomMemberStore, RoomStore, ) raise _RollbackButIsFineException("_persist_event") + self._handle_prev_events( + txn, + outlier=outlier, + event_id=event.event_id, + prev_events=event.prev_events, + room_id=event.room_id, + ) + + self._store_state_groups_txn(txn, event) + is_state = hasattr(event, "state_key") and event.state_key is not None if is_new_state and is_state: vals = { @@ -249,6 +314,30 @@ class DataStore(RoomMemberStore, RoomStore, } ) + if hasattr(event, "signatures"): + signatures = event.signatures.get(event.origin, {}) + + for key_id, signature_base64 in signatures.items(): + signature_bytes = decode_base64(signature_base64) + self._store_event_origin_signature_txn( + txn, event.event_id, event.origin, key_id, signature_bytes, + ) + + for prev_event_id, prev_hashes in event.prev_events: + for alg, hash_base64 in prev_hashes.items(): + hash_bytes = decode_base64(hash_base64) + self._store_prev_event_hash_txn( + txn, event.event_id, prev_event_id, alg, hash_bytes + ) + + # TODO + # (ref_alg, ref_hash_bytes) = compute_pdu_event_reference_hash(pdu) + # self._store_event_reference_hash_txn( + # txn, event.event_id, ref_alg, ref_hash_bytes + # ) + + self._update_min_depth_for_room_txn(txn, event.room_id, event.depth) + def _store_redaction(self, txn, event): txn.execute( "INSERT OR IGNORE INTO redactions " @@ -331,9 +420,8 @@ class DataStore(RoomMemberStore, RoomStore, """ def _snapshot(txn): membership_state = self._get_room_member(txn, user_id, room_id) - prev_pdus = self._get_latest_pdus_in_context( - txn, room_id - ) + prev_events = self._get_latest_events_in_room(txn, room_id) + if state_type is not None and state_key is not None: prev_state_pdu = self._get_current_state_pdu( txn, room_id, state_type, state_key @@ -345,14 +433,14 @@ class DataStore(RoomMemberStore, RoomStore, store=self, room_id=room_id, user_id=user_id, - prev_pdus=prev_pdus, + prev_events=prev_events, membership_state=membership_state, state_type=state_type, state_key=state_key, prev_state_pdu=prev_state_pdu, ) - return self.runInteraction(_snapshot) + return self.runInteraction("snapshot_room", _snapshot) class Snapshot(object): @@ -361,7 +449,7 @@ class Snapshot(object): store (DataStore): The datastore. room_id (RoomId): The room of the snapshot. user_id (UserId): The user this snapshot is for. - prev_pdus (list): The list of PDU ids this snapshot is after. + prev_events (list): The list of event ids this snapshot is after. membership_state (RoomMemberEvent): The current state of the user in the room. state_type (str, optional): State type captured by the snapshot @@ -370,13 +458,13 @@ class Snapshot(object): the previous value of the state type and key in the room. """ - def __init__(self, store, room_id, user_id, prev_pdus, + def __init__(self, store, room_id, user_id, prev_events, membership_state, state_type=None, state_key=None, prev_state_pdu=None): self.store = store self.room_id = room_id self.user_id = user_id - self.prev_pdus = prev_pdus + self.prev_events = prev_events self.membership_state = membership_state self.state_type = state_type self.state_key = state_key @@ -386,14 +474,13 @@ class Snapshot(object): if hasattr(event, "prev_events"): return - es = [ - "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus + event.prev_events = [ + (event_id, hashes) + for event_id, hashes, _ in self.prev_events ] - event.prev_events = [e for e in es if e != event.event_id] - - if self.prev_pdus: - event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1 + if self.prev_events: + event.depth = max([int(v) for _, _, v in self.prev_events]) + 1 else: event.depth = 0 @@ -452,9 +539,10 @@ def prepare_database(db_conn): db_conn.commit() else: - sql_script = "BEGIN TRANSACTION;" + sql_script = "BEGIN TRANSACTION;\n" for sql_loc in SCHEMAS: sql_script += read_schema(sql_loc) + sql_script += "\n" sql_script += "COMMIT TRANSACTION;" c.executescript(sql_script) db_conn.commit() diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 65a86e9056..464b12f032 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,54 +19,66 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.api.events.utils import prune_event from synapse.util.logutils import log_function +from syutil.base64util import encode_base64 import collections import copy import json +import sys +import time logger = logging.getLogger(__name__) sql_logger = logging.getLogger("synapse.storage.SQL") +transaction_logger = logging.getLogger("synapse.storage.txn") class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging to the .execute() method.""" - __slots__ = ["txn"] + __slots__ = ["txn", "name"] - def __init__(self, txn): + def __init__(self, txn, name): object.__setattr__(self, "txn", txn) + object.__setattr__(self, "name", name) - def __getattribute__(self, name): - if name == "execute": - return object.__getattribute__(self, "execute") - - return getattr(object.__getattribute__(self, "txn"), name) + def __getattr__(self, name): + return getattr(self.txn, name) def __setattr__(self, name, value): - setattr(object.__getattribute__(self, "txn"), name, value) + setattr(self.txn, name, value) def execute(self, sql, *args, **kwargs): # TODO(paul): Maybe use 'info' and 'debug' for values? - sql_logger.debug("[SQL] %s", sql) + sql_logger.debug("[SQL] {%s} %s", self.name, sql) try: if args and args[0]: values = args[0] - sql_logger.debug("[SQL values] " + - ", ".join(("<%s>",) * len(values)), *values) + sql_logger.debug( + "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)), + self.name, + *values + ) except: # Don't let logging failures stop SQL from working pass - # TODO(paul): Here would be an excellent place to put some timing - # measurements, and log (warning?) slow queries. - return object.__getattribute__(self, "txn").execute( - sql, *args, **kwargs - ) + start = time.clock() * 1000 + try: + return self.txn.execute( + sql, *args, **kwargs + ) + except: + logger.exception("[SQL FAIL] {%s}", self.name) + raise + finally: + end = time.clock() * 1000 + sql_logger.debug("[SQL time] {%s} %f", self.name, end - start) class SQLBaseStore(object): + _TXN_ID = 0 def __init__(self, hs): self.hs = hs @@ -74,10 +86,30 @@ class SQLBaseStore(object): self.event_factory = hs.get_event_factory() self._clock = hs.get_clock() - def runInteraction(self, func, *args, **kwargs): + def runInteraction(self, desc, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" def inner_func(txn, *args, **kwargs): - return func(LoggingTransaction(txn), *args, **kwargs) + start = time.clock() * 1000 + txn_id = SQLBaseStore._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + + name = "%s-%x" % (desc, txn_id, ) + + transaction_logger.debug("[TXN START] {%s}", name) + try: + return func(LoggingTransaction(txn, name), *args, **kwargs) + except: + logger.exception("[TXN FAIL] {%s}", name) + raise + finally: + end = time.clock() * 1000 + transaction_logger.debug( + "[TXN END] {%s} %f", + name, end - start + ) return self._db_pool.runInteraction(inner_func, *args, **kwargs) @@ -113,7 +145,7 @@ class SQLBaseStore(object): else: return cursor.fetchall() - return self.runInteraction(interaction) + return self.runInteraction("_execute", interaction) def _execute_and_decode(self, query, *args): return self._execute(self.cursor_to_dict, query, *args) @@ -130,6 +162,7 @@ class SQLBaseStore(object): or_replace : bool; if True performs an INSERT OR REPLACE """ return self.runInteraction( + "_simple_insert", self._simple_insert_txn, table, values, or_replace=or_replace, or_ignore=or_ignore, ) @@ -170,7 +203,6 @@ class SQLBaseStore(object): table, keyvalues, retcols=retcols, allow_none=allow_none ) - @defer.inlineCallbacks def _simple_select_one_onecol(self, table, keyvalues, retcol, allow_none=False): """Executes a SELECT query on the named table, which is expected to @@ -181,19 +213,41 @@ class SQLBaseStore(object): keyvalues : dict of column names and values to select the row with retcol : string giving the name of the column to return """ - ret = yield self._simple_select_one( + return self.runInteraction( + "_simple_select_one_onecol_txn", + self._simple_select_one_onecol_txn, + table, keyvalues, retcol, allow_none=allow_none, + ) + + def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol, + allow_none=False): + ret = self._simple_select_onecol_txn( + txn, table=table, keyvalues=keyvalues, - retcols=[retcol], - allow_none=allow_none + retcol=retcol, ) if ret: - defer.returnValue(ret[retcol]) + return ret[0] else: - defer.returnValue(None) + if allow_none: + return None + else: + raise StoreError(404, "No row found") + + def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): + sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { + "retcol": retcol, + "table": table, + "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), + } + + txn.execute(sql, keyvalues.values()) + + return [r[0] for r in txn.fetchall()] + - @defer.inlineCallbacks def _simple_select_onecol(self, table, keyvalues, retcol): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -206,19 +260,11 @@ class SQLBaseStore(object): Returns: Deferred: Results in a list """ - sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { - "retcol": retcol, - "table": table, - "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), - } - - def func(txn): - txn.execute(sql, keyvalues.values()) - return txn.fetchall() - - res = yield self.runInteraction(func) - - defer.returnValue([r[0] for r in res]) + return self.runInteraction( + "_simple_select_onecol", + self._simple_select_onecol_txn, + table, keyvalues, retcol + ) def _simple_select_list(self, table, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or @@ -239,7 +285,7 @@ class SQLBaseStore(object): txn.execute(sql, keyvalues.values()) return self.cursor_to_dict(txn) - return self.runInteraction(func) + return self.runInteraction("_simple_select_list", func) def _simple_update_one(self, table, keyvalues, updatevalues, retcols=None): @@ -307,7 +353,7 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched") return ret - return self.runInteraction(func) + return self.runInteraction("_simple_selectupdate_one", func) def _simple_delete_one(self, table, keyvalues): """Executes a DELETE query on the named table, expecting to delete a @@ -319,7 +365,7 @@ class SQLBaseStore(object): """ sql = "DELETE FROM %s WHERE %s" % ( table, - " AND ".join("%s = ?" % (k) for k in keyvalues) + " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) def func(txn): @@ -328,7 +374,25 @@ class SQLBaseStore(object): raise StoreError(404, "No row found") if txn.rowcount > 1: raise StoreError(500, "more than one row matched") - return self.runInteraction(func) + return self.runInteraction("_simple_delete_one", func) + + def _simple_delete(self, table, keyvalues): + """Executes a DELETE query on the named table. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + + return self.runInteraction("_simple_delete", self._simple_delete_txn) + + def _simple_delete_txn(self, txn, table, keyvalues): + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k, ) for k in keyvalues) + ) + + return txn.execute(sql, keyvalues.values()) def _simple_max_id(self, table): """Executes a SELECT query on the named table, expecting to return the @@ -346,7 +410,7 @@ class SQLBaseStore(object): return 0 return max_id - return self.runInteraction(func) + return self.runInteraction("_simple_max_id", func) def _parse_event_from_row(self, row_dict): d = copy.deepcopy({k: v for k, v in row_dict.items()}) @@ -370,7 +434,9 @@ class SQLBaseStore(object): ) def _parse_events(self, rows): - return self.runInteraction(self._parse_events_txn, rows) + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) def _parse_events_txn(self, txn, rows): events = [self._parse_event_from_row(r) for r in rows] @@ -378,6 +444,17 @@ class SQLBaseStore(object): sql = "SELECT * FROM events WHERE event_id = ?" for ev in events: + signatures = self._get_event_origin_signatures_txn( + txn, ev.event_id, + ) + + ev.signatures = { + k: encode_base64(v) for k, v in signatures.items() + } + + prev_events = self._get_latest_events_in_room(txn, ev.room_id) + ev.prev_events = [(e_id, s,) for e_id, s, _ in prev_events] + if hasattr(ev, "prev_state"): # Load previous state_content. # TODO: Should we be pulling this out above? diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 52373a28a6..d6a7113b9c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -95,6 +95,7 @@ class DirectoryStore(SQLBaseStore): def delete_room_alias(self, room_alias): return self.runInteraction( + "delete_room_alias", self._delete_room_alias_txn, room_alias, ) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py new file mode 100644 index 0000000000..88d09d9ba8 --- /dev/null +++ b/synapse/storage/event_federation.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from syutil.base64util import encode_base64 + +import logging + + +logger = logging.getLogger(__name__) + + +class EventFederationStore(SQLBaseStore): + + def get_latest_events_in_room(self, room_id): + return self.runInteraction( + "get_latest_events_in_room", + self._get_latest_events_in_room, + room_id, + ) + + def _get_latest_events_in_room(self, txn, room_id): + self._simple_select_onecol_txn( + txn, + table="event_forward_extremities", + keyvalues={ + "room_id": room_id, + }, + retcol="event_id", + ) + + sql = ( + "SELECT e.event_id, e.depth FROM events as e " + "INNER JOIN event_forward_extremities as f " + "ON e.event_id = f.event_id " + "WHERE f.room_id = ?" + ) + + txn.execute(sql, (room_id, )) + + results = [] + for event_id, depth in txn.fetchall(): + hashes = self._get_prev_event_hashes_txn(txn, event_id) + prev_hashes = { + k: encode_base64(v) for k, v in hashes.items() + if k == "sha256" + } + results.append((event_id, prev_hashes, depth)) + + return results + + def _get_min_depth_interaction(self, txn, room_id): + min_depth = self._simple_select_one_onecol_txn( + txn, + table="room_depth", + keyvalues={"room_id": room_id,}, + retcol="min_depth", + allow_none=True, + ) + + return int(min_depth) if min_depth is not None else None + + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self._get_min_depth_interaction(txn, room_id) + + do_insert = depth < min_depth if min_depth else True + + if do_insert: + self._simple_insert_txn( + txn, + table="room_depth", + values={ + "room_id": room_id, + "min_depth": depth, + }, + or_replace=True, + ) + + def _handle_prev_events(self, txn, outlier, event_id, prev_events, + room_id): + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_insert_txn( + txn, + table="event_edges", + values={ + "event_id": event_id, + "prev_event_id": e_id, + "room_id": room_id, + }, + or_ignore=True, + ) + + # Update the extremities table if this is not an outlier. + if not outlier: + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_delete_txn( + txn, + table="event_forward_extremities", + keyvalues={ + "event_id": e_id, + "room_id": room_id, + } + ) + + + + # We only insert as a forward extremity the new pdu if there are no + # other pdus that reference it as a prev pdu + query = ( + "INSERT OR IGNORE INTO %(table)s (event_id, room_id) " + "SELECT ?, ? WHERE NOT EXISTS (" + "SELECT 1 FROM %(event_edges)s WHERE " + "prev_event_id = ? " + ")" + ) % { + "table": "event_forward_extremities", + "event_edges": "event_edges", + } + + logger.debug("query: %s", query) + + txn.execute(query, (event_id, room_id, event_id)) + + # Insert all the prev_pdus as a backwards thing, they'll get + # deleted in a second if they're incorrect anyway. + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_insert_txn( + txn, + table="event_backward_extremities", + values={ + "event_id": e_id, + "room_id": room_id, + }, + or_ignore=True, + ) + + # Also delete from the backwards extremities table all ones that + # reference pdus that we have already seen + query = ( + "DELETE FROM event_backward_extremities WHERE EXISTS (" + "SELECT 1 FROM events " + "WHERE " + "event_backward_extremities.event_id = events.event_id " + "AND not events.outlier " + ")" + ) + txn.execute(query) \ No newline at end of file diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index d70467dcd6..4a4341907b 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper from synapse.federation.units import Pdu from synapse.util.logutils import log_function +from syutil.base64util import encode_base64 + from collections import namedtuple import logging + logger = logging.getLogger(__name__) @@ -44,7 +47,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( - self._get_pdu_tuple, pdu_id, origin + "get_pdu", self._get_pdu_tuple, pdu_id, origin ) def _get_pdu_tuple(self, txn, pdu_id, origin): @@ -64,6 +67,13 @@ class PduStore(SQLBaseStore): for r in PduEdgesTable.decode_results(txn.fetchall()) ] + edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin) + + hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin) + signatures = self._get_pdu_origin_signatures_txn( + txn, pdu_id, origin + ) + query = ( "SELECT %(fields)s FROM %(pdus)s as p " "LEFT JOIN %(state)s as s " @@ -80,7 +90,9 @@ class PduStore(SQLBaseStore): row = txn.fetchone() if row: - results.append(PduTuple(PduEntry(*row), edges)) + results.append(PduTuple( + PduEntry(*row), edges, hashes, signatures, edge_hashes + )) return results @@ -96,6 +108,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( + "get_current_state_for_context", self._get_current_state_for_context, context ) @@ -144,6 +157,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( + "mark_pdu_as_processed", self._mark_as_processed, pdu_id, pdu_origin ) @@ -153,6 +167,7 @@ class PduStore(SQLBaseStore): def get_all_pdus_from_context(self, context): """Get a list of all PDUs for a given context.""" return self.runInteraction( + "get_all_pdus_from_context", self._get_all_pdus_from_context, context, ) @@ -180,6 +195,7 @@ class PduStore(SQLBaseStore): list: A list of PduTuples """ return self.runInteraction( + "get_backfill", self._get_backfill, context, pdu_list, limit ) @@ -241,6 +257,7 @@ class PduStore(SQLBaseStore): context (str) """ return self.runInteraction( + "get_min_depth_for_context", self._get_min_depth_for_context, context ) @@ -277,6 +294,13 @@ class PduStore(SQLBaseStore): (context, depth) ) + def get_latest_pdus_in_context(self, context): + return self.runInteraction( + "get_latest_pdus_in_context", + self._get_latest_pdus_in_context, + context + ) + def _get_latest_pdus_in_context(self, txn, context): """Get's a list of the most current pdus for a given context. This is used when we are sending a Pdu and need to fill out the `prev_pdus` @@ -303,9 +327,14 @@ class PduStore(SQLBaseStore): (context, ) ) - results = txn.fetchall() + results = [] + for pdu_id, origin, depth in txn.fetchall(): + hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin) + sha256_bytes = hashes["sha256"] + prev_hashes = {"sha256": encode_base64(sha256_bytes)} + results.append((pdu_id, origin, prev_hashes, depth)) - return [(row[0], row[1], row[2]) for row in results] + return results @defer.inlineCallbacks def get_oldest_pdus_in_context(self, context): @@ -347,6 +376,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( + "is_pdu_new", self._is_pdu_new, pdu_id=pdu_id, origin=origin, @@ -424,7 +454,7 @@ class PduStore(SQLBaseStore): "DELETE FROM %s WHERE pdu_id = ? AND origin = ?" % PduForwardExtremitiesTable.table_name ) - txn.executemany(query, prev_pdus) + txn.executemany(query, list(p[:2] for p in prev_pdus)) # We only insert as a forward extremety the new pdu if there are no # other pdus that reference it as a prev pdu @@ -447,7 +477,7 @@ class PduStore(SQLBaseStore): # deleted in a second if they're incorrect anyway. txn.executemany( PduBackwardExtremitiesTable.insert_statement(), - [(i, o, context) for i, o in prev_pdus] + [(i, o, context) for i, o, _ in prev_pdus] ) # Also delete from the backwards extremities table all ones that @@ -500,6 +530,7 @@ class StatePduStore(SQLBaseStore): def get_unresolved_state_tree(self, new_state_pdu): return self.runInteraction( + "get_unresolved_state_tree", self._get_unresolved_state_tree, new_state_pdu ) @@ -539,6 +570,7 @@ class StatePduStore(SQLBaseStore): def update_current_state(self, pdu_id, origin, context, pdu_type, state_key): return self.runInteraction( + "update_current_state", self._update_current_state, pdu_id, origin, context, pdu_type, state_key ) @@ -578,6 +610,7 @@ class StatePduStore(SQLBaseStore): """ return self.runInteraction( + "get_current_state_pdu", self._get_current_state_pdu, context, pdu_type, state_key ) @@ -637,6 +670,7 @@ class StatePduStore(SQLBaseStore): bool: True if the new_pdu clobbered the current state, False if not """ return self.runInteraction( + "handle_new_state", self._handle_new_state, new_pdu ) @@ -908,7 +942,7 @@ This does not include a prev_pdus key. PduTuple = namedtuple( "PduTuple", - ("pdu_entry", "prev_pdu_list") + ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes") ) """ This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent the `prev_pdus` key of a PDU. diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 719806f82b..a2ca6f9a69 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -62,8 +62,10 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if the user_id could not be registered. """ - yield self.runInteraction(self._register, user_id, token, - password_hash) + yield self.runInteraction( + "register", + self._register, user_id, token, password_hash + ) def _register(self, txn, user_id, token, password_hash): now = int(self.clock.time()) @@ -100,6 +102,7 @@ class RegistrationStore(SQLBaseStore): StoreError if no user was found. """ return self.runInteraction( + "get_user_by_token", self._query_for_auth, token ) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 8cd46334cf..7e48ce9cc3 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -150,6 +150,7 @@ class RoomStore(SQLBaseStore): def get_power_level(self, room_id, user_id): return self.runInteraction( + "get_power_level", self._get_power_level, room_id, user_id, ) @@ -183,6 +184,7 @@ class RoomStore(SQLBaseStore): def get_ops_levels(self, room_id): return self.runInteraction( + "get_ops_levels", self._get_ops_levels, room_id, ) diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql new file mode 100644 index 0000000000..e5f768c705 --- /dev/null +++ b/synapse/storage/schema/event_edges.sql @@ -0,0 +1,49 @@ + +CREATE TABLE IF NOT EXISTS event_forward_extremities( + event_id TEXT, + room_id TEXT, + CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id); +CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_backward_extremities( + event_id TEXT, + room_id TEXT, + CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id); +CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_edges( + event_id TEXT, + prev_event_id TEXT, + room_id TEXT, + CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id) +); + +CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id); +CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id); + + +CREATE TABLE IF NOT EXISTS room_depth( + room_id TEXT, + min_depth INTEGER, + CONSTRAINT uniqueness UNIQUE (room_id) +); + +CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id); + + +create TABLE IF NOT EXISTS event_destinations( + event_id TEXT, + destination TEXT, + delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered + CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id); diff --git a/synapse/storage/schema/event_signatures.sql b/synapse/storage/schema/event_signatures.sql new file mode 100644 index 0000000000..5491c7ecec --- /dev/null +++ b/synapse/storage/schema/event_signatures.sql @@ -0,0 +1,65 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_content_hashes ( + event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_reference_hashes ( + event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes ( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_origin_signatures ( + event_id TEXT, + origin TEXT, + key_id TEXT, + signature BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, key_id) +); + +CREATE INDEX IF NOT EXISTS event_origin_signatures_id ON event_origin_signatures ( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_edge_hashes( + event_id TEXT, + prev_event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE ( + event_id, prev_event_id, algorithm + ) +); + +CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes( + event_id +); diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql index 3aa83f5c8c..8d6f655993 100644 --- a/synapse/storage/schema/im.sql +++ b/synapse/storage/schema/im.sql @@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events( unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, + depth INTEGER DEFAULT 0 NOT NULL, CONSTRAINT ev_uniq UNIQUE (event_id) ); diff --git a/synapse/storage/schema/signatures.sql b/synapse/storage/schema/signatures.sql new file mode 100644 index 0000000000..1c45a51bec --- /dev/null +++ b/synapse/storage/schema/signatures.sql @@ -0,0 +1,66 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS pdu_content_hashes ( + pdu_id TEXT, + origin TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm) +); + +CREATE INDEX IF NOT EXISTS pdu_content_hashes_id ON pdu_content_hashes ( + pdu_id, origin +); + +CREATE TABLE IF NOT EXISTS pdu_reference_hashes ( + pdu_id TEXT, + origin TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm) +); + +CREATE INDEX IF NOT EXISTS pdu_reference_hashes_id ON pdu_reference_hashes ( + pdu_id, origin +); + +CREATE TABLE IF NOT EXISTS pdu_origin_signatures ( + pdu_id TEXT, + origin TEXT, + key_id TEXT, + signature BLOB, + CONSTRAINT uniqueness UNIQUE (pdu_id, origin, key_id) +); + +CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures ( + pdu_id, origin +); + +CREATE TABLE IF NOT EXISTS pdu_edge_hashes( + pdu_id TEXT, + origin TEXT, + prev_pdu_id TEXT, + prev_origin TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE ( + pdu_id, origin, prev_pdu_id, prev_origin, algorithm + ) +); + +CREATE INDEX IF NOT EXISTS pdu_edge_hashes_id ON pdu_edge_hashes( + pdu_id, origin +); diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql new file mode 100644 index 0000000000..b44c56b519 --- /dev/null +++ b/synapse/storage/schema/state.sql @@ -0,0 +1,33 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS state_groups( + id INTEGER PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS state_groups_state( + state_group INTEGER NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS event_to_state_groups( + event_id TEXT NOT NULL, + state_group INTEGER NOT NULL +); \ No newline at end of file diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py new file mode 100644 index 0000000000..5e99174fcd --- /dev/null +++ b/synapse/storage/signatures.py @@ -0,0 +1,302 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _base import SQLBaseStore + + +class SignatureStore(SQLBaseStore): + """Persistence for PDU signatures and hashes""" + + def _get_pdu_content_hashes_txn(self, txn, pdu_id, origin): + """Get all the hashes for a given PDU. + Args: + txn (cursor): + pdu_id (str): Id for the PDU. + origin (str): origin of the PDU. + Returns: + A dict of algorithm -> hash. + """ + query = ( + "SELECT algorithm, hash" + " FROM pdu_content_hashes" + " WHERE pdu_id = ? and origin = ?" + ) + txn.execute(query, (pdu_id, origin)) + return dict(txn.fetchall()) + + def _store_pdu_content_hash_txn(self, txn, pdu_id, origin, algorithm, + hash_bytes): + """Store a hash for a PDU + Args: + txn (cursor): + pdu_id (str): Id for the PDU. + origin (str): origin of the PDU. + algorithm (str): Hashing algorithm. + hash_bytes (bytes): Hash function output bytes. + """ + self._simple_insert_txn(txn, "pdu_content_hashes", { + "pdu_id": pdu_id, + "origin": origin, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }) + + def _get_pdu_reference_hashes_txn(self, txn, pdu_id, origin): + """Get all the hashes for a given PDU. + Args: + txn (cursor): + pdu_id (str): Id for the PDU. + origin (str): origin of the PDU. + Returns: + A dict of algorithm -> hash. + """ + query = ( + "SELECT algorithm, hash" + " FROM pdu_reference_hashes" + " WHERE pdu_id = ? and origin = ?" + ) + txn.execute(query, (pdu_id, origin)) + return dict(txn.fetchall()) + + def _store_pdu_reference_hash_txn(self, txn, pdu_id, origin, algorithm, + hash_bytes): + """Store a hash for a PDU + Args: + txn (cursor): + pdu_id (str): Id for the PDU. + origin (str): origin of the PDU. + algorithm (str): Hashing algorithm. + hash_bytes (bytes): Hash function output bytes. + """ + self._simple_insert_txn(txn, "pdu_reference_hashes", { + "pdu_id": pdu_id, + "origin": origin, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }) + + + def _get_pdu_origin_signatures_txn(self, txn, pdu_id, origin): + """Get all the signatures for a given PDU. + Args: + txn (cursor): + pdu_id (str): Id for the PDU. + origin (str): origin of the PDU. + Returns: + A dict of key_id -> signature_bytes. + """ + query = ( + "SELECT key_id, signature" + " FROM pdu_origin_signatures" + " WHERE pdu_id = ? and origin = ?" + ) + txn.execute(query, (pdu_id, origin)) + return dict(txn.fetchall()) + + def _store_pdu_origin_signature_txn(self, txn, pdu_id, origin, key_id, + signature_bytes): + """Store a signature from the origin server for a PDU. + Args: + txn (cursor): + pdu_id (str): Id for the PDU. + origin (str): origin of the PDU. + key_id (str): Id for the signing key. + signature (bytes): The signature. + """ + self._simple_insert_txn(txn, "pdu_origin_signatures", { + "pdu_id": pdu_id, + "origin": origin, + "key_id": key_id, + "signature": buffer(signature_bytes), + }) + + def _get_prev_pdu_hashes_txn(self, txn, pdu_id, origin): + """Get all the hashes for previous PDUs of a PDU + Args: + txn (cursor): + pdu_id (str): Id of the PDU. + origin (str): Origin of the PDU. + Returns: + dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes. + """ + query = ( + "SELECT prev_pdu_id, prev_origin, algorithm, hash" + " FROM pdu_edge_hashes" + " WHERE pdu_id = ? and origin = ?" + ) + txn.execute(query, (pdu_id, origin)) + results = {} + for prev_pdu_id, prev_origin, algorithm, hash_bytes in txn.fetchall(): + hashes = results.setdefault((prev_pdu_id, prev_origin), {}) + hashes[algorithm] = hash_bytes + return results + + def _store_prev_pdu_hash_txn(self, txn, pdu_id, origin, prev_pdu_id, + prev_origin, algorithm, hash_bytes): + self._simple_insert_txn(txn, "pdu_edge_hashes", { + "pdu_id": pdu_id, + "origin": origin, + "prev_pdu_id": prev_pdu_id, + "prev_origin": prev_origin, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }) + + ## Events ## + + def _get_event_content_hashes_txn(self, txn, event_id): + """Get all the hashes for a given Event. + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + A dict of algorithm -> hash. + """ + query = ( + "SELECT algorithm, hash" + " FROM event_content_hashes" + " WHERE event_id = ?" + ) + txn.execute(query, (event_id, )) + return dict(txn.fetchall()) + + def _store_event_content_hash_txn(self, txn, event_id, algorithm, + hash_bytes): + """Store a hash for a Event + Args: + txn (cursor): + event_id (str): Id for the Event. + algorithm (str): Hashing algorithm. + hash_bytes (bytes): Hash function output bytes. + """ + self._simple_insert_txn( + txn, + "event_content_hashes", + { + "event_id": event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) + + def _get_event_reference_hashes_txn(self, txn, event_id): + """Get all the hashes for a given PDU. + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + A dict of algorithm -> hash. + """ + query = ( + "SELECT algorithm, hash" + " FROM event_reference_hashes" + " WHERE event_id = ?" + ) + txn.execute(query, (event_id, )) + return dict(txn.fetchall()) + + def _store_event_reference_hash_txn(self, txn, event_id, algorithm, + hash_bytes): + """Store a hash for a PDU + Args: + txn (cursor): + event_id (str): Id for the Event. + algorithm (str): Hashing algorithm. + hash_bytes (bytes): Hash function output bytes. + """ + self._simple_insert_txn( + txn, + "event_reference_hashes", + { + "event_id": event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) + + + def _get_event_origin_signatures_txn(self, txn, event_id): + """Get all the signatures for a given PDU. + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + A dict of key_id -> signature_bytes. + """ + query = ( + "SELECT key_id, signature" + " FROM event_origin_signatures" + " WHERE event_id = ? " + ) + txn.execute(query, (event_id, )) + return dict(txn.fetchall()) + + def _store_event_origin_signature_txn(self, txn, event_id, origin, key_id, + signature_bytes): + """Store a signature from the origin server for a PDU. + Args: + txn (cursor): + event_id (str): Id for the Event. + origin (str): origin of the Event. + key_id (str): Id for the signing key. + signature (bytes): The signature. + """ + self._simple_insert_txn( + txn, + "event_origin_signatures", + { + "event_id": event_id, + "origin": origin, + "key_id": key_id, + "signature": buffer(signature_bytes), + }, + or_ignore=True, + ) + + def _get_prev_event_hashes_txn(self, txn, event_id): + """Get all the hashes for previous PDUs of a PDU + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes. + """ + query = ( + "SELECT prev_event_id, algorithm, hash" + " FROM event_edge_hashes" + " WHERE event_id = ?" + ) + txn.execute(query, (event_id, )) + results = {} + for prev_event_id, algorithm, hash_bytes in txn.fetchall(): + hashes = results.setdefault(prev_event_id, {}) + hashes[algorithm] = hash_bytes + return results + + def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id, + algorithm, hash_bytes): + self._simple_insert_txn( + txn, + "event_edge_hashes", + { + "event_id": event_id, + "prev_event_id": prev_event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) \ No newline at end of file diff --git a/synapse/storage/state.py b/synapse/storage/state.py new file mode 100644 index 0000000000..e08acd6404 --- /dev/null +++ b/synapse/storage/state.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from twisted.internet import defer + +from collections import namedtuple + + +StateGroup = namedtuple("StateGroup", ("group", "state")) + + +class StateStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_state_groups(self, event_ids): + groups = set() + for event_id in event_ids: + group = yield self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + ) + if group: + groups.add(group) + + res = [] + for group in groups: + state_ids = yield self._simple_select_onecol( + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + state = [] + for state_id in state_ids: + s = yield self.get_event( + state_id, + allow_none=True, + ) + if s: + state.append(s) + + res.append(StateGroup(group, state)) + + defer.returnValue(res) + + def store_state_groups(self, event): + return self.runInteraction( + "store_state_groups", + self._store_state_groups_txn, event + ) + + def _store_state_groups_txn(self, txn, event): + if not event.state_events: + return + + state_group = event.state_group + if not state_group: + state_group = self._simple_insert_txn( + txn, + table="state_groups", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + } + ) + + for state in event.state_events.values(): + self._simple_insert_txn( + txn, + table="state_groups_state", + values={ + "state_group": state_group, + "room_id": state.room_id, + "type": state.type, + "state_key": state.state_key, + "event_id": state.event_id, + } + ) + + self._simple_insert_txn( + txn, + table="event_to_state_groups", + values={ + "state_group": state_group, + "event_id": event.event_id, + } + ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d61f909939..8f7f61d29d 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -309,7 +309,10 @@ class StreamStore(SQLBaseStore): defer.returnValue(ret) def get_room_events_max_id(self): - return self.runInteraction(self._get_room_events_max_id_txn) + return self.runInteraction( + "get_room_events_max_id", + self._get_room_events_max_id_txn + ) def _get_room_events_max_id_txn(self, txn): txn.execute( diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 2ba8e30efe..908014d38b 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -42,6 +42,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( + "get_received_txn_response", self._get_received_txn_response, transaction_id, origin ) @@ -73,6 +74,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( + "set_received_txn_response", self._set_received_txn_response, transaction_id, origin, code, response_dict ) @@ -106,6 +108,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( + "prep_send_transaction", self._prep_send_transaction, transaction_id, destination, origin_server_ts, pdu_list ) @@ -161,6 +164,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ return self.runInteraction( + "delivered_txn", self._delivered_txn, transaction_id, destination, code, response_dict ) @@ -186,6 +190,7 @@ class TransactionStore(SQLBaseStore): list: A list of `ReceivedTransactionsTable.EntryType` """ return self.runInteraction( + "get_transactions_after", self._get_transactions_after, transaction_id, destination ) @@ -216,6 +221,7 @@ class TransactionStore(SQLBaseStore): list: A list of PduTuple """ return self.runInteraction( + "get_pdus_after_transaction", self._get_pdus_after_transaction, transaction_id, destination ) |