diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 98 | ||||
-rw-r--r-- | synapse/storage/_base.py | 156 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 11 | ||||
-rw-r--r-- | synapse/storage/schema/im.sql | 13 | ||||
-rw-r--r-- | synapse/storage/schema/state.sql | 3 | ||||
-rw-r--r-- | synapse/storage/signatures.py | 19 | ||||
-rw-r--r-- | synapse/storage/state.py | 13 |
7 files changed, 154 insertions, 159 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index c9ab434b4e..cc1dcc2e74 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -15,12 +15,8 @@ from twisted.internet import defer -from synapse.api.events.room import ( - RoomMemberEvent, RoomTopicEvent, FeedbackEvent, RoomNameEvent, - RoomRedactionEvent, -) - from synapse.util.logutils import log_function +from synapse.api.constants import EventTypes from .directory import DirectoryStore from .feedback import FeedbackStore @@ -39,6 +35,7 @@ from .state import StateStore from .signatures import SignatureStore from syutil.base64util import decode_base64 +from syutil.jsonutil import encode_canonical_json from synapse.crypto.event_signing import compute_event_reference_hash @@ -89,7 +86,6 @@ class DataStore(RoomMemberStore, RoomStore, def __init__(self, hs): super(DataStore, self).__init__(hs) - self.event_factory = hs.get_event_factory() self.hs = hs self.min_token_deferred = self._get_min_token() @@ -97,8 +93,8 @@ class DataStore(RoomMemberStore, RoomStore, @defer.inlineCallbacks @log_function - def persist_event(self, event, backfilled=False, is_new_state=True, - current_state=None): + def persist_event(self, event, context, backfilled=False, + is_new_state=True, current_state=None): stream_ordering = None if backfilled: if not self.min_token_deferred.called: @@ -111,6 +107,7 @@ class DataStore(RoomMemberStore, RoomStore, "persist_event", self._persist_event_txn, event=event, + context=context, backfilled=backfilled, stream_ordering=stream_ordering, is_new_state=is_new_state, @@ -121,50 +118,66 @@ class DataStore(RoomMemberStore, RoomStore, @defer.inlineCallbacks def get_event(self, event_id, allow_none=False): - events_dict = yield self._simple_select_one( - "events", - {"event_id": event_id}, - [ - "event_id", - "type", - "room_id", - "content", - "unrecognized_keys", - "depth", - ], - allow_none=allow_none, - ) + events = yield self._get_events([event_id]) - if not events_dict: - defer.returnValue(None) + if not events: + if allow_none: + defer.returnValue(None) + else: + raise RuntimeError("Could not find event %s" % (event_id,)) - event = yield self._parse_events([events_dict]) - defer.returnValue(event[0]) + defer.returnValue(events[0]) @log_function - def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, - is_new_state=True, current_state=None): - if event.type == RoomMemberEvent.TYPE: + def _persist_event_txn(self, txn, event, context, backfilled, + stream_ordering=None, is_new_state=True, + current_state=None): + if event.type == EventTypes.Member: self._store_room_member_txn(txn, event) - elif event.type == FeedbackEvent.TYPE: + elif event.type == EventTypes.Feedback: self._store_feedback_txn(txn, event) - elif event.type == RoomNameEvent.TYPE: + elif event.type == EventTypes.Name: self._store_room_name_txn(txn, event) - elif event.type == RoomTopicEvent.TYPE: + elif event.type == EventTypes.Topic: self._store_room_topic_txn(txn, event) - elif event.type == RoomRedactionEvent.TYPE: + elif event.type == EventTypes.Redaction: self._store_redaction(txn, event) outlier = False - if hasattr(event, "outlier"): - outlier = event.outlier + if hasattr(event.internal_metadata, "outlier"): + outlier = event.internal_metadata.outlier + + event_dict = { + k: v + for k, v in event.get_dict().items() + if k not in [ + "redacted", + "redacted_because", + ] + } + + metadata_json = encode_canonical_json( + event.internal_metadata.get_dict() + ) + + self._simple_insert_txn( + txn, + table="event_json", + values={ + "event_id": event.event_id, + "room_id": event.room_id, + "internal_metadata": metadata_json.decode("UTF-8"), + "json": encode_canonical_json(event_dict).decode("UTF-8"), + }, + or_replace=True, + ) vals = { "topological_ordering": event.depth, "event_id": event.event_id, "type": event.type, "room_id": event.room_id, - "content": json.dumps(event.content), + "content": json.dumps(event.get_dict()["content"]), "processed": True, "outlier": outlier, "depth": event.depth, @@ -175,7 +188,7 @@ class DataStore(RoomMemberStore, RoomStore, unrec = { k: v - for k, v in event.get_full_dict().items() + for k, v in event.get_dict().items() if k not in vals.keys() and k not in [ "redacted", "redacted_because", @@ -210,7 +223,8 @@ class DataStore(RoomMemberStore, RoomStore, room_id=event.room_id, ) - self._store_state_groups_txn(txn, event) + if not outlier: + self._store_state_groups_txn(txn, event, context) if current_state: txn.execute( @@ -304,16 +318,6 @@ class DataStore(RoomMemberStore, RoomStore, txn, event.event_id, hash_alg, hash_bytes, ) - if hasattr(event, "signatures"): - logger.debug("sigs: %s", event.signatures) - for name, sigs in event.signatures.items(): - for key_id, signature_base64 in sigs.items(): - signature_bytes = decode_base64(signature_base64) - self._store_event_signature_txn( - txn, event.event_id, name, key_id, - signature_bytes, - ) - for prev_event_id, prev_hashes in event.prev_events: for alg, hash_base64 in prev_hashes.items(): hash_bytes = decode_base64(hash_base64) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e72200e2f7..6dc857c4aa 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,15 +15,14 @@ import logging from synapse.api.errors import StoreError -from synapse.api.events.utils import prune_event +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext, LoggingContext -from syutil.base64util import encode_base64 from twisted.internet import defer import collections -import copy import json import sys import time @@ -84,7 +83,6 @@ class SQLBaseStore(object): def __init__(self, hs): self.hs = hs self._db_pool = hs.get_db_pool() - self.event_factory = hs.get_event_factory() self._clock = hs.get_clock() @defer.inlineCallbacks @@ -436,123 +434,77 @@ class SQLBaseStore(object): return self.runInteraction("_simple_max_id", func) - def _parse_event_from_row(self, row_dict): - d = copy.deepcopy({k: v for k, v in row_dict.items()}) - - d.pop("stream_ordering", None) - d.pop("topological_ordering", None) - d.pop("processed", None) - d["origin_server_ts"] = d.pop("ts", 0) - replaces_state = d.pop("prev_state", None) + def _get_events(self, event_ids): + return self.runInteraction( + "_get_events", self._get_events_txn, event_ids + ) - if replaces_state: - d["replaces_state"] = replaces_state + def _get_events_txn(self, txn, event_ids): + events = [] + for e_id in event_ids: + ev = self._get_event_txn(txn, e_id) - d.update(json.loads(row_dict["unrecognized_keys"])) - d["content"] = json.loads(d["content"]) - del d["unrecognized_keys"] + if ev: + events.append(ev) - if "age_ts" not in d: - # For compatibility - d["age_ts"] = d.get("origin_server_ts", 0) + return events - return self.event_factory.create_event( - etype=d["type"], - **d + def _get_event_txn(self, txn, event_id, check_redacted=True, + get_prev_content=True): + sql = ( + "SELECT internal_metadata, json, r.event_id FROM event_json as e " + "LEFT JOIN redactions as r ON e.event_id = r.redacts " + "WHERE e.event_id = ? " + "LIMIT 1 " ) - def _get_events_txn(self, txn, event_ids): - # FIXME (erikj): This should be batched? + txn.execute(sql, (event_id,)) - sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc" + res = txn.fetchone() - event_rows = [] - for e_id in event_ids: - c = txn.execute(sql, (e_id,)) - event_rows.extend(self.cursor_to_dict(c)) + if not res: + return None - return self._parse_events_txn(txn, event_rows) + internal_metadata, js, redacted = res - def _parse_events(self, rows): - return self.runInteraction( - "_parse_events", self._parse_events_txn, rows - ) + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) - def _parse_events_txn(self, txn, rows): - events = [self._parse_event_from_row(r) for r in rows] + ev = FrozenEvent(d, internal_metadata_dict=internal_metadata) - select_event_sql = ( - "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc" - ) + if check_redacted and redacted: + ev = prune_event(ev) - for i, ev in enumerate(events): - signatures = self._get_event_signatures_txn( - txn, ev.event_id, + ev.unsigned["redacted_by"] = redacted + # Get the redaction event. + + because = self._get_event_txn( + txn, + redacted, + check_redacted=False ) - ev.signatures = { - n: { - k: encode_base64(v) for k, v in s.items() - } - for n, s in signatures.items() - } + if because: + ev.unsigned["redacted_because"] = because - hashes = self._get_event_content_hashes_txn( - txn, ev.event_id, - ) + if get_prev_content and "replaces_state" in ev.unsigned: + ev.unsigned["prev_content"] = self._get_event_txn( + txn, + ev.unsigned["replaces_state"], + get_prev_content=False, + ).get_dict()["content"] - ev.hashes = { - k: encode_base64(v) for k, v in hashes.items() - } - - prevs = self._get_prev_events_and_state(txn, ev.event_id) - - ev.prev_events = [ - (e_id, h) - for e_id, h, is_state in prevs - if is_state == 0 - ] - - ev.auth_events = self._get_auth_events(txn, ev.event_id) - - if hasattr(ev, "state_key"): - ev.prev_state = [ - (e_id, h) - for e_id, h, is_state in prevs - if is_state == 1 - ] - - if hasattr(ev, "replaces_state"): - # Load previous state_content. - # FIXME (erikj): Handle multiple prev_states. - cursor = txn.execute( - select_event_sql, - (ev.replaces_state,) - ) - prevs = self.cursor_to_dict(cursor) - if prevs: - prev = self._parse_event_from_row(prevs[0]) - ev.prev_content = prev.content - - if not hasattr(ev, "redacted"): - logger.debug("Doesn't have redacted key: %s", ev) - ev.redacted = self._has_been_redacted_txn(txn, ev) - - if ev.redacted: - # Get the redaction event. - select_event_sql = "SELECT * FROM events WHERE event_id = ?" - txn.execute(select_event_sql, (ev.redacted,)) - - del_evs = self._parse_events_txn( - txn, self.cursor_to_dict(txn) - ) + return ev - if del_evs: - ev = prune_event(ev) - events[i] = ev - ev.redacted_because = del_evs[0] + def _parse_events(self, rows): + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) - return events + def _parse_events_txn(self, txn, rows): + event_ids = [r["event_id"] for r in rows] + + return self._get_events_txn(txn, event_ids) def _has_been_redacted_txn(self, txn, event): sql = "SELECT event_id FROM redactions WHERE redacts = ?" diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 6c559f8f63..ced066f407 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -177,14 +177,15 @@ class EventFederationStore(SQLBaseStore): retcols=["prev_event_id", "is_state"], ) + hashes = self._get_prev_event_hashes_txn(txn, event_id) + results = [] for d in res: - hashes = self._get_event_reference_hashes_txn( - txn, - d["prev_event_id"] - ) + edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"]) + edge_hash.update(hashes.get(d["prev_event_id"], {})) prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() + k: encode_base64(v) + for k, v in edge_hash.items() if k == "sha256" } results.append((d["prev_event_id"], prev_hashes, d["is_state"])) diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql index 8ba732a23b..253f9f779b 100644 --- a/synapse/storage/schema/im.sql +++ b/synapse/storage/schema/im.sql @@ -32,6 +32,19 @@ CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering); CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering); CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id); + +CREATE TABLE IF NOT EXISTS event_json( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + internal_metadata NOT NULL, + json BLOB NOT NULL, + CONSTRAINT ev_j_uniq UNIQUE (event_id) +); + +CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id); +CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id); + + CREATE TABLE IF NOT EXISTS state_events( event_id TEXT NOT NULL, room_id TEXT NOT NULL, diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql index 44f7aafb27..2c48d6daca 100644 --- a/synapse/storage/schema/state.sql +++ b/synapse/storage/schema/state.sql @@ -29,7 +29,8 @@ CREATE TABLE IF NOT EXISTS state_groups_state( CREATE TABLE IF NOT EXISTS event_to_state_groups( event_id TEXT NOT NULL, - state_group INTEGER NOT NULL + state_group INTEGER NOT NULL, + CONSTRAINT event_to_state_groups_uniq UNIQUE (event_id) ); CREATE INDEX IF NOT EXISTS state_groups_id ON state_groups(id); diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index eea4f21065..3a705119fd 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -13,8 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from _base import SQLBaseStore +from syutil.base64util import encode_base64 + class SignatureStore(SQLBaseStore): """Persistence for event signatures and hashes""" @@ -67,6 +71,21 @@ class SignatureStore(SQLBaseStore): f ) + @defer.inlineCallbacks + def add_event_hashes(self, event_ids): + hashes = yield self.get_event_reference_hashes( + event_ids + ) + hashes = [ + { + k: encode_base64(v) for k, v in h.items() + if k == "sha256" + } + for h in hashes + ] + + defer.returnValue(zip(event_ids, hashes)) + def _get_event_reference_hashes_txn(self, txn, event_id): """Get all the hashes for a given PDU. Args: diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e0f44b3e59..afe3e5edea 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -86,11 +86,16 @@ class StateStore(SQLBaseStore): self._store_state_groups_txn, event ) - def _store_state_groups_txn(self, txn, event): - if event.state_events is None: + def _store_state_groups_txn(self, txn, event, context): + if context.current_state is None: return - state_group = event.state_group + state_events = context.current_state + + if event.is_state(): + state_events[(event.type, event.state_key)] = event + + state_group = context.state_group if not state_group: state_group = self._simple_insert_txn( txn, @@ -102,7 +107,7 @@ class StateStore(SQLBaseStore): or_ignore=True, ) - for state in event.state_events.values(): + for state in state_events.values(): self._simple_insert_txn( txn, table="state_groups_state", |