diff options
Diffstat (limited to 'synapse')
39 files changed, 1996 insertions, 1708 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e1b1823cd7..c684265101 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -21,6 +21,8 @@ from synapse.api.constants import Membership, JoinRules from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.events.room import ( RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent, + RoomJoinRulesEvent, RoomOpsPowerLevelsEvent, InviteJoinEvent, + RoomCreateEvent, ) from synapse.util.logutils import log_function @@ -47,42 +49,60 @@ class Auth(object): """ try: if hasattr(event, "room_id"): + if event.old_state_events is None: + # Oh, we don't know what the state of the room was, so we + # are trusting that this is allowed (at least for now) + defer.returnValue(True) + + if hasattr(event, "outlier") and event.outlier is True: + # TODO (erikj): Auth for outliers is done differently. + defer.returnValue(True) + is_state = hasattr(event, "state_key") + if event.type == RoomCreateEvent.TYPE: + # FIXME + defer.returnValue(True) + if event.type == RoomMemberEvent.TYPE: - yield self._can_replace_state(event) - allowed = yield self.is_membership_change_allowed(event) + self._can_replace_state(event) + allowed = self.is_membership_change_allowed(event) + if allowed: + logger.debug("Allowing! %s", event) + else: + logger.debug("Denying! %s", event) defer.returnValue(allowed) return - self._check_joined_room( - member=snapshot.membership_state, - user_id=snapshot.user_id, - room_id=snapshot.room_id, - ) + if not event.type == InviteJoinEvent.TYPE: + self.check_event_sender_in_room(event) if is_state: # TODO (erikj): This really only should be called for *new* # state yield self._can_add_state(event) - yield self._can_replace_state(event) + self._can_replace_state(event) else: yield self._can_send_event(event) if event.type == RoomPowerLevelsEvent.TYPE: - yield self._check_power_levels(event) + self._check_power_levels(event) if event.type == RoomRedactionEvent.TYPE: - yield self._check_redaction(event) + self._check_redaction(event) + + logger.debug("Allowing! %s", event) defer.returnValue(True) else: raise AuthError(500, "Unknown event: %s" % event) except AuthError as e: logger.info("Event auth check failed on event %s with msg: %s", event, e.msg) + logger.info("Denying! %s", event) if raises: raise e + defer.returnValue(False) @defer.inlineCallbacks @@ -98,45 +118,72 @@ class Auth(object): pass defer.returnValue(None) + def check_event_sender_in_room(self, event): + key = (RoomMemberEvent.TYPE, event.user_id, ) + member_event = event.state_events.get(key) + + return self._check_joined_room( + member_event, + event.user_id, + event.room_id + ) + def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: raise AuthError(403, "User %s not in room %s (%s)" % ( user_id, room_id, repr(member) )) - @defer.inlineCallbacks + @log_function def is_membership_change_allowed(self, event): target_user_id = event.state_key - # does this room even exist - room = yield self.store.get_room(event.room_id) - if not room: - raise AuthError(403, "Room does not exist") - # get info about the caller - try: - caller = yield self.store.get_room_member( - user_id=event.user_id, - room_id=event.room_id) - except: - caller = None + key = (RoomMemberEvent.TYPE, event.user_id, ) + caller = event.old_state_events.get(key) + caller_in_room = caller and caller.membership == "join" # get info about the target - try: - target = yield self.store.get_room_member( - user_id=target_user_id, - room_id=event.room_id) - except: - target = None + key = (RoomMemberEvent.TYPE, target_user_id, ) + target = event.old_state_events.get(key) + target_in_room = target and target.membership == "join" membership = event.content["membership"] - join_rule = yield self.store.get_room_join_rule(event.room_id) - if not join_rule: + key = (RoomJoinRulesEvent.TYPE, "", ) + join_rule_event = event.old_state_events.get(key) + if join_rule_event: + join_rule = join_rule_event.content.get( + "join_rule", JoinRules.INVITE + ) + else: join_rule = JoinRules.INVITE + user_level = self._get_power_level_from_event_state( + event, + event.user_id, + ) + + ban_level, kick_level, redact_level = ( + self._get_ops_level_from_event_state( + event + ) + ) + + logger.debug( + "is_membership_change_allowed: %s", + { + "caller_in_room": caller_in_room, + "target_in_room": target_in_room, + "membership": membership, + "join_rule": join_rule, + "target_user_id": target_user_id, + "event.user_id": event.user_id, + } + ) + if Membership.INVITE == membership: # TODO (erikj): We should probably handle this more intelligently # PRIVATE join rules. @@ -153,13 +200,10 @@ class Auth(object): # joined: It's a NOOP if event.user_id != target_user_id: raise AuthError(403, "Cannot force another user to join.") - elif join_rule == JoinRules.PUBLIC or room.is_public: + elif join_rule == JoinRules.PUBLIC: pass elif join_rule == JoinRules.INVITE: - if ( - not caller or caller.membership not in - [Membership.INVITE, Membership.JOIN] - ): + if not caller_in_room: raise AuthError(403, "You are not invited to this room.") else: # TODO (erikj): may_join list @@ -171,29 +215,16 @@ class Auth(object): if not caller_in_room: # trying to leave a room you aren't joined raise AuthError(403, "You are not in room %s." % event.room_id) elif target_user_id != event.user_id: - user_level = yield self.store.get_power_level( - event.room_id, - event.user_id, - ) - _, kick_level, _ = yield self.store.get_ops_levels(event.room_id) - if kick_level: kick_level = int(kick_level) else: - kick_level = 50 + kick_level = 50 # FIXME (erikj): What should we do here? if user_level < kick_level: raise AuthError( 403, "You cannot kick user %s." % target_user_id ) elif Membership.BAN == membership: - user_level = yield self.store.get_power_level( - event.room_id, - event.user_id, - ) - - ban_level, _, _ = yield self.store.get_ops_levels(event.room_id) - if ban_level: ban_level = int(ban_level) else: @@ -204,7 +235,30 @@ class Auth(object): else: raise AuthError(500, "Unknown membership %s" % membership) - defer.returnValue(True) + return True + + def _get_power_level_from_event_state(self, event, user_id): + key = (RoomPowerLevelsEvent.TYPE, "", ) + power_level_event = event.old_state_events.get(key) + level = None + if power_level_event: + level = power_level_event.content.get(user_id) + if not level: + level = power_level_event.content.get("default", 0) + + return level + + def _get_ops_level_from_event_state(self, event): + key = (RoomOpsPowerLevelsEvent.TYPE, "", ) + ops_event = event.old_state_events.get(key) + + if ops_event: + return ( + ops_event.content.get("ban_level"), + ops_event.content.get("kick_level"), + ops_event.content.get("redact_level"), + ) + return None, None, None, @defer.inlineCallbacks def get_user_by_req(self, request): @@ -282,8 +336,8 @@ class Auth(object): else: send_level = 0 - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -308,8 +362,8 @@ class Auth(object): add_level = int(add_level) - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -322,19 +376,9 @@ class Auth(object): defer.returnValue(True) - @defer.inlineCallbacks def _can_replace_state(self, event): - current_state = yield self.store.get_current_state( - event.room_id, - event.type, - event.state_key, - ) - - if current_state: - current_state = current_state[0] - - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -346,6 +390,10 @@ class Auth(object): logger.debug( "Checking power level for %s, %s", event.user_id, user_level ) + + key = (event.type, event.state_key, ) + current_state = event.old_state_events.get(key) + if current_state and hasattr(current_state, "required_power_level"): req = current_state.required_power_level @@ -356,10 +404,9 @@ class Auth(object): "You don't have permission to change that state" ) - @defer.inlineCallbacks def _check_redaction(self, event): - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -368,7 +415,9 @@ class Auth(object): else: user_level = 0 - _, _, redact_level = yield self.store.get_ops_levels(event.room_id) + _, _, redact_level = self._get_ops_level_from_event_state( + event + ) if not redact_level: redact_level = 50 @@ -379,7 +428,6 @@ class Auth(object): "You don't have permission to redact events" ) - @defer.inlineCallbacks def _check_power_levels(self, event): for k, v in event.content.items(): if k == "default": @@ -399,19 +447,16 @@ class Auth(object): except: raise SynapseError(400, "Not a valid power level: %s" % (v,)) - current_state = yield self.store.get_current_state( - event.room_id, - event.type, - event.state_key, - ) + key = (event.type, event.state_key, ) + current_state = event.old_state_events.get(key) if not current_state: return else: current_state = current_state[0] - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py index f66fea2904..b855811b98 100644 --- a/synapse/api/events/__init__.py +++ b/synapse/api/events/__init__.py @@ -65,13 +65,15 @@ class SynapseEvent(JsonEncodedObject): internal_keys = [ "is_state", - "prev_events", "depth", "destinations", "origin", "outlier", "power_level", "redacted", + "prev_events", + "hashes", + "signatures", ] required_keys = [ diff --git a/synapse/api/events/factory.py b/synapse/api/events/factory.py index 74d0ef77f4..750096c618 100644 --- a/synapse/api/events/factory.py +++ b/synapse/api/events/factory.py @@ -21,6 +21,8 @@ from synapse.api.events.room import ( RoomRedactionEvent, ) +from synapse.types import EventID + from synapse.util.stringutils import random_string @@ -51,12 +53,22 @@ class EventFactory(object): self.clock = hs.get_clock() self.hs = hs + self.event_id_count = 0 + + def create_event_id(self): + i = str(self.event_id_count) + self.event_id_count += 1 + + local_part = str(int(self.clock.time())) + i + random_string(5) + + e_id = EventID.create_local(local_part, self.hs) + + return e_id.to_string() + def create_event(self, etype=None, **kwargs): kwargs["type"] = etype if "event_id" not in kwargs: - kwargs["event_id"] = "%s@%s" % ( - random_string(10), self.hs.hostname - ) + kwargs["event_id"] = self.create_event_id() if "origin_server_ts" not in kwargs: kwargs["origin_server_ts"] = int(self.clock.time_msec()) diff --git a/synapse/api/events/utils.py b/synapse/api/events/utils.py index c3a32be8c1..7fdf45a264 100644 --- a/synapse/api/events/utils.py +++ b/synapse/api/events/utils.py @@ -27,7 +27,14 @@ def prune_event(event): the user has specified, but we do want to keep necessary information like type, state_key etc. """ + return _prune_event_or_pdu(event.type, event) +def prune_pdu(pdu): + """Removes keys that contain unrestricted and non-essential data from a PDU + """ + return _prune_event_or_pdu(pdu.pdu_type, pdu) + +def _prune_event_or_pdu(event_type, event): # Remove all extraneous fields. event.unrecognized_keys = {} @@ -38,25 +45,25 @@ def prune_event(event): if field in event.content: new_content[field] = event.content[field] - if event.type == RoomMemberEvent.TYPE: + if event_type == RoomMemberEvent.TYPE: add_fields("membership") - elif event.type == RoomCreateEvent.TYPE: + elif event_type == RoomCreateEvent.TYPE: add_fields("creator") - elif event.type == RoomJoinRulesEvent.TYPE: + elif event_type == RoomJoinRulesEvent.TYPE: add_fields("join_rule") - elif event.type == RoomPowerLevelsEvent.TYPE: + elif event_type == RoomPowerLevelsEvent.TYPE: # TODO: Actually check these are valid user_ids etc. add_fields("default") for k, v in event.content.items(): if k.startswith("@") and isinstance(v, (int, long)): new_content[k] = v - elif event.type == RoomAddStateLevelEvent.TYPE: + elif event_type == RoomAddStateLevelEvent.TYPE: add_fields("level") - elif event.type == RoomSendEventLevelEvent.TYPE: + elif event_type == RoomSendEventLevelEvent.TYPE: add_fields("level") - elif event.type == RoomOpsPowerLevelsEvent.TYPE: + elif event_type == RoomOpsPowerLevelsEvent.TYPE: add_fields("kick_level", "ban_level", "redact_level") - elif event.type == RoomAliasesEvent.TYPE: + elif event_type == RoomAliasesEvent.TYPE: add_fields("aliases") event.content = new_content diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py new file mode 100644 index 0000000000..cb2db01c04 --- /dev/null +++ b/synapse/crypto/event_signing.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- + +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.federation.units import Pdu +from synapse.api.events.utils import prune_pdu, prune_event +from syutil.jsonutil import encode_canonical_json +from syutil.base64util import encode_base64, decode_base64 +from syutil.crypto.jsonsign import sign_json, verify_signed_json + +import copy +import hashlib +import logging + +logger = logging.getLogger(__name__) + + +def add_event_pdu_content_hash(pdu, hash_algorithm=hashlib.sha256): + hashed = _compute_content_hash(pdu, hash_algorithm) + pdu.hashes[hashed.name] = encode_base64(hashed.digest()) + return pdu + + +def check_event_pdu_content_hash(pdu, hash_algorithm=hashlib.sha256): + """Check whether the hash for this PDU matches the contents""" + computed_hash = _compute_content_hash(pdu, hash_algorithm) + if computed_hash.name not in pdu.hashes: + raise Exception("Algorithm %s not in hashes %s" % ( + computed_hash.name, list(pdu.hashes) + )) + message_hash_base64 = pdu.hashes[computed_hash.name] + try: + message_hash_bytes = decode_base64(message_hash_base64) + except: + raise Exception("Invalid base64: %s" % (message_hash_base64,)) + return message_hash_bytes == computed_hash.digest() + + +def _compute_content_hash(pdu, hash_algorithm): + pdu_json = pdu.get_dict() + #TODO: Make "age_ts" key internal + pdu_json.pop("age_ts", None) + pdu_json.pop("unsigned", None) + pdu_json.pop("signatures", None) + pdu_json.pop("hashes", None) + pdu_json_bytes = encode_canonical_json(pdu_json) + return hash_algorithm(pdu_json_bytes) + + +def compute_pdu_event_reference_hash(pdu, hash_algorithm=hashlib.sha256): + tmp_pdu = Pdu(**pdu.get_dict()) + tmp_pdu = prune_pdu(tmp_pdu) + pdu_json = tmp_pdu.get_dict() + pdu_json.pop("signatures", None) + pdu_json_bytes = encode_canonical_json(pdu_json) + hashed = hash_algorithm(pdu_json_bytes) + return (hashed.name, hashed.digest()) + + +def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): + tmp_event = copy.deepcopy(event) + tmp_event = prune_event(tmp_event) + event_json = tmp_event.get_dict() + event_json.pop("signatures", None) + event_json_bytes = encode_canonical_json(event_json) + hashed = hash_algorithm(event_json_bytes) + return (hashed.name, hashed.digest()) + + +def sign_event_pdu(pdu, signature_name, signing_key): + tmp_pdu = Pdu(**pdu.get_dict()) + tmp_pdu = prune_pdu(tmp_pdu) + pdu_json = tmp_pdu.get_dict() + pdu_json = sign_json(pdu_json, signature_name, signing_key) + pdu.signatures = pdu_json["signatures"] + return pdu + + +def verify_signed_event_pdu(pdu, signature_name, verify_key): + tmp_pdu = Pdu(**pdu.get_dict()) + tmp_pdu = prune_pdu(tmp_pdu) + pdu_json = tmp_pdu.get_dict() + verify_signed_json(pdu_json, signature_name, verify_key) + + +def add_hashes_and_signatures(event, signature_name, signing_key, + hash_algorithm=hashlib.sha256): + tmp_event = copy.deepcopy(event) + tmp_event = prune_event(tmp_event) + redact_json = tmp_event.get_dict() + redact_json.pop("signatures", None) + redact_json = sign_json(redact_json, signature_name, signing_key) + event.signatures = redact_json["signatures"] + + event_json = event.get_full_dict() + #TODO: We need to sign the JSON that is going out via fedaration. + event_json.pop("age_ts", None) + event_json.pop("unsigned", None) + event_json.pop("signatures", None) + event_json.pop("hashes", None) + event_json_bytes = encode_canonical_json(event_json) + hashed = hash_algorithm(event_json_bytes) + event.hashes[hashed.name] = encode_base64(hashed.digest()) diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py index e8180d94fd..d4c896e163 100644 --- a/synapse/federation/pdu_codec.py +++ b/synapse/federation/pdu_codec.py @@ -14,41 +14,43 @@ # limitations under the License. from .units import Pdu +from synapse.crypto.event_signing import ( + add_event_pdu_content_hash, sign_event_pdu +) +from synapse.types import EventID import copy -def decode_event_id(event_id, server_name): - parts = event_id.split("@") - if len(parts) < 2: - return (event_id, server_name) - else: - return (parts[0], "".join(parts[1:])) - - -def encode_event_id(pdu_id, origin): - return "%s@%s" % (pdu_id, origin) - - class PduCodec(object): def __init__(self, hs): + self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname self.event_factory = hs.get_event_factory() self.clock = hs.get_clock() + self.hs = hs + + def encode_event_id(self, local, domain): + return local + + def decode_event_id(self, event_id): + e_id = self.hs.parse_eventid(event_id) + return event_id, e_id.domain def event_from_pdu(self, pdu): kwargs = {} - kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin) + kwargs["event_id"] = self.encode_event_id(pdu.pdu_id, pdu.origin) kwargs["room_id"] = pdu.context kwargs["etype"] = pdu.pdu_type kwargs["prev_events"] = [ - encode_event_id(p[0], p[1]) for p in pdu.prev_pdus + (self.encode_event_id(i, o), s) + for i, o, s in pdu.prev_pdus ] if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"): - kwargs["prev_state"] = encode_event_id( + kwargs["prev_state"] = self.encode_event_id( pdu.prev_state_id, pdu.prev_state_origin ) @@ -70,21 +72,24 @@ class PduCodec(object): def pdu_from_event(self, event): d = event.get_full_dict() - d["pdu_id"], d["origin"] = decode_event_id( - event.event_id, self.server_name + d["pdu_id"], d["origin"] = self.decode_event_id( + event.event_id ) d["context"] = event.room_id d["pdu_type"] = event.type if hasattr(event, "prev_events"): + def f(e, s): + i, o = self.decode_event_id(e) + return i, o, s d["prev_pdus"] = [ - decode_event_id(e, self.server_name) - for e in event.prev_events + f(e, s) + for e, s in event.prev_events ] if hasattr(event, "prev_state"): d["prev_state_id"], d["prev_state_origin"] = ( - decode_event_id(event.prev_state, self.server_name) + self.decode_event_id(event.prev_state) ) if hasattr(event, "state_key"): @@ -99,4 +104,6 @@ class PduCodec(object): if "origin_server_ts" not in kwargs: kwargs["origin_server_ts"] = int(self.clock.time_msec()) - return Pdu(**kwargs) + pdu = Pdu(**kwargs) + pdu = add_event_pdu_content_hash(pdu) + return sign_event_pdu(pdu, self.server_name, self.signing_key) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 7043fcc504..a565375e68 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -32,76 +32,6 @@ import logging logger = logging.getLogger(__name__) -class PduActions(object): - """ Defines persistence actions that relate to handling PDUs. - """ - - def __init__(self, datastore): - self.store = datastore - - @log_function - def mark_as_processed(self, pdu): - """ Persist the fact that we have fully processed the given `Pdu` - - Returns: - Deferred - """ - return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin) - - @defer.inlineCallbacks - @log_function - def after_transaction(self, transaction_id, destination, origin): - """ Returns all `Pdu`s that we sent to the given remote home server - after a given transaction id. - - Returns: - Deferred: Results in a list of `Pdu`s - """ - results = yield self.store.get_pdus_after_transaction( - transaction_id, - destination - ) - - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @defer.inlineCallbacks - @log_function - def get_all_pdus_from_context(self, context): - results = yield self.store.get_all_pdus_from_context(context) - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @defer.inlineCallbacks - @log_function - def backfill(self, context, pdu_list, limit): - """ For a given list of PDU id and origins return the proceeding - `limit` `Pdu`s in the given `context`. - - Returns: - Deferred: Results in a list of `Pdu`s. - """ - results = yield self.store.get_backfill( - context, pdu_list, limit - ) - - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @log_function - def is_new(self, pdu): - """ When we receive a `Pdu` from a remote home server, we want to - figure out whether it is `new`, i.e. it is not some historic PDU that - we haven't seen simply because we haven't backfilled back that far. - - Returns: - Deferred: Results in a `bool` - """ - return self.store.is_pdu_new( - pdu_id=pdu.pdu_id, - origin=pdu.origin, - context=pdu.context, - depth=pdu.depth - ) - - class TransactionActions(object): """ Defines persistence actions that relate to handling Transactions. """ diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 092411eaf9..159af4eed7 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -21,7 +21,7 @@ from twisted.internet import defer from .units import Transaction, Pdu, Edu -from .persistence import PduActions, TransactionActions +from .persistence import TransactionActions from synapse.util.logutils import log_function @@ -57,7 +57,7 @@ class ReplicationLayer(object): self.transport_layer.register_request_handler(self) self.store = hs.get_datastore() - self.pdu_actions = PduActions(self.store) + # self.pdu_actions = PduActions(self.store) self.transaction_actions = TransactionActions(self.store) self._transaction_queue = _TransactionQueue( @@ -106,7 +106,6 @@ class ReplicationLayer(object): self.query_handlers[query_type] = handler - @defer.inlineCallbacks @log_function def send_pdu(self, pdu): """Informs the replication layer about a new PDU generated within the @@ -135,7 +134,7 @@ class ReplicationLayer(object): logger.debug("[%s] Persisting PDU", pdu.pdu_id) # Save *before* trying to send - yield self.store.persist_event(pdu=pdu) + # yield self.store.persist_event(pdu=pdu) logger.debug("[%s] Persisted PDU", pdu.pdu_id) logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id) @@ -181,7 +180,7 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def backfill(self, dest, context, limit): + def backfill(self, dest, context, limit, extremities): """Requests some more historic PDUs for the given context from the given destination server. @@ -189,12 +188,12 @@ class ReplicationLayer(object): dest (str): The remote home server to ask. context (str): The context to backfill. limit (int): The maximum number of PDUs to return. + extremities (list): List of PDU id and origins of the first pdus + we have seen from the context Returns: Deferred: Results in the received PDUs. """ - extremities = yield self.store.get_oldest_pdus_in_context(context) - logger.debug("backfill extrem=%s", extremities) # If there are no extremeties then we've (probably) reached the start. @@ -244,13 +243,14 @@ class ReplicationLayer(object): pdu = None if pdu_list: pdu = pdu_list[0] - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdu) @defer.inlineCallbacks @log_function - def get_state_for_context(self, destination, context): + def get_state_for_context(self, destination, context, pdu_id=None, + pdu_origin=None): """Requests all of the `current` state PDUs for a given context from a remote home server. @@ -263,29 +263,30 @@ class ReplicationLayer(object): """ transaction_data = yield self.transport_layer.get_context_state( - destination, context) + destination, context, pdu_id=pdu_id, pdu_origin=pdu_origin, + ) transaction = Transaction(**transaction_data) pdus = [Pdu(outlier=True, **p) for p in transaction.pdus] for pdu in pdus: - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdus) @defer.inlineCallbacks @log_function def on_context_pdus_request(self, context): - pdus = yield self.pdu_actions.get_all_pdus_from_context( - context + raise NotImplementedError( + "on_context_pdus_request is a security violation" ) - defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @defer.inlineCallbacks @log_function def on_backfill_request(self, context, versions, limit): - - pdus = yield self.pdu_actions.backfill(context, versions, limit) + pdus = yield self.handler.on_backfill_request( + context, versions, limit + ) defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @@ -295,6 +296,10 @@ class ReplicationLayer(object): transaction = Transaction(**transaction_data) for p in transaction.pdus: + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] if "age" in p: p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) del p["age"] @@ -315,7 +320,7 @@ class ReplicationLayer(object): dl = [] for pdu in pdu_list: - dl.append(self._handle_new_pdu(pdu)) + dl.append(self._handle_new_pdu(transaction.origin, pdu)) if hasattr(transaction, "edus"): for edu in [Edu(**x) for x in transaction.edus]: @@ -347,14 +352,20 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def on_context_state_request(self, context): - results = yield self.store.get_current_state_for_context( - context - ) - - logger.debug("Context returning %d results", len(results)) + def on_context_state_request(self, context, pdu_id, pdu_origin): + if pdu_id and pdu_origin: + pdus = yield self.handler.get_state_for_pdu( + pdu_id, pdu_origin + ) + else: + raise NotImplementedError("Specify an event") + # results = yield self.store.get_current_state_for_context( + # context + # ) + # pdus = [Pdu.from_pdu_tuple(p) for p in results] + # + # logger.debug("Context returning %d results", len(pdus)) - pdus = [Pdu.from_pdu_tuple(p) for p in results] defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @defer.inlineCallbacks @@ -372,20 +383,22 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function def on_pull_request(self, origin, versions): - transaction_id = max([int(v) for v in versions]) - - response = yield self.pdu_actions.after_transaction( - transaction_id, - origin, - self.server_name - ) - - if not response: - response = [] - - defer.returnValue( - (200, self._transaction_from_pdus(response).get_dict()) - ) + raise NotImplementedError("Pull transacions not implemented") + + # transaction_id = max([int(v) for v in versions]) + # + # response = yield self.pdu_actions.after_transaction( + # transaction_id, + # origin, + # self.server_name + # ) + # + # if not response: + # response = [] + # + # defer.returnValue( + # (200, self._transaction_from_pdus(response).get_dict()) + # ) @defer.inlineCallbacks def on_query_request(self, query_type, args): @@ -393,11 +406,56 @@ class ReplicationLayer(object): response = yield self.query_handlers[query_type](args) defer.returnValue((200, response)) else: - defer.returnValue((404, "No handler for Query type '%s'" - % (query_type) - )) + defer.returnValue( + (404, "No handler for Query type '%s'" % (query_type, )) + ) + + @defer.inlineCallbacks + def on_make_join_request(self, context, user_id): + pdu = yield self.handler.on_make_join_request(context, user_id) + defer.returnValue(pdu.get_dict()) + + @defer.inlineCallbacks + def on_invite_request(self, origin, content): + pdu = Pdu(**content) + ret_pdu = yield self.handler.on_send_join_request(origin, pdu) + defer.returnValue((200, ret_pdu.get_dict())) @defer.inlineCallbacks + def on_send_join_request(self, origin, content): + pdu = Pdu(**content) + state = yield self.handler.on_send_join_request(origin, pdu) + defer.returnValue((200, self._transaction_from_pdus(state).get_dict())) + + @defer.inlineCallbacks + def make_join(self, destination, context, user_id): + pdu_dict = yield self.transport_layer.make_join( + destination=destination, + context=context, + user_id=user_id, + ) + + logger.debug("Got response to make_join: %s", pdu_dict) + + defer.returnValue(Pdu(**pdu_dict)) + + @defer.inlineCallbacks + def send_join(self, destination, pdu): + _, content = yield self.transport_layer.send_join( + destination, + pdu.context, + pdu.pdu_id, + pdu.origin, + pdu.get_dict(), + ) + + logger.debug("Got content: %s", content) + pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])] + for pdu in pdus: + yield self._handle_new_pdu(destination, pdu) + + defer.returnValue(pdus) + @log_function def _get_persisted_pdu(self, pdu_id, pdu_origin): """ Get a PDU from the database with given origin and id. @@ -405,29 +463,29 @@ class ReplicationLayer(object): Returns: Deferred: Results in a `Pdu`. """ - pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin) - - defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple)) + return self.handler.get_persisted_pdu(pdu_id, pdu_origin) def _transaction_from_pdus(self, pdu_list): """Returns a new Transaction containing the given PDUs suitable for transmission. """ pdus = [p.get_dict() for p in pdu_list] + time_now = self._clock.time_msec() for p in pdus: - if "age_ts" in pdus: - p["age"] = int(self.clock.time_msec()) - p["age_ts"] - + if "age_ts" in p: + age = time_now - p["age_ts"] + p.setdefault("unsigned", {})["age"] = int(age) + del p["age_ts"] return Transaction( origin=self.server_name, pdus=pdus, - origin_server_ts=int(self._clock.time_msec()), + origin_server_ts=int(time_now), destination=None, ) @defer.inlineCallbacks @log_function - def _handle_new_pdu(self, pdu, backfilled=False): + def _handle_new_pdu(self, origin, pdu, backfilled=False): # We reprocess pdus when we have seen them only as outliers existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin) @@ -436,14 +494,17 @@ class ReplicationLayer(object): defer.returnValue({}) return + state = None + # Get missing pdus if necessary. - is_new = yield self.pdu_actions.is_new(pdu) - if is_new and not pdu.outlier: + if not pdu.outlier: # We only backfill backwards to the min depth. - min_depth = yield self.store.get_min_depth_for_context(pdu.context) + min_depth = yield self.handler.get_min_depth_for_context( + pdu.context + ) if min_depth and pdu.depth > min_depth: - for pdu_id, origin in pdu.prev_pdus: + for pdu_id, origin, hashes in pdu.prev_pdus: exists = yield self._get_persisted_pdu(pdu_id, origin) if not exists: @@ -459,16 +520,26 @@ class ReplicationLayer(object): except: # TODO(erikj): Do some more intelligent retries. logger.exception("Failed to get PDU") + else: + # We need to get the state at this event, since we have reached + # a backward extremity edge. + state = yield self.get_state_for_context( + origin, pdu.context, pdu.pdu_id, pdu.origin, + ) # Persist the Pdu, but don't mark it as processed yet. - yield self.store.persist_event(pdu=pdu) + # yield self.store.persist_event(pdu=pdu) if not backfilled: - ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled) + ret = yield self.handler.on_receive_pdu( + pdu, + backfilled=backfilled, + state=state, + ) else: ret = None - yield self.pdu_actions.mark_as_processed(pdu) + # yield self.pdu_actions.mark_as_processed(pdu) defer.returnValue(ret) @@ -589,7 +660,7 @@ class _TransactionQueue(object): logger.debug("TX [%s] Persisting transaction...", destination) transaction = Transaction.create_new( - origin_server_ts=self._clock.time_msec(), + origin_server_ts=int(self._clock.time_msec()), transaction_id=str(self._next_txn_id), origin=self.server_name, destination=destination, @@ -614,7 +685,9 @@ class _TransactionQueue(object): if "pdus" in data: for p in data["pdus"]: if "age_ts" in p: - p["age"] = now - int(p["age_ts"]) + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] return data code, response = yield self.transport_layer.send_transaction( diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py index e7517cac4d..7f01b4faaf 100644 --- a/synapse/federation/transport.py +++ b/synapse/federation/transport.py @@ -72,7 +72,8 @@ class TransportLayer(object): self.received_handler = None @log_function - def get_context_state(self, destination, context): + def get_context_state(self, destination, context, pdu_id=None, + pdu_origin=None): """ Requests all state for a given context (i.e. room) from the given server. @@ -89,7 +90,14 @@ class TransportLayer(object): subpath = "/state/%s/" % context - return self._do_request_for_transaction(destination, subpath) + args = {} + if pdu_id and pdu_origin: + args["pdu_id"] = pdu_id + args["pdu_origin"] = pdu_origin + + return self._do_request_for_transaction( + destination, subpath, args=args + ) @log_function def get_pdu(self, destination, pdu_origin, pdu_id): @@ -135,8 +143,10 @@ class TransportLayer(object): subpath = "/backfill/%s/" % context - args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]} - args["limit"] = limit + args = { + "v": ["%s,%s" % (i, o) for i, o in pdu_tuples], + "limit": limit, + } return self._do_request_for_transaction( dest, @@ -198,6 +208,59 @@ class TransportLayer(object): defer.returnValue(response) @defer.inlineCallbacks + @log_function + def make_join(self, destination, context, user_id, retry_on_dns_fail=True): + path = PREFIX + "/make_join/%s/%s" % (context, user_id,) + + response = yield self.client.get_json( + destination=destination, + path=path, + retry_on_dns_fail=retry_on_dns_fail, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks + @log_function + def send_join(self, destination, context, pdu_id, origin, content): + path = PREFIX + "/send_join/%s/%s/%s" % ( + context, + origin, + pdu_id, + ) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_join", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks + @log_function + def send_invite(self, destination, context, pdu_id, origin, content): + path = PREFIX + "/invite/%s/%s/%s" % ( + context, + origin, + pdu_id, + ) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_invite", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks def _authenticate_request(self, request): json_request = { "method": request.method, @@ -326,7 +389,11 @@ class TransportLayer(object): re.compile("^" + PREFIX + "/state/([^/]*)/$"), self._with_authentication( lambda origin, content, query, context: - handler.on_context_state_request(context) + handler.on_context_state_request( + context, + query.get("pdu_id", [None])[0], + query.get("pdu_origin", [None])[0] + ) ) ) @@ -362,6 +429,39 @@ class TransportLayer(object): ) ) + self.server.register_path( + "GET", + re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, user_id: + self._on_make_join_request( + origin, content, query, context, user_id + ) + ) + ) + + self.server.register_path( + "PUT", + re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, pdu_origin, pdu_id: + self._on_send_join_request( + origin, content, query, + ) + ) + ) + + self.server.register_path( + "PUT", + re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, pdu_origin, pdu_id: + self._on_invite_request( + origin, content, query, + ) + ) + ) + @defer.inlineCallbacks @log_function def _on_send_request(self, origin, content, query, transaction_id): @@ -451,7 +551,34 @@ class TransportLayer(object): versions = [v.split(",", 1) for v in v_list] return self.request_handler.on_backfill_request( - context, versions, limit) + context, versions, limit + ) + + @defer.inlineCallbacks + @log_function + def _on_make_join_request(self, origin, content, query, context, user_id): + content = yield self.request_handler.on_make_join_request( + context, user_id, + ) + defer.returnValue((200, content)) + + @defer.inlineCallbacks + @log_function + def _on_send_join_request(self, origin, content, query): + content = yield self.request_handler.on_send_join_request( + origin, content, + ) + + defer.returnValue((200, content)) + + @defer.inlineCallbacks + @log_function + def _on_invite_request(self, origin, content, query): + content = yield self.request_handler.on_invite_request( + origin, content, + ) + + defer.returnValue((200, content)) class TransportReceivedHandler(object): diff --git a/synapse/federation/units.py b/synapse/federation/units.py index b2fb964180..adc3385644 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -18,6 +18,7 @@ server protocol. """ from synapse.util.jsonobject import JsonEncodedObject +from syutil.base64util import encode_base64 import logging import json @@ -63,9 +64,10 @@ class Pdu(JsonEncodedObject): "depth", "content", "outlier", + "hashes", + "signatures", "is_state", # Below this are keys valid only for State Pdus. "state_key", - "power_level", "prev_state_id", "prev_state_origin", "required_power_level", @@ -91,7 +93,7 @@ class Pdu(JsonEncodedObject): # just leaving it as a dict. (OR DO WE?!) def __init__(self, destinations=[], is_state=False, prev_pdus=[], - outlier=False, **kwargs): + outlier=False, hashes={}, signatures={}, **kwargs): if is_state: for required_key in ["state_key"]: if required_key not in kwargs: @@ -99,9 +101,11 @@ class Pdu(JsonEncodedObject): super(Pdu, self).__init__( destinations=destinations, - is_state=is_state, + is_state=bool(is_state), prev_pdus=prev_pdus, outlier=outlier, + hashes=hashes, + signatures=signatures, **kwargs ) @@ -120,6 +124,10 @@ class Pdu(JsonEncodedObject): d = copy.copy(pdu_tuple.pdu_entry._asdict()) d["origin_server_ts"] = d.pop("ts") + for k in d.keys(): + if d[k] is None: + del d[k] + d["content"] = json.loads(d["content_json"]) del d["content_json"] @@ -127,8 +135,28 @@ class Pdu(JsonEncodedObject): if "unrecognized_keys" in d and d["unrecognized_keys"]: args.update(json.loads(d["unrecognized_keys"])) + hashes = { + alg: encode_base64(hsh) + for alg, hsh in pdu_tuple.hashes.items() + } + + signatures = { + kid: encode_base64(sig) + for kid, sig in pdu_tuple.signatures.items() + } + + prev_pdus = [] + for prev_pdu in pdu_tuple.prev_pdu_list: + prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {}) + prev_hashes = { + alg: encode_base64(hsh) for alg, hsh in prev_hashes.items() + } + prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes)) + return Pdu( - prev_pdus=pdu_tuple.prev_pdu_list, + prev_pdus=prev_pdus, + hashes=hashes, + signatures=signatures, **args ) else: diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index de4d23bbb3..787a01efc5 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -16,6 +16,8 @@ from twisted.internet import defer from synapse.api.errors import LimitExceededError +from synapse.util.async import run_on_reactor + class BaseHandler(object): def __init__(self, hs): @@ -44,9 +46,19 @@ class BaseHandler(object): @defer.inlineCallbacks def _on_new_room_event(self, event, snapshot, extra_destinations=[], - extra_users=[]): + extra_users=[], suppress_auth=False): + yield run_on_reactor() + snapshot.fill_out_prev_events(event) + yield self.state_handler.annotate_state_groups(event) + + if not suppress_auth: + yield self.auth.check(event, snapshot, raises=True) + + if hasattr(event, "state_key"): + yield self.state_handler.handle_new_event(event, snapshot) + yield self.store.persist_event(event) destinations = set(extra_destinations) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index a56830d520..6e897e915d 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -152,5 +152,6 @@ class DirectoryHandler(BaseHandler): user_id=user_id, ) - yield self.state_handler.handle_new_event(event, snapshot) - yield self._on_new_room_event(event, snapshot, extra_users=[user_id]) + yield self._on_new_room_event( + event, snapshot, extra_users=[user_id], suppress_auth=True + ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f52591d2a3..18cb1d4e97 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -22,6 +22,8 @@ from synapse.api.constants import Membership from synapse.util.logutils import log_function from synapse.federation.pdu_codec import PduCodec from synapse.api.errors import SynapseError +from synapse.util.async import run_on_reactor +from synapse.types import EventID from twisted.internet import defer, reactor @@ -62,6 +64,9 @@ class FederationHandler(BaseHandler): self.pdu_codec = PduCodec(hs) + # When joining a room we need to queue any events for that room up + self.room_queues = {} + @log_function @defer.inlineCallbacks def handle_new_event(self, event, snapshot): @@ -78,6 +83,8 @@ class FederationHandler(BaseHandler): processing. """ + yield run_on_reactor() + pdu = self.pdu_codec.pdu_from_event(event) if not hasattr(pdu, "destinations") or not pdu.destinations: @@ -87,98 +94,83 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def on_receive_pdu(self, pdu, backfilled): + def on_receive_pdu(self, pdu, backfilled, state=None): """ Called by the ReplicationLayer when we have a new pdu. We need to - do auth checks and put it throught the StateHandler. + do auth checks and put it through the StateHandler. """ event = self.pdu_codec.event_from_pdu(pdu) logger.debug("Got event: %s", event.event_id) - with (yield self.lock_manager.lock(pdu.context)): - if event.is_state and not backfilled: - is_new_state = yield self.state_handler.handle_new_state( - pdu - ) - else: - is_new_state = False + if event.room_id in self.room_queues: + self.room_queues[event.room_id].append(pdu) + return + + logger.debug("Processing event: %s", event.event_id) + + if state: + state = [self.pdu_codec.event_from_pdu(p) for p in state] + + is_new_state = yield self.state_handler.annotate_state_groups( + event, + old_state=state + ) + + logger.debug("Event: %s", event) + + if not backfilled: + yield self.auth.check(event, None, raises=True) + + is_new_state = is_new_state and not backfilled + # TODO: Implement something in federation that allows us to # respond to PDU. - target_is_mine = False - if hasattr(event, "target_host"): - target_is_mine = event.target_host == self.hs.hostname - - if event.type == InviteJoinEvent.TYPE: - if not target_is_mine: - logger.debug("Ignoring invite/join event %s", event) - return - - # If we receive an invite/join event then we need to join the - # sender to the given room. - # TODO: We should probably auth this or some such - content = event.content - content.update({"membership": Membership.JOIN}) - new_event = self.event_factory.create_event( - etype=RoomMemberEvent.TYPE, - state_key=event.user_id, - room_id=event.room_id, - user_id=event.user_id, - membership=Membership.JOIN, - content=content + with (yield self.room_lock.lock(event.room_id)): + yield self.store.persist_event( + event, + backfilled, + is_new_state=is_new_state ) - yield self.hs.get_handlers().room_member_handler.change_membership( - new_event, - do_auth=False, - ) + room = yield self.store.get_room(event.room_id) - else: - with (yield self.room_lock.lock(event.room_id)): - yield self.store.persist_event( - event, - backfilled, - is_new_state=is_new_state + if not room: + # Huh, let's try and get the current state + try: + yield self.replication_layer.get_state_for_context( + event.origin, event.room_id, pdu.pdu_id, pdu.origin, ) - room = yield self.store.get_room(event.room_id) - - if not room: - # Huh, let's try and get the current state - try: - yield self.replication_layer.get_state_for_context( - event.origin, event.room_id - ) - - hosts = yield self.store.get_joined_hosts_for_room( - event.room_id - ) - if self.hs.hostname in hosts: - try: - yield self.store.store_room( - room_id=event.room_id, - room_creator_user_id="", - is_public=False, - ) - except: - pass - except: - logger.exception( - "Failed to get current state for room %s", - event.room_id - ) - - if not backfilled: - extra_users = [] - if event.type == RoomMemberEvent.TYPE: - target_user_id = event.state_key - target_user = self.hs.parse_userid(target_user_id) - extra_users.append(target_user) - - yield self.notifier.on_new_room_event( - event, extra_users=extra_users + hosts = yield self.store.get_joined_hosts_for_room( + event.room_id + ) + if self.hs.hostname in hosts: + try: + yield self.store.store_room( + room_id=event.room_id, + room_creator_user_id="", + is_public=False, + ) + except: + pass + except: + logger.exception( + "Failed to get current state for room %s", + event.room_id ) + if not backfilled: + extra_users = [] + if event.type == RoomMemberEvent.TYPE: + target_user_id = event.state_key + target_user = self.hs.parse_userid(target_user_id) + extra_users.append(target_user) + + yield self.notifier.on_new_room_event( + event, extra_users=extra_users + ) + if event.type == RoomMemberEvent.TYPE: if event.membership == Membership.JOIN: user = self.hs.parse_userid(event.state_key) @@ -189,13 +181,28 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks def backfill(self, dest, room_id, limit): - pdus = yield self.replication_layer.backfill(dest, room_id, limit) + extremities = yield self.store.get_oldest_events_in_room(room_id) + + pdus = yield self.replication_layer.backfill( + dest, + room_id, + limit, + extremities=[ + self.pdu_codec.decode_event_id(e) + for e in extremities + ] + ) events = [] for pdu in pdus: event = self.pdu_codec.event_from_pdu(pdu) + + # FIXME (erikj): Not sure this actually works :/ + yield self.state_handler.annotate_state_groups(event) + events.append(event) + yield self.store.persist_event(event, backfilled=True) defer.returnValue(events) @@ -203,62 +210,232 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks def do_invite_join(self, target_host, room_id, joinee, content, snapshot): - hosts = yield self.store.get_joined_hosts_for_room(room_id) if self.hs.hostname in hosts: # We are already in the room. logger.debug("We're already in the room apparently") defer.returnValue(False) - # First get current state to see if we are already joined. + pdu = yield self.replication_layer.make_join( + target_host, + room_id, + joinee + ) + + logger.debug("Got response to make_join: %s", pdu) + + event = self.pdu_codec.event_from_pdu(pdu) + + # We should assert some things. + assert(event.type == RoomMemberEvent.TYPE) + assert(event.user_id == joinee) + assert(event.state_key == joinee) + assert(event.room_id == room_id) + + event.outlier = False + + self.room_queues[room_id] = [] + try: - yield self.replication_layer.get_state_for_context( - target_host, room_id + event.event_id = self.event_factory.create_event_id() + event.content = content + + state = yield self.replication_layer.send_join( + target_host, + self.pdu_codec.pdu_from_event(event) ) - hosts = yield self.store.get_joined_hosts_for_room(room_id) - if self.hs.hostname in hosts: - # Oh, we were actually in the room already. - logger.debug("We're already in the room apparently") - defer.returnValue(False) - except Exception: - logger.exception("Failed to get current state") - - new_event = self.event_factory.create_event( - etype=InviteJoinEvent.TYPE, - target_host=target_host, - room_id=room_id, - user_id=joinee, - content=content - ) + state = [self.pdu_codec.event_from_pdu(p) for p in state] - new_event.destinations = [target_host] + logger.debug("do_invite_join state: %s", state) - snapshot.fill_out_prev_events(new_event) - yield self.handle_new_event(new_event, snapshot) + is_new_state = yield self.state_handler.annotate_state_groups( + event, + old_state=state + ) - # TODO (erikj): Time out here. - d = defer.Deferred() - self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d) - reactor.callLater(10, d.cancel) + logger.debug("do_invite_join event: %s", event) - try: - yield d - except defer.CancelledError: - raise SynapseError(500, "Unable to join remote room") + try: + yield self.store.store_room( + room_id=room_id, + room_creator_user_id="", + is_public=False + ) + except: + # FIXME + pass - try: - yield self.store.store_room( - room_id=room_id, - room_creator_user_id="", - is_public=False + for e in state: + # FIXME: Auth these. + e.outlier = True + + yield self.state_handler.annotate_state_groups( + e, + ) + + yield self.store.persist_event( + e, + backfilled=False, + is_new_state=False + ) + + yield self.store.persist_event( + event, + backfilled=False, + is_new_state=is_new_state ) - except: - pass + finally: + room_queue = self.room_queues[room_id] + del self.room_queues[room_id] + for p in room_queue: + try: + yield self.on_receive_pdu(p, backfilled=False) + except: + pass defer.returnValue(True) + @defer.inlineCallbacks + @log_function + def on_make_join_request(self, context, user_id): + event = self.event_factory.create_event( + etype=RoomMemberEvent.TYPE, + content={"membership": Membership.JOIN}, + room_id=context, + user_id=user_id, + state_key=user_id, + ) + + snapshot = yield self.store.snapshot_room( + event.room_id, event.user_id, + ) + snapshot.fill_out_prev_events(event) + + yield self.state_handler.annotate_state_groups(event) + yield self.auth.check(event, None, raises=True) + + pdu = self.pdu_codec.pdu_from_event(event) + + defer.returnValue(pdu) + + @defer.inlineCallbacks + @log_function + def on_send_join_request(self, origin, pdu): + event = self.pdu_codec.event_from_pdu(pdu) + + event.outlier = False + + is_new_state = yield self.state_handler.annotate_state_groups(event) + yield self.auth.check(event, None, raises=True) + + # FIXME (erikj): All this is duplicated above :( + + yield self.store.persist_event( + event, + backfilled=False, + is_new_state=is_new_state + ) + + extra_users = [] + if event.type == RoomMemberEvent.TYPE: + target_user_id = event.state_key + target_user = self.hs.parse_userid(target_user_id) + extra_users.append(target_user) + + yield self.notifier.on_new_room_event( + event, extra_users=extra_users + ) + + if event.type == RoomMemberEvent.TYPE: + if event.membership == Membership.JOIN: + user = self.hs.parse_userid(event.state_key) + self.distributor.fire( + "user_joined_room", user=user, room_id=event.room_id + ) + + new_pdu = self.pdu_codec.pdu_from_event(event); + new_pdu.destinations = yield self.store.get_joined_hosts_for_room( + event.room_id + ) + + yield self.replication_layer.send_pdu(new_pdu) + + defer.returnValue([ + self.pdu_codec.pdu_from_event(e) + for e in event.state_events.values() + ]) + + @defer.inlineCallbacks + def get_state_for_pdu(self, pdu_id, pdu_origin): + yield run_on_reactor() + + event_id = EventID.create(pdu_id, pdu_origin, self.hs).to_string() + + state_groups = yield self.store.get_state_groups( + [event_id] + ) + + if state_groups: + results = { + (e.type, e.state_key): e for e in state_groups[0].state + } + + event = yield self.store.get_event(event_id) + if hasattr(event, "state_key"): + # Get previous state + if hasattr(event, "prev_state") and event.prev_state: + prev_event = yield self.store.get_event(event.prev_state) + results[(event.type, event.state_key)] = prev_event + else: + del results[(event.type, event.state_key)] + + defer.returnValue( + [ + self.pdu_codec.pdu_from_event(s) + for s in results.values() + ] + ) + else: + defer.returnValue([]) + + @defer.inlineCallbacks + @log_function + def on_backfill_request(self, context, pdu_list, limit): + + events = yield self.store.get_backfill_events( + context, + [self.pdu_codec.encode_event_id(i, o) for i, o in pdu_list], + limit + ) + + defer.returnValue([ + self.pdu_codec.pdu_from_event(e) + for e in events + ]) + + @defer.inlineCallbacks + @log_function + def get_persisted_pdu(self, pdu_id, origin): + """ Get a PDU from the database with given origin and id. + + Returns: + Deferred: Results in a `Pdu`. + """ + event = yield self.store.get_event( + self.pdu_codec.encode_event_id(pdu_id, origin), + allow_none=True, + ) + + if event: + defer.returnValue(self.pdu_codec.pdu_from_event(event)) + else: + defer.returnValue(None) + + @log_function + def get_min_depth_for_context(self, context): + return self.store.get_min_depth(context) @log_function def _on_user_joined(self, user, room_id): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 72894869ea..c6f6ab14d1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -83,10 +83,9 @@ class MessageHandler(BaseHandler): snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - if not suppress_auth: - yield self.auth.check(event, snapshot, raises=True) - - yield self._on_new_room_event(event, snapshot) + yield self._on_new_room_event( + event, snapshot, suppress_auth=suppress_auth + ) self.hs.get_handlers().presence_handler.bump_presence_active_time( user @@ -149,10 +148,6 @@ class MessageHandler(BaseHandler): state_key=event.state_key, ) - yield self.auth.check(event, snapshot, raises=True) - - yield self.state_handler.handle_new_event(event, snapshot) - yield self._on_new_room_event(event, snapshot) @defer.inlineCallbacks @@ -201,7 +196,7 @@ class MessageHandler(BaseHandler): raise RoomError( 403, "Member does not meet private room rules.") - data = yield self.store.get_current_state( + data = yield self.state_handler.get_current_state( room_id, event_type, state_key ) defer.returnValue(data) @@ -221,8 +216,6 @@ class MessageHandler(BaseHandler): def send_feedback(self, event): snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - yield self.auth.check(event, snapshot, raises=True) - # store message in db yield self._on_new_room_event(event, snapshot) @@ -239,7 +232,7 @@ class MessageHandler(BaseHandler): yield self.auth.check_joined_room(room_id, user_id) # TODO: This is duplicating logic from snapshot_all_rooms - current_state = yield self.store.get_current_state(room_id) + current_state = yield self.state_handler.get_current_state(room_id) defer.returnValue([self.hs.serialize_event(c) for c in current_state]) @defer.inlineCallbacks @@ -316,7 +309,7 @@ class MessageHandler(BaseHandler): "end": end_token.to_string(), } - current_state = yield self.store.get_current_state( + current_state = yield self.state_handler.get_current_state( event.room_id ) d["state"] = [self.hs.serialize_event(c) for c in current_state] diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index dab9b03f04..4cd0a06093 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -218,5 +218,6 @@ class ProfileHandler(BaseHandler): user_id=j.state_key, ) - yield self.state_handler.handle_new_event(new_event, snapshot) - yield self._on_new_room_event(new_event, snapshot) + yield self._on_new_room_event( + new_event, snapshot, suppress_auth=True + ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 81ce1a5907..ffc0892f1a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -129,8 +129,9 @@ class RoomCreationHandler(BaseHandler): logger.debug("Event: %s", event) - yield self.state_handler.handle_new_event(event, snapshot) - yield self._on_new_room_event(event, snapshot, extra_users=[user]) + yield self._on_new_room_event( + event, snapshot, extra_users=[user], suppress_auth=True + ) for event in creation_events: yield handle_event(event) @@ -391,8 +392,6 @@ class RoomMemberHandler(BaseHandler): yield self._do_join(event, snapshot, do_auth=do_auth) else: # This is not a JOIN, so we can handle it normally. - if do_auth: - yield self.auth.check(event, snapshot, raises=True) # If we're banning someone, set a req power level if event.membership == Membership.BAN: @@ -414,6 +413,7 @@ class RoomMemberHandler(BaseHandler): event, membership=event.content["membership"], snapshot=snapshot, + do_auth=do_auth, ) defer.returnValue({"room_id": room_id}) @@ -502,14 +502,11 @@ class RoomMemberHandler(BaseHandler): if not have_joined: logger.debug("Doing normal join") - if do_auth: - yield self.auth.check(event, snapshot, raises=True) - - yield self.state_handler.handle_new_event(event, snapshot) yield self._do_local_membership_update( event, membership=event.content["membership"], snapshot=snapshot, + do_auth=do_auth, ) user = self.hs.parse_userid(event.user_id) @@ -553,7 +550,8 @@ class RoomMemberHandler(BaseHandler): defer.returnValue([r.room_id for r in rooms]) - def _do_local_membership_update(self, event, membership, snapshot): + def _do_local_membership_update(self, event, membership, snapshot, + do_auth): destinations = [] # If we're inviting someone, then we should also send it to that @@ -570,9 +568,10 @@ class RoomMemberHandler(BaseHandler): return self._on_new_room_event( event, snapshot, extra_destinations=destinations, - extra_users=[target_user] + extra_users=[target_user], suppress_auth=(not do_auth), ) + class RoomListHandler(BaseHandler): @defer.inlineCallbacks diff --git a/synapse/rest/base.py b/synapse/rest/base.py index 2e8e3fa7d4..dc784c1527 100644 --- a/synapse/rest/base.py +++ b/synapse/rest/base.py @@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX from synapse.rest.transactions import HttpTransactionStore import re +import logging + + +logger = logging.getLogger(__name__) + def client_path_pattern(path_regex): """Creates a regex compiled client path with the correct client path diff --git a/synapse/rest/events.py b/synapse/rest/events.py index 097195d7cc..92ff5e5ca7 100644 --- a/synapse/rest/events.py +++ b/synapse/rest/events.py @@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError from synapse.streams.config import PaginationConfig from synapse.rest.base import RestServlet, client_path_pattern +import logging + + +logger = logging.getLogger(__name__) + + class EventStreamRestServlet(RestServlet): PATTERN = client_path_pattern("/events$") @@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): auth_user = yield self.auth.get_user_by_req(request) - - handler = self.handlers.event_stream_handler - pagin_config = PaginationConfig.from_request(request) - timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if "timeout" in request.args: - try: - timeout = int(request.args["timeout"][0]) - except ValueError: - raise SynapseError(400, "timeout must be in milliseconds.") - - chunk = yield handler.get_stream(auth_user.to_string(), pagin_config, - timeout=timeout) + try: + handler = self.handlers.event_stream_handler + pagin_config = PaginationConfig.from_request(request) + timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS + if "timeout" in request.args: + try: + timeout = int(request.args["timeout"][0]) + except ValueError: + raise SynapseError(400, "timeout must be in milliseconds.") + + chunk = yield handler.get_stream( + auth_user.to_string(), pagin_config, timeout=timeout + ) + except: + logger.exception("Event stream failed") + raise defer.returnValue((200, chunk)) diff --git a/synapse/server.py b/synapse/server.py index a4d2d4aba5..d770b20b19 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -28,7 +28,7 @@ from synapse.handlers import Handlers from synapse.rest import RestServletFactory from synapse.state import StateHandler from synapse.storage import DataStore -from synapse.types import UserID, RoomAlias, RoomID +from synapse.types import UserID, RoomAlias, RoomID, EventID from synapse.util import Clock from synapse.util.distributor import Distributor from synapse.util.lockutils import LockManager @@ -143,6 +143,11 @@ class BaseHomeServer(object): object.""" return RoomID.from_string(s, hs=self) + def parse_eventid(self, s): + """Parse the string given by 's' as a Event ID and return a EventID + object.""" + return EventID.from_string(s, hs=self) + def serialize_event(self, e): return serialize_event(self, e) diff --git a/synapse/state.py b/synapse/state.py index 9db84c9b5c..2548deed28 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,11 +16,14 @@ from twisted.internet import defer -from synapse.federation.pdu_codec import encode_event_id, decode_event_id from synapse.util.logutils import log_function +from synapse.util.async import run_on_reactor + +from synapse.types import EventID from collections import namedtuple +import copy import logging import hashlib @@ -35,13 +38,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) class StateHandler(object): - """ Repsonsible for doing state conflict resolution. + """ Responsible for doing state conflict resolution. """ def __init__(self, hs): self.store = hs.get_datastore() self._replication = hs.get_replication_layer() self.server_name = hs.hostname + self.hs = hs @defer.inlineCallbacks @log_function @@ -50,7 +54,7 @@ class StateHandler(object): to update the state and b) works out what the prev_state should be. Returns: - Deferred: Resolved with a boolean indicating if we succesfully + Deferred: Resolved with a boolean indicating if we successfully updated the state. Raised: @@ -71,128 +75,157 @@ class StateHandler(object): # (w.r.t. to power levels) snapshot.fill_out_prev_events(event) + yield self.annotate_state_groups(event) - event.prev_events = [ - e for e in event.prev_events if e != event.event_id - ] - - current_state = snapshot.prev_state_pdu - - if current_state: - event.prev_state = encode_event_id( - current_state.pdu_id, current_state.origin + if event.old_state_events: + current_state = event.old_state_events.get( + (event.type, event.state_key) ) + if current_state: + event.prev_state = current_state.event_id + # TODO check current_state to see if the min power level is less # than the power level of the user # power_level = self._get_power_level_for_event(event) - pdu_id, origin = decode_event_id(event.event_id, self.server_name) - - yield self.store.update_current_state( - pdu_id=pdu_id, - origin=origin, - context=key.context, - pdu_type=key.type, - state_key=key.state_key - ) - defer.returnValue(True) @defer.inlineCallbacks @log_function - def handle_new_state(self, new_pdu): - """ Apply conflict resolution to `new_pdu`. + def annotate_state_groups(self, event, old_state=None): + yield run_on_reactor() - This should be called on every new state pdu, regardless of whether or - not there is a conflict. + if old_state: + event.state_group = None + event.old_state_events = { + (s.type, s.state_key): s for s in old_state + } + event.state_events = event.old_state_events - This function is safe against the race of it getting called with two - `PDU`s trying to update the same state. - """ + if hasattr(event, "state_key"): + event.state_events[(event.type, event.state_key)] = event - # This needs to be done in a transaction. + defer.returnValue(False) + return - is_new = yield self._handle_new_state(new_pdu) + if hasattr(event, "outlier") and event.outlier: + event.state_group = None + event.old_state_events = None + event.state_events = {} + defer.returnValue(False) + return - logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin) + new_state = yield self.resolve_state_groups( + [e for e, _ in event.prev_events] + ) - if is_new: - yield self.store.update_current_state( - pdu_id=new_pdu.pdu_id, - origin=new_pdu.origin, - context=new_pdu.context, - pdu_type=new_pdu.pdu_type, - state_key=new_pdu.state_key - ) + event.old_state_events = copy.deepcopy(new_state) - defer.returnValue(is_new) + if hasattr(event, "state_key"): + new_state[(event.type, event.state_key)] = event - def _get_power_level_for_event(self, event): - # return self._persistence.get_power_level_for_user(event.room_id, - # event.sender) - return event.power_level + event.state_group = None + event.state_events = new_state + + defer.returnValue(hasattr(event, "state_key")) @defer.inlineCallbacks - @log_function - def _handle_new_state(self, new_pdu): - tree, missing_branch = yield self.store.get_unresolved_state_tree( - new_pdu - ) - new_branch, current_branch = tree + def get_current_state(self, room_id, event_type=None, state_key=""): + events = yield self.store.get_latest_events_in_room(room_id) - logger.debug( - "_handle_new_state new=%s, current=%s", - new_branch, current_branch - ) + event_ids = [ + e_id + for e_id, _, _ in events + ] - if missing_branch is not None: - # We're missing some PDUs. Fetch them. - # TODO (erikj): Limit this. - missing_prev = tree[missing_branch][-1] + res = yield self.resolve_state_groups(event_ids) - pdu_id = missing_prev.prev_state_id - origin = missing_prev.prev_state_origin + if event_type: + defer.returnValue(res.get((event_type, state_key))) + return - is_missing = yield self.store.get_pdu(pdu_id, origin) is None - if not is_missing: - raise Exception("Conflict resolution failed") + defer.returnValue(res.values()) - yield self._replication.get_pdu( - destination=missing_prev.origin, - pdu_origin=origin, - pdu_id=pdu_id, - outlier=True - ) + @defer.inlineCallbacks + @log_function + def resolve_state_groups(self, event_ids): + state_groups = yield self.store.get_state_groups( + event_ids + ) - updated_current = yield self._handle_new_state(new_pdu) - defer.returnValue(updated_current) + state = {} + for group in state_groups: + for s in group.state: + state.setdefault( + (s.type, s.state_key), + {} + )[s.event_id] = s + + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } + + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } + + try: + new_state = {} + new_state.update(unconflicted_state) + for key, events in conflicted_state.items(): + new_state[key] = yield self._resolve_state_events(events) + except: + logger.exception("Failed to resolve state") + raise + + defer.returnValue(new_state) - if not current_branch: - # There is no current state - defer.returnValue(True) - return + @defer.inlineCallbacks + @log_function + def _resolve_state_events(self, events): + curr_events = events - n = new_branch[-1] - c = current_branch[-1] + new_powers_deferreds = [] + for e in curr_events: + new_powers_deferreds.append( + self.store.get_power_level(e.room_id, e.user_id) + ) - common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin + new_powers = yield defer.gatherResults( + new_powers_deferreds, + consumeErrors=True + ) - if common_ancestor: - # We found a common ancestor! + max_power = max([int(p) for p in new_powers]) - if len(current_branch) == 1: - # This is a direct clobber so we can just... - defer.returnValue(True) + curr_events = [ + z[0] for z in zip(curr_events, new_powers) + if int(z[1]) == max_power + ] - else: - # We didn't find a common ancestor. This is probably fine. - pass + if not curr_events: + raise RuntimeError("Max didn't get a max?") + elif len(curr_events) == 1: + defer.returnValue(curr_events[0]) - result = yield self._do_conflict_res( - new_branch, current_branch, common_ancestor + # TODO: For now, just choose the one with the largest event_id. + defer.returnValue( + sorted( + curr_events, + key=lambda e: hashlib.sha1( + e.event_id + e.user_id + e.room_id + e.type + ).hexdigest() + )[0] ) - defer.returnValue(result) + + def _get_power_level_for_event(self, event): + # return self._persistence.get_power_level_for_user(event.room_id, + # event.sender) + return event.power_level @defer.inlineCallbacks def _do_conflict_res(self, new_branch, current_branch, common_ancestor): diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4e9291fdff..1f39a4094e 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -37,9 +37,17 @@ from .registration import RegistrationStore from .room import RoomStore from .roommember import RoomMemberStore from .stream import StreamStore -from .pdu import StatePduStore, PduStore, PdusTable from .transactions import TransactionStore from .keys import KeyStore +from .event_federation import EventFederationStore + +from .state import StateStore +from .signatures import SignatureStore + +from syutil.base64util import decode_base64 + +from synapse.crypto.event_signing import compute_event_reference_hash + import json import logging @@ -51,7 +59,6 @@ logger = logging.getLogger(__name__) SCHEMAS = [ "transactions", - "pdu", "users", "profiles", "presence", @@ -59,6 +66,9 @@ SCHEMAS = [ "room_aliases", "keys", "redactions", + "state", + "event_edges", + "event_signatures", ] @@ -73,10 +83,12 @@ class _RollbackButIsFineException(Exception): """ pass + class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore, - PresenceStore, PduStore, StatePduStore, TransactionStore, - DirectoryStore, KeyStore): + PresenceStore, TransactionStore, + DirectoryStore, KeyStore, StateStore, SignatureStore, + EventFederationStore, ): def __init__(self, hs): super(DataStore, self).__init__(hs) @@ -99,6 +111,7 @@ class DataStore(RoomMemberStore, RoomStore, try: yield self.runInteraction( + "persist_event", self._persist_pdu_event_txn, pdu=pdu, event=event, @@ -119,7 +132,8 @@ class DataStore(RoomMemberStore, RoomStore, "type", "room_id", "content", - "unrecognized_keys" + "unrecognized_keys", + "depth", ], allow_none=allow_none, ) @@ -133,39 +147,12 @@ class DataStore(RoomMemberStore, RoomStore, def _persist_pdu_event_txn(self, txn, pdu=None, event=None, backfilled=False, stream_ordering=None, is_new_state=True): - if pdu is not None: - self._persist_event_pdu_txn(txn, pdu) if event is not None: return self._persist_event_txn( txn, event, backfilled, stream_ordering, is_new_state=is_new_state, ) - def _persist_event_pdu_txn(self, txn, pdu): - cols = dict(pdu.__dict__) - unrec_keys = dict(pdu.unrecognized_keys) - del cols["content"] - del cols["prev_pdus"] - cols["content_json"] = json.dumps(pdu.content) - - unrec_keys.update({ - k: v for k, v in cols.items() - if k not in PdusTable.fields - }) - - cols["unrecognized_keys"] = json.dumps(unrec_keys) - - cols["ts"] = cols.pop("origin_server_ts") - - logger.debug("Persisting: %s", repr(cols)) - - if pdu.is_state: - self._persist_state_txn(txn, pdu.prev_pdus, cols) - else: - self._persist_pdu_txn(txn, pdu.prev_pdus, cols) - - self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth) - @log_function def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, is_new_state=True): @@ -190,6 +177,10 @@ class DataStore(RoomMemberStore, RoomStore, elif event.type == RoomRedactionEvent.TYPE: self._store_redaction(txn, event) + outlier = False + if hasattr(event, "outlier"): + outlier = event.outlier + vals = { "topological_ordering": event.depth, "event_id": event.event_id, @@ -197,25 +188,30 @@ class DataStore(RoomMemberStore, RoomStore, "room_id": event.room_id, "content": json.dumps(event.content), "processed": True, + "outlier": outlier, + "depth": event.depth, } if stream_ordering is not None: vals["stream_ordering"] = stream_ordering - if hasattr(event, "outlier"): - vals["outlier"] = event.outlier - else: - vals["outlier"] = False - unrec = { k: v for k, v in event.get_full_dict().items() - if k not in vals.keys() and k not in ["redacted", "redacted_because"] + if k not in vals.keys() and k not in [ + "redacted", "redacted_because", "signatures", "hashes", + "prev_events", + ] } vals["unrecognized_keys"] = json.dumps(unrec) try: - self._simple_insert_txn(txn, "events", vals) + self._simple_insert_txn( + txn, + "events", + vals, + or_replace=(not outlier), + ) except: logger.warn( "Failed to persist, probably duplicate: %s", @@ -224,6 +220,16 @@ class DataStore(RoomMemberStore, RoomStore, ) raise _RollbackButIsFineException("_persist_event") + self._handle_prev_events( + txn, + outlier=outlier, + event_id=event.event_id, + prev_events=event.prev_events, + room_id=event.room_id, + ) + + self._store_state_groups_txn(txn, event) + is_state = hasattr(event, "state_key") and event.state_key is not None if is_new_state and is_state: vals = { @@ -249,6 +255,35 @@ class DataStore(RoomMemberStore, RoomStore, } ) + for hash_alg, hash_base64 in event.hashes.items(): + hash_bytes = decode_base64(hash_base64) + self._store_event_content_hash_txn( + txn, event.event_id, hash_alg, hash_bytes, + ) + + if hasattr(event, "signatures"): + signatures = event.signatures.get(event.origin, {}) + + for key_id, signature_base64 in signatures.items(): + signature_bytes = decode_base64(signature_base64) + self._store_event_origin_signature_txn( + txn, event.event_id, event.origin, key_id, signature_bytes, + ) + + for prev_event_id, prev_hashes in event.prev_events: + for alg, hash_base64 in prev_hashes.items(): + hash_bytes = decode_base64(hash_base64) + self._store_prev_event_hash_txn( + txn, event.event_id, prev_event_id, alg, hash_bytes + ) + + (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) + self._store_event_reference_hash_txn( + txn, event.event_id, ref_alg, ref_hash_bytes + ) + + self._update_min_depth_for_room_txn(txn, event.room_id, event.depth) + def _store_redaction(self, txn, event): txn.execute( "INSERT OR IGNORE INTO redactions " @@ -331,28 +366,19 @@ class DataStore(RoomMemberStore, RoomStore, """ def _snapshot(txn): membership_state = self._get_room_member(txn, user_id, room_id) - prev_pdus = self._get_latest_pdus_in_context( - txn, room_id - ) - if state_type is not None and state_key is not None: - prev_state_pdu = self._get_current_state_pdu( - txn, room_id, state_type, state_key - ) - else: - prev_state_pdu = None + prev_events = self._get_latest_events_in_room(txn, room_id) return Snapshot( store=self, room_id=room_id, user_id=user_id, - prev_pdus=prev_pdus, + prev_events=prev_events, membership_state=membership_state, state_type=state_type, state_key=state_key, - prev_state_pdu=prev_state_pdu, ) - return self.runInteraction(_snapshot) + return self.runInteraction("snapshot_room", _snapshot) class Snapshot(object): @@ -361,7 +387,7 @@ class Snapshot(object): store (DataStore): The datastore. room_id (RoomId): The room of the snapshot. user_id (UserId): The user this snapshot is for. - prev_pdus (list): The list of PDU ids this snapshot is after. + prev_events (list): The list of event ids this snapshot is after. membership_state (RoomMemberEvent): The current state of the user in the room. state_type (str, optional): State type captured by the snapshot @@ -370,13 +396,13 @@ class Snapshot(object): the previous value of the state type and key in the room. """ - def __init__(self, store, room_id, user_id, prev_pdus, + def __init__(self, store, room_id, user_id, prev_events, membership_state, state_type=None, state_key=None, prev_state_pdu=None): self.store = store self.room_id = room_id self.user_id = user_id - self.prev_pdus = prev_pdus + self.prev_events = prev_events self.membership_state = membership_state self.state_type = state_type self.state_key = state_key @@ -386,14 +412,13 @@ class Snapshot(object): if hasattr(event, "prev_events"): return - es = [ - "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus + event.prev_events = [ + (event_id, hashes) + for event_id, hashes, _ in self.prev_events ] - event.prev_events = [e for e in es if e != event.event_id] - - if self.prev_pdus: - event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1 + if self.prev_events: + event.depth = max([int(v) for _, _, v in self.prev_events]) + 1 else: event.depth = 0 @@ -452,9 +477,10 @@ def prepare_database(db_conn): db_conn.commit() else: - sql_script = "BEGIN TRANSACTION;" + sql_script = "BEGIN TRANSACTION;\n" for sql_loc in SCHEMAS: sql_script += read_schema(sql_loc) + sql_script += "\n" sql_script += "COMMIT TRANSACTION;" c.executescript(sql_script) db_conn.commit() diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 65a86e9056..464b12f032 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,54 +19,66 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.api.events.utils import prune_event from synapse.util.logutils import log_function +from syutil.base64util import encode_base64 import collections import copy import json +import sys +import time logger = logging.getLogger(__name__) sql_logger = logging.getLogger("synapse.storage.SQL") +transaction_logger = logging.getLogger("synapse.storage.txn") class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging to the .execute() method.""" - __slots__ = ["txn"] + __slots__ = ["txn", "name"] - def __init__(self, txn): + def __init__(self, txn, name): object.__setattr__(self, "txn", txn) + object.__setattr__(self, "name", name) - def __getattribute__(self, name): - if name == "execute": - return object.__getattribute__(self, "execute") - - return getattr(object.__getattribute__(self, "txn"), name) + def __getattr__(self, name): + return getattr(self.txn, name) def __setattr__(self, name, value): - setattr(object.__getattribute__(self, "txn"), name, value) + setattr(self.txn, name, value) def execute(self, sql, *args, **kwargs): # TODO(paul): Maybe use 'info' and 'debug' for values? - sql_logger.debug("[SQL] %s", sql) + sql_logger.debug("[SQL] {%s} %s", self.name, sql) try: if args and args[0]: values = args[0] - sql_logger.debug("[SQL values] " + - ", ".join(("<%s>",) * len(values)), *values) + sql_logger.debug( + "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)), + self.name, + *values + ) except: # Don't let logging failures stop SQL from working pass - # TODO(paul): Here would be an excellent place to put some timing - # measurements, and log (warning?) slow queries. - return object.__getattribute__(self, "txn").execute( - sql, *args, **kwargs - ) + start = time.clock() * 1000 + try: + return self.txn.execute( + sql, *args, **kwargs + ) + except: + logger.exception("[SQL FAIL] {%s}", self.name) + raise + finally: + end = time.clock() * 1000 + sql_logger.debug("[SQL time] {%s} %f", self.name, end - start) class SQLBaseStore(object): + _TXN_ID = 0 def __init__(self, hs): self.hs = hs @@ -74,10 +86,30 @@ class SQLBaseStore(object): self.event_factory = hs.get_event_factory() self._clock = hs.get_clock() - def runInteraction(self, func, *args, **kwargs): + def runInteraction(self, desc, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" def inner_func(txn, *args, **kwargs): - return func(LoggingTransaction(txn), *args, **kwargs) + start = time.clock() * 1000 + txn_id = SQLBaseStore._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + + name = "%s-%x" % (desc, txn_id, ) + + transaction_logger.debug("[TXN START] {%s}", name) + try: + return func(LoggingTransaction(txn, name), *args, **kwargs) + except: + logger.exception("[TXN FAIL] {%s}", name) + raise + finally: + end = time.clock() * 1000 + transaction_logger.debug( + "[TXN END] {%s} %f", + name, end - start + ) return self._db_pool.runInteraction(inner_func, *args, **kwargs) @@ -113,7 +145,7 @@ class SQLBaseStore(object): else: return cursor.fetchall() - return self.runInteraction(interaction) + return self.runInteraction("_execute", interaction) def _execute_and_decode(self, query, *args): return self._execute(self.cursor_to_dict, query, *args) @@ -130,6 +162,7 @@ class SQLBaseStore(object): or_replace : bool; if True performs an INSERT OR REPLACE """ return self.runInteraction( + "_simple_insert", self._simple_insert_txn, table, values, or_replace=or_replace, or_ignore=or_ignore, ) @@ -170,7 +203,6 @@ class SQLBaseStore(object): table, keyvalues, retcols=retcols, allow_none=allow_none ) - @defer.inlineCallbacks def _simple_select_one_onecol(self, table, keyvalues, retcol, allow_none=False): """Executes a SELECT query on the named table, which is expected to @@ -181,19 +213,41 @@ class SQLBaseStore(object): keyvalues : dict of column names and values to select the row with retcol : string giving the name of the column to return """ - ret = yield self._simple_select_one( + return self.runInteraction( + "_simple_select_one_onecol_txn", + self._simple_select_one_onecol_txn, + table, keyvalues, retcol, allow_none=allow_none, + ) + + def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol, + allow_none=False): + ret = self._simple_select_onecol_txn( + txn, table=table, keyvalues=keyvalues, - retcols=[retcol], - allow_none=allow_none + retcol=retcol, ) if ret: - defer.returnValue(ret[retcol]) + return ret[0] else: - defer.returnValue(None) + if allow_none: + return None + else: + raise StoreError(404, "No row found") + + def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): + sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { + "retcol": retcol, + "table": table, + "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), + } + + txn.execute(sql, keyvalues.values()) + + return [r[0] for r in txn.fetchall()] + - @defer.inlineCallbacks def _simple_select_onecol(self, table, keyvalues, retcol): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -206,19 +260,11 @@ class SQLBaseStore(object): Returns: Deferred: Results in a list """ - sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { - "retcol": retcol, - "table": table, - "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), - } - - def func(txn): - txn.execute(sql, keyvalues.values()) - return txn.fetchall() - - res = yield self.runInteraction(func) - - defer.returnValue([r[0] for r in res]) + return self.runInteraction( + "_simple_select_onecol", + self._simple_select_onecol_txn, + table, keyvalues, retcol + ) def _simple_select_list(self, table, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or @@ -239,7 +285,7 @@ class SQLBaseStore(object): txn.execute(sql, keyvalues.values()) return self.cursor_to_dict(txn) - return self.runInteraction(func) + return self.runInteraction("_simple_select_list", func) def _simple_update_one(self, table, keyvalues, updatevalues, retcols=None): @@ -307,7 +353,7 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched") return ret - return self.runInteraction(func) + return self.runInteraction("_simple_selectupdate_one", func) def _simple_delete_one(self, table, keyvalues): """Executes a DELETE query on the named table, expecting to delete a @@ -319,7 +365,7 @@ class SQLBaseStore(object): """ sql = "DELETE FROM %s WHERE %s" % ( table, - " AND ".join("%s = ?" % (k) for k in keyvalues) + " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) def func(txn): @@ -328,7 +374,25 @@ class SQLBaseStore(object): raise StoreError(404, "No row found") if txn.rowcount > 1: raise StoreError(500, "more than one row matched") - return self.runInteraction(func) + return self.runInteraction("_simple_delete_one", func) + + def _simple_delete(self, table, keyvalues): + """Executes a DELETE query on the named table. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + + return self.runInteraction("_simple_delete", self._simple_delete_txn) + + def _simple_delete_txn(self, txn, table, keyvalues): + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k, ) for k in keyvalues) + ) + + return txn.execute(sql, keyvalues.values()) def _simple_max_id(self, table): """Executes a SELECT query on the named table, expecting to return the @@ -346,7 +410,7 @@ class SQLBaseStore(object): return 0 return max_id - return self.runInteraction(func) + return self.runInteraction("_simple_max_id", func) def _parse_event_from_row(self, row_dict): d = copy.deepcopy({k: v for k, v in row_dict.items()}) @@ -370,7 +434,9 @@ class SQLBaseStore(object): ) def _parse_events(self, rows): - return self.runInteraction(self._parse_events_txn, rows) + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) def _parse_events_txn(self, txn, rows): events = [self._parse_event_from_row(r) for r in rows] @@ -378,6 +444,17 @@ class SQLBaseStore(object): sql = "SELECT * FROM events WHERE event_id = ?" for ev in events: + signatures = self._get_event_origin_signatures_txn( + txn, ev.event_id, + ) + + ev.signatures = { + k: encode_base64(v) for k, v in signatures.items() + } + + prev_events = self._get_latest_events_in_room(txn, ev.room_id) + ev.prev_events = [(e_id, s,) for e_id, s, _ in prev_events] + if hasattr(ev, "prev_state"): # Load previous state_content. # TODO: Should we be pulling this out above? diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 52373a28a6..d6a7113b9c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -95,6 +95,7 @@ class DirectoryStore(SQLBaseStore): def delete_room_alias(self, room_alias): return self.runInteraction( + "delete_room_alias", self._delete_room_alias_txn, room_alias, ) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py new file mode 100644 index 0000000000..dcc116bad2 --- /dev/null +++ b/synapse/storage/event_federation.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from syutil.base64util import encode_base64 + +import logging + + +logger = logging.getLogger(__name__) + + +class EventFederationStore(SQLBaseStore): + + def get_oldest_events_in_room(self, room_id): + return self.runInteraction( + "get_oldest_events_in_room", + self._get_oldest_events_in_room_txn, + room_id, + ) + + def _get_oldest_events_in_room_txn(self, txn, room_id): + return self._simple_select_onecol_txn( + txn, + table="event_backward_extremities", + keyvalues={ + "room_id": room_id, + }, + retcol="event_id", + ) + + def get_latest_events_in_room(self, room_id): + return self.runInteraction( + "get_latest_events_in_room", + self._get_latest_events_in_room, + room_id, + ) + + def _get_latest_events_in_room(self, txn, room_id): + self._simple_select_onecol_txn( + txn, + table="event_forward_extremities", + keyvalues={ + "room_id": room_id, + }, + retcol="event_id", + ) + + sql = ( + "SELECT e.event_id, e.depth FROM events as e " + "INNER JOIN event_forward_extremities as f " + "ON e.event_id = f.event_id " + "WHERE f.room_id = ?" + ) + + txn.execute(sql, (room_id, )) + + results = [] + for event_id, depth in txn.fetchall(): + hashes = self._get_event_reference_hashes_txn(txn, event_id) + prev_hashes = { + k: encode_base64(v) for k, v in hashes.items() + if k == "sha256" + } + results.append((event_id, prev_hashes, depth)) + + return results + + def get_min_depth(self, room_id): + return self.runInteraction( + "get_min_depth", + self._get_min_depth_interaction, + room_id, + ) + + def _get_min_depth_interaction(self, txn, room_id): + min_depth = self._simple_select_one_onecol_txn( + txn, + table="room_depth", + keyvalues={"room_id": room_id,}, + retcol="min_depth", + allow_none=True, + ) + + return int(min_depth) if min_depth is not None else None + + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self._get_min_depth_interaction(txn, room_id) + + do_insert = depth < min_depth if min_depth else True + + if do_insert: + self._simple_insert_txn( + txn, + table="room_depth", + values={ + "room_id": room_id, + "min_depth": depth, + }, + or_replace=True, + ) + + def _handle_prev_events(self, txn, outlier, event_id, prev_events, + room_id): + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_insert_txn( + txn, + table="event_edges", + values={ + "event_id": event_id, + "prev_event_id": e_id, + "room_id": room_id, + }, + or_ignore=True, + ) + + # Update the extremities table if this is not an outlier. + if not outlier: + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_delete_txn( + txn, + table="event_forward_extremities", + keyvalues={ + "event_id": e_id, + "room_id": room_id, + } + ) + + + + # We only insert as a forward extremity the new pdu if there are no + # other pdus that reference it as a prev pdu + query = ( + "INSERT OR IGNORE INTO %(table)s (event_id, room_id) " + "SELECT ?, ? WHERE NOT EXISTS (" + "SELECT 1 FROM %(event_edges)s WHERE " + "prev_event_id = ? " + ")" + ) % { + "table": "event_forward_extremities", + "event_edges": "event_edges", + } + + logger.debug("query: %s", query) + + txn.execute(query, (event_id, room_id, event_id)) + + # Insert all the prev_pdus as a backwards thing, they'll get + # deleted in a second if they're incorrect anyway. + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_insert_txn( + txn, + table="event_backward_extremities", + values={ + "event_id": e_id, + "room_id": room_id, + }, + or_ignore=True, + ) + + # Also delete from the backwards extremities table all ones that + # reference pdus that we have already seen + query = ( + "DELETE FROM event_backward_extremities WHERE EXISTS (" + "SELECT 1 FROM events " + "WHERE " + "event_backward_extremities.event_id = events.event_id " + "AND not events.outlier " + ")" + ) + txn.execute(query) + + + def get_backfill_events(self, room_id, event_list, limit): + """Get a list of Events for a given topic that occured before (and + including) the pdus in pdu_list. Return a list of max size `limit`. + + Args: + txn + room_id (str) + event_list (list) + limit (int) + + Return: + list: A list of PduTuples + """ + return self.runInteraction( + "get_backfill_events", + self._get_backfill_events, room_id, event_list, limit + ) + + def _get_backfill_events(self, txn, room_id, event_list, limit): + logger.debug( + "_get_backfill_events: %s, %s, %s", + room_id, repr(event_list), limit + ) + + # We seed the pdu_results with the things from the pdu_list. + event_results = event_list + + front = event_list + + query = ( + "SELECT prev_event_id FROM event_edges " + "WHERE room_id = ? AND event_id = ? " + "LIMIT ?" + ) + + # We iterate through all event_ids in `front` to select their previous + # events. These are dumped in `new_front`. + # We continue until we reach the limit *or* new_front is empty (i.e., + # we've run out of things to select + while front and len(event_results) < limit: + + new_front = [] + for event_id in front: + logger.debug( + "_backfill_interaction: id=%s", + event_id + ) + + txn.execute( + query, + (room_id, event_id, limit - len(event_results)) + ) + + for row in txn.fetchall(): + logger.debug( + "_backfill_interaction: got id=%s", + *row + ) + new_front.append(row) + + front = new_front + event_results += new_front + + # We also want to update the `prev_pdus` attributes before returning. + return self._get_pdu_tuples(txn, event_results) diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py deleted file mode 100644 index d70467dcd6..0000000000 --- a/synapse/storage/pdu.py +++ /dev/null @@ -1,915 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -from ._base import SQLBaseStore, Table, JoinHelper - -from synapse.federation.units import Pdu -from synapse.util.logutils import log_function - -from collections import namedtuple - -import logging - -logger = logging.getLogger(__name__) - - -class PduStore(SQLBaseStore): - """A collection of queries for handling PDUs. - """ - - def get_pdu(self, pdu_id, origin): - """Given a pdu_id and origin, get a PDU. - - Args: - txn - pdu_id (str) - origin (str) - - Returns: - PduTuple: If the pdu does not exist in the database, returns None - """ - - return self.runInteraction( - self._get_pdu_tuple, pdu_id, origin - ) - - def _get_pdu_tuple(self, txn, pdu_id, origin): - res = self._get_pdu_tuples(txn, [(pdu_id, origin)]) - return res[0] if res else None - - def _get_pdu_tuples(self, txn, pdu_id_tuples): - results = [] - for pdu_id, origin in pdu_id_tuples: - txn.execute( - PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"), - (pdu_id, origin) - ) - - edges = [ - (r.prev_pdu_id, r.prev_origin) - for r in PduEdgesTable.decode_results(txn.fetchall()) - ] - - query = ( - "SELECT %(fields)s FROM %(pdus)s as p " - "LEFT JOIN %(state)s as s " - "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " - "WHERE p.pdu_id = ? AND p.origin = ? " - ) % { - "fields": _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s"), - "pdus": PdusTable.table_name, - "state": StatePdusTable.table_name, - } - - txn.execute(query, (pdu_id, origin)) - - row = txn.fetchone() - if row: - results.append(PduTuple(PduEntry(*row), edges)) - - return results - - def get_current_state_for_context(self, context): - """Get a list of PDUs that represent the current state for a given - context - - Args: - context (str) - - Returns: - list: A list of PduTuples - """ - - return self.runInteraction( - self._get_current_state_for_context, - context - ) - - def _get_current_state_for_context(self, txn, context): - query = ( - "SELECT pdu_id, origin FROM %s WHERE context = ?" - % CurrentStateTable.table_name - ) - - logger.debug("get_current_state %s, Args=%s", query, context) - txn.execute(query, (context,)) - - res = txn.fetchall() - - logger.debug("get_current_state %d results", len(res)) - - return self._get_pdu_tuples(txn, res) - - def _persist_pdu_txn(self, txn, prev_pdus, cols): - """Inserts a (non-state) PDU into the database. - - Args: - txn, - prev_pdus (list) - **cols: The columns to insert into the PdusTable. - """ - entry = PdusTable.EntryType( - **{k: cols.get(k, None) for k in PdusTable.fields} - ) - - txn.execute(PdusTable.insert_statement(), entry) - - self._handle_prev_pdus( - txn, entry.outlier, entry.pdu_id, entry.origin, - prev_pdus, entry.context - ) - - def mark_pdu_as_processed(self, pdu_id, pdu_origin): - """Mark a received PDU as processed. - - Args: - txn - pdu_id (str) - pdu_origin (str) - """ - - return self.runInteraction( - self._mark_as_processed, pdu_id, pdu_origin - ) - - def _mark_as_processed(self, txn, pdu_id, pdu_origin): - txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name) - - def get_all_pdus_from_context(self, context): - """Get a list of all PDUs for a given context.""" - return self.runInteraction( - self._get_all_pdus_from_context, context, - ) - - def _get_all_pdus_from_context(self, txn, context): - query = ( - "SELECT pdu_id, origin FROM %s " - "WHERE context = ?" - ) % PdusTable.table_name - - txn.execute(query, (context,)) - - return self._get_pdu_tuples(txn, txn.fetchall()) - - def get_backfill(self, context, pdu_list, limit): - """Get a list of Pdus for a given topic that occured before (and - including) the pdus in pdu_list. Return a list of max size `limit`. - - Args: - txn - context (str) - pdu_list (list) - limit (int) - - Return: - list: A list of PduTuples - """ - return self.runInteraction( - self._get_backfill, context, pdu_list, limit - ) - - def _get_backfill(self, txn, context, pdu_list, limit): - logger.debug( - "backfill: %s, %s, %s", - context, repr(pdu_list), limit - ) - - # We seed the pdu_results with the things from the pdu_list. - pdu_results = pdu_list - - front = pdu_list - - query = ( - "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s " - "WHERE context = ? AND pdu_id = ? AND origin = ? " - "LIMIT ?" - ) % { - "edges_table": PduEdgesTable.table_name, - } - - # We iterate through all pdu_ids in `front` to select their previous - # pdus. These are dumped in `new_front`. We continue until we reach the - # limit *or* new_front is empty (i.e., we've run out of things to - # select - while front and len(pdu_results) < limit: - - new_front = [] - for pdu_id, origin in front: - logger.debug( - "_backfill_interaction: i=%s, o=%s", - pdu_id, origin - ) - - txn.execute( - query, - (context, pdu_id, origin, limit - len(pdu_results)) - ) - - for row in txn.fetchall(): - logger.debug( - "_backfill_interaction: got i=%s, o=%s", - *row - ) - new_front.append(row) - - front = new_front - pdu_results += new_front - - # We also want to update the `prev_pdus` attributes before returning. - return self._get_pdu_tuples(txn, pdu_results) - - def get_min_depth_for_context(self, context): - """Get the current minimum depth for a context - - Args: - txn - context (str) - """ - return self.runInteraction( - self._get_min_depth_for_context, context - ) - - def _get_min_depth_for_context(self, txn, context): - return self._get_min_depth_interaction(txn, context) - - def _get_min_depth_interaction(self, txn, context): - txn.execute( - "SELECT min_depth FROM %s WHERE context = ?" - % ContextDepthTable.table_name, - (context,) - ) - - row = txn.fetchone() - - return row[0] if row else None - - def _update_min_depth_for_context_txn(self, txn, context, depth): - """Update the minimum `depth` of the given context, which is the line - on which we stop backfilling backwards. - - Args: - context (str) - depth (int) - """ - min_depth = self._get_min_depth_interaction(txn, context) - - do_insert = depth < min_depth if min_depth else True - - if do_insert: - txn.execute( - "INSERT OR REPLACE INTO %s (context, min_depth) " - "VALUES (?,?)" % ContextDepthTable.table_name, - (context, depth) - ) - - def _get_latest_pdus_in_context(self, txn, context): - """Get's a list of the most current pdus for a given context. This is - used when we are sending a Pdu and need to fill out the `prev_pdus` - key - - Args: - txn - context - """ - query = ( - "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p " - "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id " - "AND f.origin = p.origin " - "WHERE f.context = ?" - ) % { - "pdus": PdusTable.table_name, - "forward": PduForwardExtremitiesTable.table_name, - } - - logger.debug("get_prev query: %s", query) - - txn.execute( - query, - (context, ) - ) - - results = txn.fetchall() - - return [(row[0], row[1], row[2]) for row in results] - - @defer.inlineCallbacks - def get_oldest_pdus_in_context(self, context): - """Get a list of Pdus that we haven't backfilled beyond yet (and havent - seen). This list is used when we want to backfill backwards and is the - list we send to the remote server. - - Args: - txn - context (str) - - Returns: - list: A list of PduIdTuple. - """ - results = yield self._execute( - None, - "SELECT pdu_id, origin FROM %(back)s WHERE context = ?" - % {"back": PduBackwardExtremitiesTable.table_name, }, - context - ) - - defer.returnValue([PduIdTuple(i, o) for i, o in results]) - - def is_pdu_new(self, pdu_id, origin, context, depth): - """For a given Pdu, try and figure out if it's 'new', i.e., if it's - not something we got randomly from the past, for example when we - request the current state of the room that will probably return a bunch - of pdus from before we joined. - - Args: - txn - pdu_id (str) - origin (str) - context (str) - depth (int) - - Returns: - bool - """ - - return self.runInteraction( - self._is_pdu_new, - pdu_id=pdu_id, - origin=origin, - context=context, - depth=depth - ) - - def _is_pdu_new(self, txn, pdu_id, origin, context, depth): - # If depth > min depth in back table, then we classify it as new. - # OR if there is nothing in the back table, then it kinda needs to - # be a new thing. - query = ( - "SELECT min(p.depth) FROM %(edges)s as e " - "INNER JOIN %(back)s as b " - "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin " - "INNER JOIN %(pdus)s as p " - "ON e.pdu_id = p.pdu_id AND p.origin = e.origin " - "WHERE p.context = ?" - ) % { - "pdus": PdusTable.table_name, - "edges": PduEdgesTable.table_name, - "back": PduBackwardExtremitiesTable.table_name, - } - - txn.execute(query, (context,)) - - min_depth, = txn.fetchone() - - if not min_depth or depth > int(min_depth): - logger.debug( - "is_new true: id=%s, o=%s, d=%s min_depth=%s", - pdu_id, origin, depth, min_depth - ) - return True - - # If this pdu is in the forwards table, then it also is a new one - query = ( - "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?" - ) % { - "forward": PduForwardExtremitiesTable.table_name, - } - - txn.execute(query, (pdu_id, origin)) - - # Did we get anything? - if txn.fetchall(): - logger.debug( - "is_new true: id=%s, o=%s, d=%s was forward", - pdu_id, origin, depth - ) - return True - - logger.debug( - "is_new false: id=%s, o=%s, d=%s", - pdu_id, origin, depth - ) - - # FINE THEN. It's probably old. - return False - - @staticmethod - @log_function - def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus, - context): - txn.executemany( - PduEdgesTable.insert_statement(), - [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus] - ) - - # Update the extremities table if this is not an outlier. - if not outlier: - - # First, we delete the new one from the forwards extremities table. - query = ( - "DELETE FROM %s WHERE pdu_id = ? AND origin = ?" - % PduForwardExtremitiesTable.table_name - ) - txn.executemany(query, prev_pdus) - - # We only insert as a forward extremety the new pdu if there are no - # other pdus that reference it as a prev pdu - query = ( - "INSERT INTO %(table)s (pdu_id, origin, context) " - "SELECT ?, ?, ? WHERE NOT EXISTS (" - "SELECT 1 FROM %(pdu_edges)s WHERE " - "prev_pdu_id = ? AND prev_origin = ?" - ")" - ) % { - "table": PduForwardExtremitiesTable.table_name, - "pdu_edges": PduEdgesTable.table_name - } - - logger.debug("query: %s", query) - - txn.execute(query, (pdu_id, origin, context, pdu_id, origin)) - - # Insert all the prev_pdus as a backwards thing, they'll get - # deleted in a second if they're incorrect anyway. - txn.executemany( - PduBackwardExtremitiesTable.insert_statement(), - [(i, o, context) for i, o in prev_pdus] - ) - - # Also delete from the backwards extremities table all ones that - # reference pdus that we have already seen - query = ( - "DELETE FROM %(pdu_back)s WHERE EXISTS (" - "SELECT 1 FROM %(pdus)s AS pdus " - "WHERE " - "%(pdu_back)s.pdu_id = pdus.pdu_id " - "AND %(pdu_back)s.origin = pdus.origin " - "AND not pdus.outlier " - ")" - ) % { - "pdu_back": PduBackwardExtremitiesTable.table_name, - "pdus": PdusTable.table_name, - } - txn.execute(query) - - -class StatePduStore(SQLBaseStore): - """A collection of queries for handling state PDUs. - """ - - def _persist_state_txn(self, txn, prev_pdus, cols): - """Inserts a state PDU into the database - - Args: - txn, - prev_pdus (list) - **cols: The columns to insert into the PdusTable and StatePdusTable - """ - pdu_entry = PdusTable.EntryType( - **{k: cols.get(k, None) for k in PdusTable.fields} - ) - state_entry = StatePdusTable.EntryType( - **{k: cols.get(k, None) for k in StatePdusTable.fields} - ) - - logger.debug("Inserting pdu: %s", repr(pdu_entry)) - logger.debug("Inserting state: %s", repr(state_entry)) - - txn.execute(PdusTable.insert_statement(), pdu_entry) - txn.execute(StatePdusTable.insert_statement(), state_entry) - - self._handle_prev_pdus( - txn, - pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus, - pdu_entry.context - ) - - def get_unresolved_state_tree(self, new_state_pdu): - return self.runInteraction( - self._get_unresolved_state_tree, new_state_pdu - ) - - @log_function - def _get_unresolved_state_tree(self, txn, new_pdu): - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - ReturnType = namedtuple( - "StateReturnType", ["new_branch", "current_branch"] - ) - return_value = ReturnType([new_pdu], []) - - if not current: - logger.debug("get_unresolved_state_tree No current state.") - return (return_value, None) - - return_value.current_branch.append(current) - - enum_branches = self._enumerate_state_branches( - txn, new_pdu, current - ) - - missing_branch = None - for branch, prev_state, state in enum_branches: - if state: - return_value[branch].append(state) - else: - # We don't have prev_state :( - missing_branch = branch - break - - return (return_value, missing_branch) - - def update_current_state(self, pdu_id, origin, context, pdu_type, - state_key): - return self.runInteraction( - self._update_current_state, - pdu_id, origin, context, pdu_type, state_key - ) - - def _update_current_state(self, txn, pdu_id, origin, context, pdu_type, - state_key): - query = ( - "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" - ) % { - "curr": CurrentStateTable.table_name, - "fields": CurrentStateTable.get_fields_string(), - "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) - } - - query_args = CurrentStateTable.EntryType( - pdu_id=pdu_id, - origin=origin, - context=context, - pdu_type=pdu_type, - state_key=state_key - ) - - txn.execute(query, query_args) - - def get_current_state_pdu(self, context, pdu_type, state_key): - """For a given context, pdu_type, state_key 3-tuple, return what is - currently considered the current state. - - Args: - txn - context (str) - pdu_type (str) - state_key (str) - - Returns: - PduEntry - """ - - return self.runInteraction( - self._get_current_state_pdu, context, pdu_type, state_key - ) - - def _get_current_state_pdu(self, txn, context, pdu_type, state_key): - return self._get_current_interaction(txn, context, pdu_type, state_key) - - def _get_current_interaction(self, txn, context, pdu_type, state_key): - logger.debug( - "_get_current_interaction %s %s %s", - context, pdu_type, state_key - ) - - fields = _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s") - - current_query = ( - "SELECT %(fields)s FROM %(state)s as s " - "INNER JOIN %(pdus)s as p " - "ON s.pdu_id = p.pdu_id AND s.origin = p.origin " - "INNER JOIN %(curr)s as c " - "ON s.pdu_id = c.pdu_id AND s.origin = c.origin " - "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? " - ) % { - "fields": fields, - "curr": CurrentStateTable.table_name, - "state": StatePdusTable.table_name, - "pdus": PdusTable.table_name, - } - - txn.execute( - current_query, - (context, pdu_type, state_key) - ) - - row = txn.fetchone() - - result = PduEntry(*row) if row else None - - if not result: - logger.debug("_get_current_interaction not found") - else: - logger.debug( - "_get_current_interaction found %s %s", - result.pdu_id, result.origin - ) - - return result - - def handle_new_state(self, new_pdu): - """Actually perform conflict resolution on the new_pdu on the - assumption we have all the pdus required to perform it. - - Args: - new_pdu - - Returns: - bool: True if the new_pdu clobbered the current state, False if not - """ - return self.runInteraction( - self._handle_new_state, new_pdu - ) - - def _handle_new_state(self, txn, new_pdu): - logger.debug( - "handle_new_state %s %s", - new_pdu.pdu_id, new_pdu.origin - ) - - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - is_current = False - - if (not current or not current.prev_state_id - or not current.prev_state_origin): - # Oh, we don't have any state for this yet. - is_current = True - elif (current.pdu_id == new_pdu.prev_state_id - and current.origin == new_pdu.prev_state_origin): - # Oh! A direct clobber. Just do it. - is_current = True - else: - ## - # Ok, now loop through until we get to a common ancestor. - max_new = int(new_pdu.power_level) - max_current = int(current.power_level) - - enum_branches = self._enumerate_state_branches( - txn, new_pdu, current - ) - for branch, prev_state, state in enum_branches: - if not state: - raise RuntimeError( - "Could not find state_pdu %s %s" % - ( - prev_state.prev_state_id, - prev_state.prev_state_origin - ) - ) - - if branch == 0: - max_new = max(int(state.depth), max_new) - else: - max_current = max(int(state.depth), max_current) - - is_current = max_new > max_current - - if is_current: - logger.debug("handle_new_state make current") - - # Right, this is a new thing, so woo, just insert it. - txn.execute( - "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" - % { - "curr": CurrentStateTable.table_name, - "fields": CurrentStateTable.get_fields_string(), - "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) - }, - CurrentStateTable.EntryType( - *(new_pdu.__dict__[k] for k in CurrentStateTable.fields) - ) - ) - else: - logger.debug("handle_new_state not current") - - logger.debug("handle_new_state done") - - return is_current - - @log_function - def _enumerate_state_branches(self, txn, pdu_a, pdu_b): - branch_a = pdu_a - branch_b = pdu_b - - while True: - if (branch_a.pdu_id == branch_b.pdu_id - and branch_a.origin == branch_b.origin): - # Woo! We found a common ancestor - logger.debug("_enumerate_state_branches Found common ancestor") - break - - do_branch_a = ( - hasattr(branch_a, "prev_state_id") and - branch_a.prev_state_id - ) - - do_branch_b = ( - hasattr(branch_b, "prev_state_id") and - branch_b.prev_state_id - ) - - logger.debug( - "do_branch_a=%s, do_branch_b=%s", - do_branch_a, do_branch_b - ) - - if do_branch_a and do_branch_b: - do_branch_a = int(branch_a.depth) > int(branch_b.depth) - - if do_branch_a: - pdu_tuple = PduIdTuple( - branch_a.prev_state_id, - branch_a.prev_state_origin - ) - - prev_branch = branch_a - - logger.debug("getting branch_a prev %s", pdu_tuple) - branch_a = self._get_pdu_tuple(txn, *pdu_tuple) - if branch_a: - branch_a = Pdu.from_pdu_tuple(branch_a) - - logger.debug("branch_a=%s", branch_a) - - yield (0, prev_branch, branch_a) - - if not branch_a: - break - elif do_branch_b: - pdu_tuple = PduIdTuple( - branch_b.prev_state_id, - branch_b.prev_state_origin - ) - - prev_branch = branch_b - - logger.debug("getting branch_b prev %s", pdu_tuple) - branch_b = self._get_pdu_tuple(txn, *pdu_tuple) - if branch_b: - branch_b = Pdu.from_pdu_tuple(branch_b) - - logger.debug("branch_b=%s", branch_b) - - yield (1, prev_branch, branch_b) - - if not branch_b: - break - else: - break - - -class PdusTable(Table): - table_name = "pdus" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "ts", - "depth", - "is_state", - "content_json", - "unrecognized_keys", - "outlier", - "have_processed", - ] - - EntryType = namedtuple("PdusEntry", fields) - - -class PduDestinationsTable(Table): - table_name = "pdu_destinations" - - fields = [ - "pdu_id", - "origin", - "destination", - "delivered_ts", - ] - - EntryType = namedtuple("PduDestinationsEntry", fields) - - -class PduEdgesTable(Table): - table_name = "pdu_edges" - - fields = [ - "pdu_id", - "origin", - "prev_pdu_id", - "prev_origin", - "context" - ] - - EntryType = namedtuple("PduEdgesEntry", fields) - - -class PduForwardExtremitiesTable(Table): - table_name = "pdu_forward_extremities" - - fields = [ - "pdu_id", - "origin", - "context", - ] - - EntryType = namedtuple("PduForwardExtremitiesEntry", fields) - - -class PduBackwardExtremitiesTable(Table): - table_name = "pdu_backward_extremities" - - fields = [ - "pdu_id", - "origin", - "context", - ] - - EntryType = namedtuple("PduBackwardExtremitiesEntry", fields) - - -class ContextDepthTable(Table): - table_name = "context_depth" - - fields = [ - "context", - "min_depth", - ] - - EntryType = namedtuple("ContextDepthEntry", fields) - - -class StatePdusTable(Table): - table_name = "state_pdus" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "state_key", - "power_level", - "prev_state_id", - "prev_state_origin", - ] - - EntryType = namedtuple("StatePdusEntry", fields) - - -class CurrentStateTable(Table): - table_name = "current_state" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "state_key", - ] - - EntryType = namedtuple("CurrentStateEntry", fields) - -_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable) - - -# TODO: These should probably be put somewhere more sensible -PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin")) - -PduEntry = _pdu_state_joiner.EntryType -""" We are always interested in the join of the PdusTable and StatePdusTable, -rather than just the PdusTable. - -This does not include a prev_pdus key. -""" - -PduTuple = namedtuple( - "PduTuple", - ("pdu_entry", "prev_pdu_list") -) -""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent -the `prev_pdus` key of a PDU. -""" diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 719806f82b..a2ca6f9a69 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -62,8 +62,10 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if the user_id could not be registered. """ - yield self.runInteraction(self._register, user_id, token, - password_hash) + yield self.runInteraction( + "register", + self._register, user_id, token, password_hash + ) def _register(self, txn, user_id, token, password_hash): now = int(self.clock.time()) @@ -100,6 +102,7 @@ class RegistrationStore(SQLBaseStore): StoreError if no user was found. """ return self.runInteraction( + "get_user_by_token", self._query_for_auth, token ) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 8cd46334cf..7e48ce9cc3 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -150,6 +150,7 @@ class RoomStore(SQLBaseStore): def get_power_level(self, room_id, user_id): return self.runInteraction( + "get_power_level", self._get_power_level, room_id, user_id, ) @@ -183,6 +184,7 @@ class RoomStore(SQLBaseStore): def get_ops_levels(self, room_id): return self.runInteraction( + "get_ops_levels", self._get_ops_levels, room_id, ) diff --git a/synapse/storage/schema/edge_pdus.sql b/synapse/storage/schema/edge_pdus.sql deleted file mode 100644 index 8a00868065..0000000000 --- a/synapse/storage/schema/edge_pdus.sql +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2014 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS context_edge_pdus( - id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this - pdu_id TEXT, - origin TEXT, - context TEXT, - CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin) -); - -CREATE TABLE IF NOT EXISTS origin_edge_pdus( - id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this - pdu_id TEXT, - origin TEXT, - CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin) -); - -CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin); -CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin); diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql new file mode 100644 index 0000000000..e5f768c705 --- /dev/null +++ b/synapse/storage/schema/event_edges.sql @@ -0,0 +1,49 @@ + +CREATE TABLE IF NOT EXISTS event_forward_extremities( + event_id TEXT, + room_id TEXT, + CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id); +CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_backward_extremities( + event_id TEXT, + room_id TEXT, + CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id); +CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_edges( + event_id TEXT, + prev_event_id TEXT, + room_id TEXT, + CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id) +); + +CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id); +CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id); + + +CREATE TABLE IF NOT EXISTS room_depth( + room_id TEXT, + min_depth INTEGER, + CONSTRAINT uniqueness UNIQUE (room_id) +); + +CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id); + + +create TABLE IF NOT EXISTS event_destinations( + event_id TEXT, + destination TEXT, + delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered + CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id); diff --git a/synapse/storage/schema/event_signatures.sql b/synapse/storage/schema/event_signatures.sql new file mode 100644 index 0000000000..5491c7ecec --- /dev/null +++ b/synapse/storage/schema/event_signatures.sql @@ -0,0 +1,65 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_content_hashes ( + event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_reference_hashes ( + event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes ( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_origin_signatures ( + event_id TEXT, + origin TEXT, + key_id TEXT, + signature BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, key_id) +); + +CREATE INDEX IF NOT EXISTS event_origin_signatures_id ON event_origin_signatures ( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_edge_hashes( + event_id TEXT, + prev_event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE ( + event_id, prev_event_id, algorithm + ) +); + +CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes( + event_id +); diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql index 3aa83f5c8c..8d6f655993 100644 --- a/synapse/storage/schema/im.sql +++ b/synapse/storage/schema/im.sql @@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events( unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, + depth INTEGER DEFAULT 0 NOT NULL, CONSTRAINT ev_uniq UNIQUE (event_id) ); diff --git a/synapse/storage/schema/pdu.sql b/synapse/storage/schema/pdu.sql deleted file mode 100644 index 16e111a56c..0000000000 --- a/synapse/storage/schema/pdu.sql +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2014 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ --- Stores pdus and their content -CREATE TABLE IF NOT EXISTS pdus( - pdu_id TEXT, - origin TEXT, - context TEXT, - pdu_type TEXT, - ts INTEGER, - depth INTEGER DEFAULT 0 NOT NULL, - is_state BOOL, - content_json TEXT, - unrecognized_keys TEXT, - outlier BOOL NOT NULL, - have_processed BOOL, - CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) -); - --- Stores what the current state pdu is for a given (context, pdu_type, key) tuple -CREATE TABLE IF NOT EXISTS state_pdus( - pdu_id TEXT, - origin TEXT, - context TEXT, - pdu_type TEXT, - state_key TEXT, - power_level TEXT, - prev_state_id TEXT, - prev_state_origin TEXT, - CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) - CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin) -); - -CREATE TABLE IF NOT EXISTS current_state( - pdu_id TEXT, - origin TEXT, - context TEXT, - pdu_type TEXT, - state_key TEXT, - CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) - CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE -); - --- Stores where each pdu we want to send should be sent and the delivery status. -create TABLE IF NOT EXISTS pdu_destinations( - pdu_id TEXT, - origin TEXT, - destination TEXT, - delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE -); - -CREATE TABLE IF NOT EXISTS pdu_forward_extremities( - pdu_id TEXT, - origin TEXT, - context TEXT, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE -); - -CREATE TABLE IF NOT EXISTS pdu_backward_extremities( - pdu_id TEXT, - origin TEXT, - context TEXT, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE -); - -CREATE TABLE IF NOT EXISTS pdu_edges( - pdu_id TEXT, - origin TEXT, - prev_pdu_id TEXT, - prev_origin TEXT, - context TEXT, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context) -); - -CREATE TABLE IF NOT EXISTS context_depth( - context TEXT, - min_depth INTEGER, - CONSTRAINT uniqueness UNIQUE (context) -); - -CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context); - - -CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin); - -CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin); --- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination); - -CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context); -CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin); - -CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin); - -CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context); diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql new file mode 100644 index 0000000000..b44c56b519 --- /dev/null +++ b/synapse/storage/schema/state.sql @@ -0,0 +1,33 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS state_groups( + id INTEGER PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS state_groups_state( + state_group INTEGER NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS event_to_state_groups( + event_id TEXT NOT NULL, + state_group INTEGER NOT NULL +); \ No newline at end of file diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py new file mode 100644 index 0000000000..b4b3d5d7ea --- /dev/null +++ b/synapse/storage/signatures.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _base import SQLBaseStore + + +class SignatureStore(SQLBaseStore): + """Persistence for event signatures and hashes""" + + def _get_event_content_hashes_txn(self, txn, event_id): + """Get all the hashes for a given Event. + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + A dict of algorithm -> hash. + """ + query = ( + "SELECT algorithm, hash" + " FROM event_content_hashes" + " WHERE event_id = ?" + ) + txn.execute(query, (event_id, )) + return dict(txn.fetchall()) + + def _store_event_content_hash_txn(self, txn, event_id, algorithm, + hash_bytes): + """Store a hash for a Event + Args: + txn (cursor): + event_id (str): Id for the Event. + algorithm (str): Hashing algorithm. + hash_bytes (bytes): Hash function output bytes. + """ + self._simple_insert_txn( + txn, + "event_content_hashes", + { + "event_id": event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) + + def _get_event_reference_hashes_txn(self, txn, event_id): + """Get all the hashes for a given PDU. + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + A dict of algorithm -> hash. + """ + query = ( + "SELECT algorithm, hash" + " FROM event_reference_hashes" + " WHERE event_id = ?" + ) + txn.execute(query, (event_id, )) + return dict(txn.fetchall()) + + def _store_event_reference_hash_txn(self, txn, event_id, algorithm, + hash_bytes): + """Store a hash for a PDU + Args: + txn (cursor): + event_id (str): Id for the Event. + algorithm (str): Hashing algorithm. + hash_bytes (bytes): Hash function output bytes. + """ + self._simple_insert_txn( + txn, + "event_reference_hashes", + { + "event_id": event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) + + + def _get_event_origin_signatures_txn(self, txn, event_id): + """Get all the signatures for a given PDU. + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + A dict of key_id -> signature_bytes. + """ + query = ( + "SELECT key_id, signature" + " FROM event_origin_signatures" + " WHERE event_id = ? " + ) + txn.execute(query, (event_id, )) + return dict(txn.fetchall()) + + def _store_event_origin_signature_txn(self, txn, event_id, origin, key_id, + signature_bytes): + """Store a signature from the origin server for a PDU. + Args: + txn (cursor): + event_id (str): Id for the Event. + origin (str): origin of the Event. + key_id (str): Id for the signing key. + signature (bytes): The signature. + """ + self._simple_insert_txn( + txn, + "event_origin_signatures", + { + "event_id": event_id, + "origin": origin, + "key_id": key_id, + "signature": buffer(signature_bytes), + }, + or_ignore=True, + ) + + def _get_prev_event_hashes_txn(self, txn, event_id): + """Get all the hashes for previous PDUs of a PDU + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes. + """ + query = ( + "SELECT prev_event_id, algorithm, hash" + " FROM event_edge_hashes" + " WHERE event_id = ?" + ) + txn.execute(query, (event_id, )) + results = {} + for prev_event_id, algorithm, hash_bytes in txn.fetchall(): + hashes = results.setdefault(prev_event_id, {}) + hashes[algorithm] = hash_bytes + return results + + def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id, + algorithm, hash_bytes): + self._simple_insert_txn( + txn, + "event_edge_hashes", + { + "event_id": event_id, + "prev_event_id": prev_event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) \ No newline at end of file diff --git a/synapse/storage/state.py b/synapse/storage/state.py new file mode 100644 index 0000000000..e08acd6404 --- /dev/null +++ b/synapse/storage/state.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from twisted.internet import defer + +from collections import namedtuple + + +StateGroup = namedtuple("StateGroup", ("group", "state")) + + +class StateStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_state_groups(self, event_ids): + groups = set() + for event_id in event_ids: + group = yield self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + ) + if group: + groups.add(group) + + res = [] + for group in groups: + state_ids = yield self._simple_select_onecol( + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + state = [] + for state_id in state_ids: + s = yield self.get_event( + state_id, + allow_none=True, + ) + if s: + state.append(s) + + res.append(StateGroup(group, state)) + + defer.returnValue(res) + + def store_state_groups(self, event): + return self.runInteraction( + "store_state_groups", + self._store_state_groups_txn, event + ) + + def _store_state_groups_txn(self, txn, event): + if not event.state_events: + return + + state_group = event.state_group + if not state_group: + state_group = self._simple_insert_txn( + txn, + table="state_groups", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + } + ) + + for state in event.state_events.values(): + self._simple_insert_txn( + txn, + table="state_groups_state", + values={ + "state_group": state_group, + "room_id": state.room_id, + "type": state.type, + "state_key": state.state_key, + "event_id": state.event_id, + } + ) + + self._simple_insert_txn( + txn, + table="event_to_state_groups", + values={ + "state_group": state_group, + "event_id": event.event_id, + } + ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d61f909939..8f7f61d29d 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -309,7 +309,10 @@ class StreamStore(SQLBaseStore): defer.returnValue(ret) def get_room_events_max_id(self): - return self.runInteraction(self._get_room_events_max_id_txn) + return self.runInteraction( + "get_room_events_max_id", + self._get_room_events_max_id_txn + ) def _get_room_events_max_id_txn(self, txn): txn.execute( diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 2ba8e30efe..ea67900788 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -14,7 +14,6 @@ # limitations under the License. from ._base import SQLBaseStore, Table -from .pdu import PdusTable from collections import namedtuple @@ -42,6 +41,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( + "get_received_txn_response", self._get_received_txn_response, transaction_id, origin ) @@ -73,6 +73,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( + "set_received_txn_response", self._set_received_txn_response, transaction_id, origin, code, response_dict ) @@ -106,6 +107,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( + "prep_send_transaction", self._prep_send_transaction, transaction_id, destination, origin_server_ts, pdu_list ) @@ -139,15 +141,15 @@ class TransactionStore(SQLBaseStore): # Update the tx id -> pdu id mapping - values = [ - (transaction_id, destination, pdu[0], pdu[1]) - for pdu in pdu_list - ] - - logger.debug("Inserting: %s", repr(values)) - - query = TransactionsToPduTable.insert_statement() - txn.executemany(query, values) + # values = [ + # (transaction_id, destination, pdu[0], pdu[1]) + # for pdu in pdu_list + # ] + # + # logger.debug("Inserting: %s", repr(values)) + # + # query = TransactionsToPduTable.insert_statement() + # txn.executemany(query, values) return prev_txns @@ -161,6 +163,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ return self.runInteraction( + "delivered_txn", self._delivered_txn, transaction_id, destination, code, response_dict ) @@ -186,6 +189,7 @@ class TransactionStore(SQLBaseStore): list: A list of `ReceivedTransactionsTable.EntryType` """ return self.runInteraction( + "get_transactions_after", self._get_transactions_after, transaction_id, destination ) @@ -202,49 +206,6 @@ class TransactionStore(SQLBaseStore): return ReceivedTransactionsTable.decode_results(txn.fetchall()) - def get_pdus_after_transaction(self, transaction_id, destination): - """For a given local transaction_id that we sent to a given destination - home server, return a list of PDUs that were sent to that destination - after it. - - Args: - txn - transaction_id (str) - destination (str) - - Returns - list: A list of PduTuple - """ - return self.runInteraction( - self._get_pdus_after_transaction, - transaction_id, destination - ) - - def _get_pdus_after_transaction(self, txn, transaction_id, destination): - - # Query that first get's all transaction_ids with an id greater than - # the one given from the `sent_transactions` table. Then JOIN on this - # from the `tx->pdu` table to get a list of (pdu_id, origin) that - # specify the pdus that were sent in those transactions. - query = ( - "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp " - "INNER JOIN %(sent_tx)s as st " - "ON tp.transaction_id = st.transaction_id " - "AND tp.destination = st.destination " - "WHERE st.id > (" - "SELECT id FROM %(sent_tx)s " - "WHERE transaction_id = ? AND destination = ?" - ) % { - "tx_pdu": TransactionsToPduTable.table_name, - "sent_tx": SentTransactions.table_name, - } - - txn.execute(query, (transaction_id, destination)) - - pdus = PdusTable.decode_results(txn.fetchall()) - - return self._get_pdu_tuples(txn, pdus) - class ReceivedTransactionsTable(Table): table_name = "received_transactions" diff --git a/synapse/types.py b/synapse/types.py index c51bc8e4f2..649ff2f7d7 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -78,6 +78,11 @@ class DomainSpecificString( """Create a structure on the local domain""" return cls(localpart=localpart, domain=hs.hostname, is_mine=True) + @classmethod + def create(cls, localpart, domain, hs): + is_mine = domain == hs.hostname + return cls(localpart=localpart, domain=domain, is_mine=is_mine) + class UserID(DomainSpecificString): """Structure representing a user ID.""" @@ -94,6 +99,11 @@ class RoomID(DomainSpecificString): SIGIL = "!" +class EventID(DomainSpecificString): + """Structure representing an event id. """ + SIGIL = "$" + + class StreamToken( namedtuple( "Token", diff --git a/synapse/util/async.py b/synapse/util/async.py index 647ea6142c..bf578f8bfb 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -21,3 +21,10 @@ def sleep(seconds): d = defer.Deferred() reactor.callLater(seconds, d.callback, seconds) return d + + +def run_on_reactor(): + """ This will cause the rest of the function to be invoked upon the next + iteration of the main loop + """ + return sleep(0) \ No newline at end of file |