diff options
97 files changed, 2733 insertions, 2466 deletions
diff --git a/CHANGES.rst b/CHANGES.rst index 08efbbf244..78c178bafd 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,19 @@ +Changes in synapse 0.4.2 (2014-10-31) +===================================== + +Homeserver: + * Fix bugs where we did not notify users of correct presence updates. + * Fix bug where we did not handle sub second event stream timeouts. + +Webclient: + * Add ability to click on messages to see JSON. + * Add ability to redact messages. + * Add ability to view and edit all room state JSON. + * Handle incoming redactions. + * Improve feedback on errors. + * Fix bugs in mobile CSS. + * Fix bugs with desktop notifications. + Changes in synapse 0.4.1 (2014-10-17) ===================================== Webclient: diff --git a/VERSION b/VERSION index 267577d47e..2b7c5ae018 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.4.1 +0.4.2 diff --git a/demo/start.sh b/demo/start.sh index fc6cd6303f..0530f0a26e 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -32,7 +32,7 @@ for port in 8080 8081 8082; do -D --pid-file "$DIR/$port.pid" \ --manhole $((port + 1000)) \ --tls-dh-params-path "demo/demo.tls.dh" \ - $PARAMS + $PARAMS $SYNAPSE_PARAMS python -m synapse.app.homeserver \ --config-path "demo/etc/$port.config" \ diff --git a/synapse/__init__.py b/synapse/__init__.py index 7067188c5b..23ae5f003f 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a synapse home server. """ -__version__ = "0.4.1" +__version__ = "0.4.2" diff --git a/synapse/api/__init__.py b/synapse/api/__init__.py index 9bff9ec169..f9811bfa04 100644 --- a/synapse/api/__init__.py +++ b/synapse/api/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e1b1823cd7..c684265101 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -21,6 +21,8 @@ from synapse.api.constants import Membership, JoinRules from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.events.room import ( RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent, + RoomJoinRulesEvent, RoomOpsPowerLevelsEvent, InviteJoinEvent, + RoomCreateEvent, ) from synapse.util.logutils import log_function @@ -47,42 +49,60 @@ class Auth(object): """ try: if hasattr(event, "room_id"): + if event.old_state_events is None: + # Oh, we don't know what the state of the room was, so we + # are trusting that this is allowed (at least for now) + defer.returnValue(True) + + if hasattr(event, "outlier") and event.outlier is True: + # TODO (erikj): Auth for outliers is done differently. + defer.returnValue(True) + is_state = hasattr(event, "state_key") + if event.type == RoomCreateEvent.TYPE: + # FIXME + defer.returnValue(True) + if event.type == RoomMemberEvent.TYPE: - yield self._can_replace_state(event) - allowed = yield self.is_membership_change_allowed(event) + self._can_replace_state(event) + allowed = self.is_membership_change_allowed(event) + if allowed: + logger.debug("Allowing! %s", event) + else: + logger.debug("Denying! %s", event) defer.returnValue(allowed) return - self._check_joined_room( - member=snapshot.membership_state, - user_id=snapshot.user_id, - room_id=snapshot.room_id, - ) + if not event.type == InviteJoinEvent.TYPE: + self.check_event_sender_in_room(event) if is_state: # TODO (erikj): This really only should be called for *new* # state yield self._can_add_state(event) - yield self._can_replace_state(event) + self._can_replace_state(event) else: yield self._can_send_event(event) if event.type == RoomPowerLevelsEvent.TYPE: - yield self._check_power_levels(event) + self._check_power_levels(event) if event.type == RoomRedactionEvent.TYPE: - yield self._check_redaction(event) + self._check_redaction(event) + + logger.debug("Allowing! %s", event) defer.returnValue(True) else: raise AuthError(500, "Unknown event: %s" % event) except AuthError as e: logger.info("Event auth check failed on event %s with msg: %s", event, e.msg) + logger.info("Denying! %s", event) if raises: raise e + defer.returnValue(False) @defer.inlineCallbacks @@ -98,45 +118,72 @@ class Auth(object): pass defer.returnValue(None) + def check_event_sender_in_room(self, event): + key = (RoomMemberEvent.TYPE, event.user_id, ) + member_event = event.state_events.get(key) + + return self._check_joined_room( + member_event, + event.user_id, + event.room_id + ) + def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: raise AuthError(403, "User %s not in room %s (%s)" % ( user_id, room_id, repr(member) )) - @defer.inlineCallbacks + @log_function def is_membership_change_allowed(self, event): target_user_id = event.state_key - # does this room even exist - room = yield self.store.get_room(event.room_id) - if not room: - raise AuthError(403, "Room does not exist") - # get info about the caller - try: - caller = yield self.store.get_room_member( - user_id=event.user_id, - room_id=event.room_id) - except: - caller = None + key = (RoomMemberEvent.TYPE, event.user_id, ) + caller = event.old_state_events.get(key) + caller_in_room = caller and caller.membership == "join" # get info about the target - try: - target = yield self.store.get_room_member( - user_id=target_user_id, - room_id=event.room_id) - except: - target = None + key = (RoomMemberEvent.TYPE, target_user_id, ) + target = event.old_state_events.get(key) + target_in_room = target and target.membership == "join" membership = event.content["membership"] - join_rule = yield self.store.get_room_join_rule(event.room_id) - if not join_rule: + key = (RoomJoinRulesEvent.TYPE, "", ) + join_rule_event = event.old_state_events.get(key) + if join_rule_event: + join_rule = join_rule_event.content.get( + "join_rule", JoinRules.INVITE + ) + else: join_rule = JoinRules.INVITE + user_level = self._get_power_level_from_event_state( + event, + event.user_id, + ) + + ban_level, kick_level, redact_level = ( + self._get_ops_level_from_event_state( + event + ) + ) + + logger.debug( + "is_membership_change_allowed: %s", + { + "caller_in_room": caller_in_room, + "target_in_room": target_in_room, + "membership": membership, + "join_rule": join_rule, + "target_user_id": target_user_id, + "event.user_id": event.user_id, + } + ) + if Membership.INVITE == membership: # TODO (erikj): We should probably handle this more intelligently # PRIVATE join rules. @@ -153,13 +200,10 @@ class Auth(object): # joined: It's a NOOP if event.user_id != target_user_id: raise AuthError(403, "Cannot force another user to join.") - elif join_rule == JoinRules.PUBLIC or room.is_public: + elif join_rule == JoinRules.PUBLIC: pass elif join_rule == JoinRules.INVITE: - if ( - not caller or caller.membership not in - [Membership.INVITE, Membership.JOIN] - ): + if not caller_in_room: raise AuthError(403, "You are not invited to this room.") else: # TODO (erikj): may_join list @@ -171,29 +215,16 @@ class Auth(object): if not caller_in_room: # trying to leave a room you aren't joined raise AuthError(403, "You are not in room %s." % event.room_id) elif target_user_id != event.user_id: - user_level = yield self.store.get_power_level( - event.room_id, - event.user_id, - ) - _, kick_level, _ = yield self.store.get_ops_levels(event.room_id) - if kick_level: kick_level = int(kick_level) else: - kick_level = 50 + kick_level = 50 # FIXME (erikj): What should we do here? if user_level < kick_level: raise AuthError( 403, "You cannot kick user %s." % target_user_id ) elif Membership.BAN == membership: - user_level = yield self.store.get_power_level( - event.room_id, - event.user_id, - ) - - ban_level, _, _ = yield self.store.get_ops_levels(event.room_id) - if ban_level: ban_level = int(ban_level) else: @@ -204,7 +235,30 @@ class Auth(object): else: raise AuthError(500, "Unknown membership %s" % membership) - defer.returnValue(True) + return True + + def _get_power_level_from_event_state(self, event, user_id): + key = (RoomPowerLevelsEvent.TYPE, "", ) + power_level_event = event.old_state_events.get(key) + level = None + if power_level_event: + level = power_level_event.content.get(user_id) + if not level: + level = power_level_event.content.get("default", 0) + + return level + + def _get_ops_level_from_event_state(self, event): + key = (RoomOpsPowerLevelsEvent.TYPE, "", ) + ops_event = event.old_state_events.get(key) + + if ops_event: + return ( + ops_event.content.get("ban_level"), + ops_event.content.get("kick_level"), + ops_event.content.get("redact_level"), + ) + return None, None, None, @defer.inlineCallbacks def get_user_by_req(self, request): @@ -282,8 +336,8 @@ class Auth(object): else: send_level = 0 - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -308,8 +362,8 @@ class Auth(object): add_level = int(add_level) - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -322,19 +376,9 @@ class Auth(object): defer.returnValue(True) - @defer.inlineCallbacks def _can_replace_state(self, event): - current_state = yield self.store.get_current_state( - event.room_id, - event.type, - event.state_key, - ) - - if current_state: - current_state = current_state[0] - - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -346,6 +390,10 @@ class Auth(object): logger.debug( "Checking power level for %s, %s", event.user_id, user_level ) + + key = (event.type, event.state_key, ) + current_state = event.old_state_events.get(key) + if current_state and hasattr(current_state, "required_power_level"): req = current_state.required_power_level @@ -356,10 +404,9 @@ class Auth(object): "You don't have permission to change that state" ) - @defer.inlineCallbacks def _check_redaction(self, event): - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -368,7 +415,9 @@ class Auth(object): else: user_level = 0 - _, _, redact_level = yield self.store.get_ops_levels(event.room_id) + _, _, redact_level = self._get_ops_level_from_event_state( + event + ) if not redact_level: redact_level = 50 @@ -379,7 +428,6 @@ class Auth(object): "You don't have permission to redact events" ) - @defer.inlineCallbacks def _check_power_levels(self, event): for k, v in event.content.items(): if k == "default": @@ -399,19 +447,16 @@ class Auth(object): except: raise SynapseError(400, "Not a valid power level: %s" % (v,)) - current_state = yield self.store.get_current_state( - event.room_id, - event.type, - event.state_key, - ) + key = (event.type, event.state_key, ) + current_state = event.old_state_events.get(key) if not current_state: return else: current_state = current_state[0] - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 618d3d7577..3cafff0e32 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -58,4 +58,4 @@ class LoginType(object): EMAIL_CODE = u"m.login.email.code" EMAIL_URL = u"m.login.email.url" EMAIL_IDENTITY = u"m.login.email.identity" - RECAPTCHA = u"m.login.recaptcha" \ No newline at end of file + RECAPTCHA = u"m.login.recaptcha" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 6d7d499fea..38ccb4f9d1 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -54,7 +54,7 @@ class SynapseError(CodeMessageException): """Constructs a synapse error. Args: - code (int): The integer error code (typically an HTTP response code) + code (int): The integer error code (an HTTP response code) msg (str): The human-readable error message. err (str): The error code e.g 'M_FORBIDDEN' """ @@ -67,6 +67,7 @@ class SynapseError(CodeMessageException): self.errcode, ) + class RoomError(SynapseError): """An error raised when a room event fails.""" pass @@ -117,6 +118,7 @@ class InvalidCaptchaError(SynapseError): error_url=self.error_url, ) + class LimitExceededError(SynapseError): """A client has sent too many requests and is being throttled. """ diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py index a5a55742e0..b855811b98 100644 --- a/synapse/api/events/__init__.py +++ b/synapse/api/events/__init__.py @@ -71,7 +71,9 @@ class SynapseEvent(JsonEncodedObject): "outlier", "power_level", "redacted", - "prev_pdus", + "prev_events", + "hashes", + "signatures", ] required_keys = [ diff --git a/synapse/api/events/factory.py b/synapse/api/events/factory.py index 74d0ef77f4..9134c82eff 100644 --- a/synapse/api/events/factory.py +++ b/synapse/api/events/factory.py @@ -21,6 +21,8 @@ from synapse.api.events.room import ( RoomRedactionEvent, ) +from synapse.types import EventID + from synapse.util.stringutils import random_string @@ -51,12 +53,26 @@ class EventFactory(object): self.clock = hs.get_clock() self.hs = hs + self.event_id_count = 0 + + def create_event_id(self): + i = str(self.event_id_count) + self.event_id_count += 1 + + local_part = str(int(self.clock.time())) + i + random_string(5) + + e_id = EventID.create_local(local_part, self.hs) + + return e_id.to_string() + def create_event(self, etype=None, **kwargs): kwargs["type"] = etype if "event_id" not in kwargs: - kwargs["event_id"] = "%s@%s" % ( - random_string(10), self.hs.hostname - ) + kwargs["event_id"] = self.create_event_id() + kwargs["origin"] = self.hs.hostname + else: + ev_id = self.hs.parse_eventid(kwargs["event_id"]) + kwargs["origin"] = ev_id.domain if "origin_server_ts" not in kwargs: kwargs["origin_server_ts"] = int(self.clock.time_msec()) diff --git a/synapse/api/events/utils.py b/synapse/api/events/utils.py index 7fdf45a264..31601fd3a9 100644 --- a/synapse/api/events/utils.py +++ b/synapse/api/events/utils.py @@ -32,7 +32,7 @@ def prune_event(event): def prune_pdu(pdu): """Removes keys that contain unrestricted and non-essential data from a PDU """ - return _prune_event_or_pdu(pdu.pdu_type, pdu) + return _prune_event_or_pdu(pdu.type, pdu) def _prune_event_or_pdu(event_type, event): # Remove all extraneous fields. diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index 9bff9ec169..f9811bfa04 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6394bc27d1..a20376b9d6 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -233,7 +233,10 @@ def setup(): f.namespace['hs'] = hs reactor.listenTCP(config.manhole, f, interface='127.0.0.1') - hs.start_listening(config.bind_port, config.unsecure_port) + bind_port = config.bind_port + if config.no_tls: + bind_port = None + hs.start_listening(bind_port, config.unsecure_port) if config.daemonize: print config.pid_file diff --git a/synapse/config/_base.py b/synapse/config/_base.py index b3aeff327c..8ebd2eba4a 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -116,18 +116,25 @@ class Config(object): config = {} for key, value in vars(args).items(): if (key not in set(["config_path", "generate_config"]) - and value is not None): + and value is not None): config[key] = value with open(config_args.config_path, "w") as config_file: # TODO(paul) it would be lovely if we wrote out vim- and emacs- # style mode markers into the file, to hint to people that # this is a YAML file. yaml.dump(config, config_file, default_flow_style=False) - print "A config file has been generated in %s for server name '%s') with corresponding SSL keys and self-signed certificates. Please review this file and customise it to your needs." % (config_args.config_path, config['server_name']) - print "If this server name is incorrect, you will need to regenerate the SSL certificates" + print ( + "A config file has been generated in %s for server name" + " '%s' with corresponding SSL keys and self-signed" + " certificates. Please review this file and customise it to" + " your needs." + ) % ( + config_args.config_path, config['server_name'] + ) + print ( + "If this server name is incorrect, you will need to regenerate" + " the SSL certificates" + ) sys.exit(0) return cls(args) - - - diff --git a/synapse/config/database.py b/synapse/config/database.py index 460445f15d..0aac8c8382 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -16,6 +16,7 @@ from ._base import Config import os + class DatabaseConfig(Config): def __init__(self, args): super(DatabaseConfig, self).__init__(args) @@ -34,4 +35,3 @@ class DatabaseConfig(Config): def generate_config(cls, args, config_dir_path): super(DatabaseConfig, cls).generate_config(args, config_dir_path) args.database_path = os.path.abspath(args.database_path) - diff --git a/synapse/config/email.py b/synapse/config/email.py index 9bcc5a8fea..6bab133224 100644 --- a/synapse/config/email.py +++ b/synapse/config/email.py @@ -35,5 +35,8 @@ class EmailConfig(Config): email_group.add_argument( "--email-smtp-server", default="", - help="The SMTP server to send emails from (e.g. for password resets)." - ) \ No newline at end of file + help=( + "The SMTP server to send emails from (e.g. for password" + " resets)." + ) + ) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 56cd095433..05611d02f7 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -19,6 +19,7 @@ from twisted.python.log import PythonLoggingObserver import logging import logging.config + class LoggingConfig(Config): def __init__(self, args): super(LoggingConfig, self).__init__(args) @@ -51,7 +52,7 @@ class LoggingConfig(Config): level = logging.INFO if self.verbosity: - level = logging.DEBUG + level = logging.DEBUG # FIXME: we need a logging.WARN for a -q quiet option diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index f126782b8d..fb63ed7d9b 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -14,6 +14,7 @@ from ._base import Config + class RatelimitConfig(Config): def __init__(self, args): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index b71d30227c..743bc26474 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -15,6 +15,7 @@ from ._base import Config + class ContentRepositoryConfig(Config): def __init__(self, args): super(ContentRepositoryConfig, self).__init__(args) diff --git a/synapse/config/server.py b/synapse/config/server.py index 086937044f..814a4c349b 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -30,11 +30,12 @@ class ServerConfig(Config): self.pid_file = self.abspath(args.pid_file) self.webclient = True self.manhole = args.manhole + self.no_tls = args.no_tls if not args.content_addr: host = args.server_name if ':' not in host: - host = "%s:%d" % (host, args.bind_port) + host = "%s:%d" % (host, args.bind_port) args.content_addr = "https://%s" % (host,) self.content_addr = args.content_addr @@ -67,6 +68,8 @@ class ServerConfig(Config): server_group.add_argument("--content-addr", default=None, help="The host and scheme to use for the " "content repository") + server_group.add_argument("--no-tls", action='store_true', + help="Don't bind to the https port.") def read_signing_key(self, signing_key_path): signing_keys = self.read_file(signing_key_path, "signing_key") diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 72d5518a89..3600c3ea9e 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -19,7 +19,7 @@ from OpenSSL import crypto import subprocess import os -GENERATE_DH_PARAMS=False +GENERATE_DH_PARAMS = False class TlsConfig(Config): diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 3a51664f46..06675966ce 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -33,7 +33,10 @@ class VoipConfig(Config): ) group.add_argument( "--turn-shared-secret", type=str, default=None, - help="The shared secret used to compute passwords for the TURN server" + help=( + "The shared secret used to compute passwords for the TURN" + " server" + ) ) group.add_argument( "--turn-user-lifetime", type=int, default=(1000 * 60 * 60), diff --git a/synapse/crypto/__init__.py b/synapse/crypto/__init__.py index 9bff9ec169..f9811bfa04 100644 --- a/synapse/crypto/__init__.py +++ b/synapse/crypto/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index f402c795bb..3143322d9c 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -20,6 +20,7 @@ import logging logger = logging.getLogger(__name__) + class ServerContextFactory(ssl.ContextFactory): """Factory for PyOpenSSL SSL contexts that are used to handle incoming connections and to make connections to remote servers.""" @@ -43,4 +44,3 @@ class ServerContextFactory(ssl.ContextFactory): def getContext(self): return self._context - diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 61edd2c6f9..0e8bc7eb6c 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -16,11 +16,12 @@ from synapse.federation.units import Pdu -from synapse.api.events.utils import prune_pdu +from synapse.api.events.utils import prune_pdu, prune_event from syutil.jsonutil import encode_canonical_json from syutil.base64util import encode_base64, decode_base64 from syutil.crypto.jsonsign import sign_json, verify_signed_json +import copy import hashlib import logging @@ -69,6 +70,16 @@ def compute_pdu_event_reference_hash(pdu, hash_algorithm=hashlib.sha256): return (hashed.name, hashed.digest()) +def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): + tmp_event = copy.deepcopy(event) + tmp_event = prune_event(tmp_event) + event_json = tmp_event.get_dict() + event_json.pop("signatures", None) + event_json_bytes = encode_canonical_json(event_json) + hashed = hash_algorithm(event_json_bytes) + return (hashed.name, hashed.digest()) + + def sign_event_pdu(pdu, signature_name, signing_key): tmp_pdu = Pdu(**pdu.get_dict()) tmp_pdu = prune_pdu(tmp_pdu) @@ -83,3 +94,25 @@ def verify_signed_event_pdu(pdu, signature_name, verify_key): tmp_pdu = prune_pdu(tmp_pdu) pdu_json = tmp_pdu.get_dict() verify_signed_json(pdu_json, signature_name, verify_key) + + +def add_hashes_and_signatures(event, signature_name, signing_key, + hash_algorithm=hashlib.sha256): + tmp_event = copy.deepcopy(event) + tmp_event = prune_event(tmp_event) + redact_json = tmp_event.get_dict() + redact_json.pop("signatures", None) + redact_json = sign_json(redact_json, signature_name, signing_key) + event.signatures = redact_json["signatures"] + + event_json = event.get_full_dict() + #TODO: We need to sign the JSON that is going out via fedaration. + event_json.pop("age_ts", None) + event_json.pop("unsigned", None) + event_json.pop("signatures", None) + event_json.pop("hashes", None) + event_json_bytes = encode_canonical_json(event_json) + hashed = hash_algorithm(event_json_bytes) + if not hasattr(event, "hashes"): + event.hashes = {} + event.hashes[hashed.name] = encode_base64(hashed.digest()) diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 7cfec5148e..5191be4570 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -98,4 +98,3 @@ class SynapseKeyClientProtocol(HTTPClient): class SynapseKeyClientFactory(Factory): protocol = SynapseKeyClientProtocol - diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 2440d604c3..694aed3a7d 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -44,7 +44,7 @@ class Keyring(object): raise SynapseError( 400, "Not signed with a supported algorithm", - Codes.UNAUTHORIZED, + Codes.UNAUTHORIZED, ) try: verify_key = yield self.get_server_verify_key(server_name, key_ids) @@ -100,7 +100,7 @@ class Keyring(object): ) if ("signatures" not in response - or server_name not in response["signatures"]): + or server_name not in response["signatures"]): raise ValueError("Key response not signed by remote server") if "tls_certificate" not in response: diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py index 991aae2a56..5ec97a698e 100644 --- a/synapse/federation/pdu_codec.py +++ b/synapse/federation/pdu_codec.py @@ -17,22 +17,11 @@ from .units import Pdu from synapse.crypto.event_signing import ( add_event_pdu_content_hash, sign_event_pdu ) +from synapse.types import EventID import copy -def decode_event_id(event_id, server_name): - parts = event_id.split("@") - if len(parts) < 2: - return (event_id, server_name) - else: - return (parts[0], "".join(parts[1:])) - - -def encode_event_id(pdu_id, origin): - return "%s@%s" % (pdu_id, origin) - - class PduCodec(object): def __init__(self, hs): @@ -40,30 +29,18 @@ class PduCodec(object): self.server_name = hs.hostname self.event_factory = hs.get_event_factory() self.clock = hs.get_clock() + self.hs = hs def event_from_pdu(self, pdu): kwargs = {} - kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin) - kwargs["room_id"] = pdu.context - kwargs["etype"] = pdu.pdu_type - kwargs["prev_pdus"] = pdu.prev_pdus - - if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"): - kwargs["prev_state"] = encode_event_id( - pdu.prev_state_id, pdu.prev_state_origin - ) + kwargs["etype"] = pdu.type kwargs.update({ k: v for k, v in pdu.get_full_dict().items() if k not in [ - "pdu_id", - "context", - "pdu_type", - "prev_pdus", - "prev_state_id", - "prev_state_origin", + "type", ] }) @@ -72,27 +49,12 @@ class PduCodec(object): def pdu_from_event(self, event): d = event.get_full_dict() - d["pdu_id"], d["origin"] = decode_event_id( - event.event_id, self.server_name - ) - d["context"] = event.room_id - d["pdu_type"] = event.type - - if hasattr(event, "prev_pdus"): - d["prev_pdus"] = event.prev_pdus - - if hasattr(event, "prev_state"): - d["prev_state_id"], d["prev_state_origin"] = ( - decode_event_id(event.prev_state, self.server_name) - ) - if hasattr(event, "state_key"): d["is_state"] = True kwargs = copy.deepcopy(event.unrecognized_keys) kwargs.update({ k: v for k, v in d.items() - if k not in ["event_id", "room_id", "type"] }) if "origin_server_ts" not in kwargs: diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 7043fcc504..b04fbb4177 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -32,76 +32,6 @@ import logging logger = logging.getLogger(__name__) -class PduActions(object): - """ Defines persistence actions that relate to handling PDUs. - """ - - def __init__(self, datastore): - self.store = datastore - - @log_function - def mark_as_processed(self, pdu): - """ Persist the fact that we have fully processed the given `Pdu` - - Returns: - Deferred - """ - return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin) - - @defer.inlineCallbacks - @log_function - def after_transaction(self, transaction_id, destination, origin): - """ Returns all `Pdu`s that we sent to the given remote home server - after a given transaction id. - - Returns: - Deferred: Results in a list of `Pdu`s - """ - results = yield self.store.get_pdus_after_transaction( - transaction_id, - destination - ) - - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @defer.inlineCallbacks - @log_function - def get_all_pdus_from_context(self, context): - results = yield self.store.get_all_pdus_from_context(context) - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @defer.inlineCallbacks - @log_function - def backfill(self, context, pdu_list, limit): - """ For a given list of PDU id and origins return the proceeding - `limit` `Pdu`s in the given `context`. - - Returns: - Deferred: Results in a list of `Pdu`s. - """ - results = yield self.store.get_backfill( - context, pdu_list, limit - ) - - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @log_function - def is_new(self, pdu): - """ When we receive a `Pdu` from a remote home server, we want to - figure out whether it is `new`, i.e. it is not some historic PDU that - we haven't seen simply because we haven't backfilled back that far. - - Returns: - Deferred: Results in a `bool` - """ - return self.store.is_pdu_new( - pdu_id=pdu.pdu_id, - origin=pdu.origin, - context=pdu.context, - depth=pdu.depth - ) - - class TransactionActions(object): """ Defines persistence actions that relate to handling Transactions. """ @@ -158,7 +88,6 @@ class TransactionActions(object): transaction.transaction_id, transaction.destination, transaction.origin_server_ts, - [(p["pdu_id"], p["origin"]) for p in transaction.pdus] ) @log_function diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 4a9414c1d4..838e660a46 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -21,7 +21,7 @@ from twisted.internet import defer from .units import Transaction, Pdu, Edu -from .persistence import PduActions, TransactionActions +from .persistence import TransactionActions from synapse.util.logutils import log_function @@ -57,7 +57,7 @@ class ReplicationLayer(object): self.transport_layer.register_request_handler(self) self.store = hs.get_datastore() - self.pdu_actions = PduActions(self.store) + # self.pdu_actions = PduActions(self.store) self.transaction_actions = TransactionActions(self.store) self._transaction_queue = _TransactionQueue( @@ -106,20 +106,11 @@ class ReplicationLayer(object): self.query_handlers[query_type] = handler - @defer.inlineCallbacks @log_function def send_pdu(self, pdu): """Informs the replication layer about a new PDU generated within the home server that should be transmitted to others. - This will fill out various attributes on the PDU object, e.g. the - `prev_pdus` key. - - *Note:* The home server should always call `send_pdu` even if it knows - that it does not need to be replicated to other home servers. This is - in case e.g. someone else joins via a remote home server and then - backfills. - TODO: Figure out when we should actually resolve the deferred. Args: @@ -132,18 +123,12 @@ class ReplicationLayer(object): order = self._order self._order += 1 - logger.debug("[%s] Persisting PDU", pdu.pdu_id) - - # Save *before* trying to send - yield self.store.persist_event(pdu=pdu) - - logger.debug("[%s] Persisted PDU", pdu.pdu_id) - logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id) + logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) # TODO, add errback, etc. self._transaction_queue.enqueue_pdu(pdu, order) - logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id) + logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.event_id) @log_function def send_edu(self, destination, edu_type, content): @@ -181,7 +166,7 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def backfill(self, dest, context, limit): + def backfill(self, dest, context, limit, extremities): """Requests some more historic PDUs for the given context from the given destination server. @@ -189,12 +174,12 @@ class ReplicationLayer(object): dest (str): The remote home server to ask. context (str): The context to backfill. limit (int): The maximum number of PDUs to return. + extremities (list): List of PDU id and origins of the first pdus + we have seen from the context Returns: Deferred: Results in the received PDUs. """ - extremities = yield self.store.get_oldest_pdus_in_context(context) - logger.debug("backfill extrem=%s", extremities) # If there are no extremeties then we've (probably) reached the start. @@ -216,7 +201,7 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False): + def get_pdu(self, destination, event_id, outlier=False): """Requests the PDU with given origin and ID from the remote home server. @@ -225,7 +210,7 @@ class ReplicationLayer(object): Args: destination (str): Which home server to query pdu_origin (str): The home server that originally sent the pdu. - pdu_id (str) + event_id (str) outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if it's from an arbitary point in the context as opposed to part of the current block of PDUs. Defaults to `False` @@ -234,8 +219,9 @@ class ReplicationLayer(object): Deferred: Results in the requested PDU. """ - transaction_data = yield self.transport_layer.get_pdu( - destination, pdu_origin, pdu_id) + transaction_data = yield self.transport_layer.get_event( + destination, event_id + ) transaction = Transaction(**transaction_data) @@ -244,13 +230,13 @@ class ReplicationLayer(object): pdu = None if pdu_list: pdu = pdu_list[0] - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdu) @defer.inlineCallbacks @log_function - def get_state_for_context(self, destination, context): + def get_state_for_context(self, destination, context, event_id=None): """Requests all of the `current` state PDUs for a given context from a remote home server. @@ -263,29 +249,32 @@ class ReplicationLayer(object): """ transaction_data = yield self.transport_layer.get_context_state( - destination, context) + destination, + context, + event_id=event_id, + ) transaction = Transaction(**transaction_data) pdus = [Pdu(outlier=True, **p) for p in transaction.pdus] for pdu in pdus: - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdus) @defer.inlineCallbacks @log_function def on_context_pdus_request(self, context): - pdus = yield self.pdu_actions.get_all_pdus_from_context( - context + raise NotImplementedError( + "on_context_pdus_request is a security violation" ) - defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @defer.inlineCallbacks @log_function def on_backfill_request(self, context, versions, limit): - - pdus = yield self.pdu_actions.backfill(context, versions, limit) + pdus = yield self.handler.on_backfill_request( + context, versions, limit + ) defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @@ -319,7 +308,7 @@ class ReplicationLayer(object): dl = [] for pdu in pdu_list: - dl.append(self._handle_new_pdu(pdu)) + dl.append(self._handle_new_pdu(transaction.origin, pdu)) if hasattr(transaction, "edus"): for edu in [Edu(**x) for x in transaction.edus]: @@ -351,20 +340,26 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def on_context_state_request(self, context): - results = yield self.store.get_current_state_for_context( - context - ) - - logger.debug("Context returning %d results", len(results)) + def on_context_state_request(self, context, event_id): + if event_id: + pdus = yield self.handler.get_state_for_pdu( + event_id + ) + else: + raise NotImplementedError("Specify an event") + # results = yield self.store.get_current_state_for_context( + # context + # ) + # pdus = [Pdu.from_pdu_tuple(p) for p in results] + # + # logger.debug("Context returning %d results", len(pdus)) - pdus = [Pdu.from_pdu_tuple(p) for p in results] defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @defer.inlineCallbacks @log_function - def on_pdu_request(self, pdu_origin, pdu_id): - pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin) + def on_pdu_request(self, event_id): + pdu = yield self._get_persisted_pdu(event_id) if pdu: defer.returnValue( @@ -376,20 +371,22 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function def on_pull_request(self, origin, versions): - transaction_id = max([int(v) for v in versions]) - - response = yield self.pdu_actions.after_transaction( - transaction_id, - origin, - self.server_name - ) - - if not response: - response = [] - - defer.returnValue( - (200, self._transaction_from_pdus(response).get_dict()) - ) + raise NotImplementedError("Pull transacions not implemented") + + # transaction_id = max([int(v) for v in versions]) + # + # response = yield self.pdu_actions.after_transaction( + # transaction_id, + # origin, + # self.server_name + # ) + # + # if not response: + # response = [] + # + # defer.returnValue( + # (200, self._transaction_from_pdus(response).get_dict()) + # ) @defer.inlineCallbacks def on_query_request(self, query_type, args): @@ -397,21 +394,63 @@ class ReplicationLayer(object): response = yield self.query_handlers[query_type](args) defer.returnValue((200, response)) else: - defer.returnValue((404, "No handler for Query type '%s'" - % (query_type) - )) + defer.returnValue( + (404, "No handler for Query type '%s'" % (query_type, )) + ) @defer.inlineCallbacks + def on_make_join_request(self, context, user_id): + pdu = yield self.handler.on_make_join_request(context, user_id) + defer.returnValue(pdu.get_dict()) + + @defer.inlineCallbacks + def on_invite_request(self, origin, content): + pdu = Pdu(**content) + ret_pdu = yield self.handler.on_send_join_request(origin, pdu) + defer.returnValue((200, ret_pdu.get_dict())) + + @defer.inlineCallbacks + def on_send_join_request(self, origin, content): + pdu = Pdu(**content) + state = yield self.handler.on_send_join_request(origin, pdu) + defer.returnValue((200, self._transaction_from_pdus(state).get_dict())) + + @defer.inlineCallbacks + def make_join(self, destination, context, user_id): + pdu_dict = yield self.transport_layer.make_join( + destination=destination, + context=context, + user_id=user_id, + ) + + logger.debug("Got response to make_join: %s", pdu_dict) + + defer.returnValue(Pdu(**pdu_dict)) + + @defer.inlineCallbacks + def send_join(self, destination, pdu): + _, content = yield self.transport_layer.send_join( + destination, + pdu.room_id, + pdu.event_id, + pdu.get_dict(), + ) + + logger.debug("Got content: %s", content) + pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])] + for pdu in pdus: + yield self._handle_new_pdu(destination, pdu) + + defer.returnValue(pdus) + @log_function - def _get_persisted_pdu(self, pdu_id, pdu_origin): + def _get_persisted_pdu(self, event_id): """ Get a PDU from the database with given origin and id. Returns: Deferred: Results in a `Pdu`. """ - pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin) - - defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple)) + return self.handler.get_persisted_pdu(event_id) def _transaction_from_pdus(self, pdu_list): """Returns a new Transaction containing the given PDUs suitable for @@ -433,48 +472,60 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def _handle_new_pdu(self, pdu, backfilled=False): + def _handle_new_pdu(self, origin, pdu, backfilled=False): # We reprocess pdus when we have seen them only as outliers - existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin) + existing = yield self._get_persisted_pdu(pdu.event_id) if existing and (not existing.outlier or pdu.outlier): - logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin) + logger.debug("Already seen pdu %s", pdu.event_id) defer.returnValue({}) return + state = None + # Get missing pdus if necessary. - is_new = yield self.pdu_actions.is_new(pdu) - if is_new and not pdu.outlier: + if not pdu.outlier: # We only backfill backwards to the min depth. - min_depth = yield self.store.get_min_depth_for_context(pdu.context) + min_depth = yield self.handler.get_min_depth_for_context( + pdu.room_id + ) if min_depth and pdu.depth > min_depth: - for pdu_id, origin, hashes in pdu.prev_pdus: - exists = yield self._get_persisted_pdu(pdu_id, origin) + for event_id, hashes in pdu.prev_events: + exists = yield self._get_persisted_pdu(event_id) if not exists: - logger.debug("Requesting pdu %s %s", pdu_id, origin) + logger.debug("Requesting pdu %s", event_id) try: yield self.get_pdu( pdu.origin, - pdu_id=pdu_id, - pdu_origin=origin + event_id=event_id, ) - logger.debug("Processed pdu %s %s", pdu_id, origin) + logger.debug("Processed pdu %s", event_id) except: # TODO(erikj): Do some more intelligent retries. logger.exception("Failed to get PDU") + else: + # We need to get the state at this event, since we have reached + # a backward extremity edge. + state = yield self.get_state_for_context( + origin, pdu.room_id, pdu.event_id, + ) # Persist the Pdu, but don't mark it as processed yet. - yield self.store.persist_event(pdu=pdu) + # yield self.store.persist_event(pdu=pdu) if not backfilled: - ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled) + ret = yield self.handler.on_receive_pdu( + pdu, + backfilled=backfilled, + state=state, + ) else: ret = None - yield self.pdu_actions.mark_as_processed(pdu) + # yield self.pdu_actions.mark_as_processed(pdu) defer.returnValue(ret) diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py index e7517cac4d..04ad7e63ae 100644 --- a/synapse/federation/transport.py +++ b/synapse/federation/transport.py @@ -72,7 +72,7 @@ class TransportLayer(object): self.received_handler = None @log_function - def get_context_state(self, destination, context): + def get_context_state(self, destination, context, event_id=None): """ Requests all state for a given context (i.e. room) from the given server. @@ -89,54 +89,62 @@ class TransportLayer(object): subpath = "/state/%s/" % context - return self._do_request_for_transaction(destination, subpath) + args = {} + if event_id: + args["event_id"] = event_id + + return self._do_request_for_transaction( + destination, subpath, args=args + ) @log_function - def get_pdu(self, destination, pdu_origin, pdu_id): + def get_event(self, destination, event_id): """ Requests the pdu with give id and origin from the given server. Args: destination (str): The host name of the remote home server we want to get the state from. - pdu_origin (str): The home server which created the PDU. - pdu_id (str): The id of the PDU being requested. + event_id (str): The id of the event being requested. Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s", - destination, pdu_origin, pdu_id) + logger.debug("get_pdu dest=%s, event_id=%s", + destination, event_id) - subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id) + subpath = "/event/%s/" % (event_id, ) return self._do_request_for_transaction(destination, subpath) @log_function - def backfill(self, dest, context, pdu_tuples, limit): + def backfill(self, dest, context, event_tuples, limit): """ Requests `limit` previous PDUs in a given context before list of PDUs. Args: dest (str) context (str) - pdu_tuples (list) + event_tuples (list) limt (int) Returns: Deferred: Results in a dict received from the remote homeserver. """ logger.debug( - "backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s", - dest, context, repr(pdu_tuples), str(limit) + "backfill dest=%s, context=%s, event_tuples=%s, limit=%s", + dest, context, repr(event_tuples), str(limit) ) - if not pdu_tuples: + if not event_tuples: + # TODO: raise? return - subpath = "/backfill/%s/" % context + subpath = "/backfill/%s/" % (context,) - args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]} - args["limit"] = limit + args = { + "v": event_tuples, + "limit": limit, + } return self._do_request_for_transaction( dest, @@ -198,6 +206,57 @@ class TransportLayer(object): defer.returnValue(response) @defer.inlineCallbacks + @log_function + def make_join(self, destination, context, user_id, retry_on_dns_fail=True): + path = PREFIX + "/make_join/%s/%s" % (context, user_id,) + + response = yield self.client.get_json( + destination=destination, + path=path, + retry_on_dns_fail=retry_on_dns_fail, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks + @log_function + def send_join(self, destination, context, event_id, content): + path = PREFIX + "/send_join/%s/%s" % ( + context, + event_id, + ) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_join", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks + @log_function + def send_invite(self, destination, context, event_id, content): + path = PREFIX + "/invite/%s/%s" % ( + context, + event_id, + ) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_invite", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks def _authenticate_request(self, request): json_request = { "method": request.method, @@ -313,10 +372,10 @@ class TransportLayer(object): # data_id pair. self.server.register_path( "GET", - re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"), + re.compile("^" + PREFIX + "/event/([^/]*)/$"), self._with_authentication( - lambda origin, content, query, pdu_origin, pdu_id: - handler.on_pdu_request(pdu_origin, pdu_id) + lambda origin, content, query, event_id: + handler.on_pdu_request(event_id) ) ) @@ -326,7 +385,10 @@ class TransportLayer(object): re.compile("^" + PREFIX + "/state/([^/]*)/$"), self._with_authentication( lambda origin, content, query, context: - handler.on_context_state_request(context) + handler.on_context_state_request( + context, + query.get("event_id", [None])[0], + ) ) ) @@ -362,6 +424,39 @@ class TransportLayer(object): ) ) + self.server.register_path( + "GET", + re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, user_id: + self._on_make_join_request( + origin, content, query, context, user_id + ) + ) + ) + + self.server.register_path( + "PUT", + re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + self._on_send_join_request( + origin, content, query, + ) + ) + ) + + self.server.register_path( + "PUT", + re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + self._on_invite_request( + origin, content, query, + ) + ) + ) + @defer.inlineCallbacks @log_function def _on_send_request(self, origin, content, query, transaction_id): @@ -448,124 +543,34 @@ class TransportLayer(object): limit = int(limits[-1]) - versions = [v.split(",", 1) for v in v_list] + versions = v_list return self.request_handler.on_backfill_request( - context, versions, limit) - - -class TransportReceivedHandler(object): - """ Callbacks used when we receive a transaction - """ - def on_incoming_transaction(self, transaction): - """ Called on PUT /send/<transaction_id>, or on response to a request - that we sent (e.g. a backfill request) - - Args: - transaction (synapse.transaction.Transaction): The transaction that - was sent to us. - - Returns: - twisted.internet.defer.Deferred: A deferred that gets fired when - the transaction has finished being processed. - - The result should be a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - -class TransportRequestHandler(object): - """ Handlers used when someone want's data from us - """ - def on_pull_request(self, versions): - """ Called on GET /pull/?v=... - - This is hit when a remote home server wants to get all data - after a given transaction. Mainly used when a home server comes back - online and wants to get everything it has missed. - - Args: - versions (list): A list of transaction_ids that should be used to - determine what PDUs the remote side have not yet seen. - - Returns: - Deferred: Resultsin a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - def on_pdu_request(self, pdu_origin, pdu_id): - """ Called on GET /pdu/<pdu_origin>/<pdu_id>/ - - Someone wants a particular PDU. This PDU may or may not have originated - from us. - - Args: - pdu_origin (str) - pdu_id (str) - - Returns: - Deferred: Resultsin a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - def on_context_state_request(self, context): - """ Called on GET /state/<context>/ - - Gets hit when someone wants all the *current* state for a given - contexts. - - Args: - context (str): The name of the context that we're interested in. - - Returns: - twisted.internet.defer.Deferred: A deferred that gets fired when - the transaction has finished being processed. - - The result should be a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - def on_backfill_request(self, context, versions, limit): - """ Called on GET /backfill/<context>/?v=...&limit=... + context, versions, limit + ) - Gets hit when we want to backfill backwards on a given context from - the given point. + @defer.inlineCallbacks + @log_function + def _on_make_join_request(self, origin, content, query, context, user_id): + content = yield self.request_handler.on_make_join_request( + context, user_id, + ) + defer.returnValue((200, content)) - Args: - context (str): The context to backfill - versions (list): A list of 2-tuples representing where to backfill - from, in the form `(pdu_id, origin)` - limit (int): How many pdus to return. + @defer.inlineCallbacks + @log_function + def _on_send_join_request(self, origin, content, query): + content = yield self.request_handler.on_send_join_request( + origin, content, + ) - Returns: - Deferred: Results in a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. + defer.returnValue((200, content)) - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass + @defer.inlineCallbacks + @log_function + def _on_invite_request(self, origin, content, query): + content = yield self.request_handler.on_invite_request( + origin, content, + ) - def on_query_request(self): - """ Called on a GET /query/<query_type> request. """ + defer.returnValue((200, content)) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index adc3385644..c94dcf64cf 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -34,13 +34,13 @@ class Pdu(JsonEncodedObject): A Pdu can be classified as "state". For a given context, we can efficiently retrieve all state pdu's that haven't been clobbered. Clobbering is done - via a unique constraint on the tuple (context, pdu_type, state_key). A pdu + via a unique constraint on the tuple (context, type, state_key). A pdu is a state pdu if `is_state` is True. Example pdu:: { - "pdu_id": "78c", + "event_id": "$78c:example.com", "origin_server_ts": 1404835423000, "origin": "bar", "prev_ids": [ @@ -53,14 +53,14 @@ class Pdu(JsonEncodedObject): """ valid_keys = [ - "pdu_id", - "context", + "event_id", + "room_id", "origin", "origin_server_ts", - "pdu_type", + "type", "destinations", "transaction_id", - "prev_pdus", + "prev_events", "depth", "content", "outlier", @@ -68,8 +68,7 @@ class Pdu(JsonEncodedObject): "signatures", "is_state", # Below this are keys valid only for State Pdus. "state_key", - "prev_state_id", - "prev_state_origin", + "prev_state", "required_power_level", "user_id", ] @@ -81,18 +80,18 @@ class Pdu(JsonEncodedObject): ] required_keys = [ - "pdu_id", - "context", + "event_id", + "room_id", "origin", "origin_server_ts", - "pdu_type", + "type", "content", ] # TODO: We need to make this properly load content rather than # just leaving it as a dict. (OR DO WE?!) - def __init__(self, destinations=[], is_state=False, prev_pdus=[], + def __init__(self, destinations=[], is_state=False, prev_events=[], outlier=False, hashes={}, signatures={}, **kwargs): if is_state: for required_key in ["state_key"]: @@ -102,66 +101,13 @@ class Pdu(JsonEncodedObject): super(Pdu, self).__init__( destinations=destinations, is_state=bool(is_state), - prev_pdus=prev_pdus, + prev_events=prev_events, outlier=outlier, hashes=hashes, signatures=signatures, **kwargs ) - @classmethod - def from_pdu_tuple(cls, pdu_tuple): - """ Converts a PduTuple to a Pdu - - Args: - pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to - convert - - Returns: - Pdu - """ - if pdu_tuple: - d = copy.copy(pdu_tuple.pdu_entry._asdict()) - d["origin_server_ts"] = d.pop("ts") - - for k in d.keys(): - if d[k] is None: - del d[k] - - d["content"] = json.loads(d["content_json"]) - del d["content_json"] - - args = {f: d[f] for f in cls.valid_keys if f in d} - if "unrecognized_keys" in d and d["unrecognized_keys"]: - args.update(json.loads(d["unrecognized_keys"])) - - hashes = { - alg: encode_base64(hsh) - for alg, hsh in pdu_tuple.hashes.items() - } - - signatures = { - kid: encode_base64(sig) - for kid, sig in pdu_tuple.signatures.items() - } - - prev_pdus = [] - for prev_pdu in pdu_tuple.prev_pdu_list: - prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {}) - prev_hashes = { - alg: encode_base64(hsh) for alg, hsh in prev_hashes.items() - } - prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes)) - - return Pdu( - prev_pdus=prev_pdus, - hashes=hashes, - signatures=signatures, - **args - ) - else: - return None - def __str__(self): return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index de4d23bbb3..28b64565ae 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -16,6 +16,10 @@ from twisted.internet import defer from synapse.api.errors import LimitExceededError +from synapse.util.async import run_on_reactor + +from synapse.crypto.event_signing import add_hashes_and_signatures + class BaseHandler(object): def __init__(self, hs): @@ -30,6 +34,9 @@ class BaseHandler(object): self.clock = hs.get_clock() self.hs = hs + self.signing_key = hs.config.signing_key[0] + self.server_name = hs.hostname + def ratelimit(self, user_id): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.send_message( @@ -44,9 +51,23 @@ class BaseHandler(object): @defer.inlineCallbacks def _on_new_room_event(self, event, snapshot, extra_destinations=[], - extra_users=[]): + extra_users=[], suppress_auth=False): + yield run_on_reactor() + snapshot.fill_out_prev_events(event) + yield self.state_handler.annotate_state_groups(event) + + yield add_hashes_and_signatures( + event, self.server_name, self.signing_key + ) + + if not suppress_auth: + yield self.auth.check(event, snapshot, raises=True) + + if hasattr(event, "state_key"): + yield self.state_handler.handle_new_event(event, snapshot) + yield self.store.persist_event(event) destinations = set(extra_destinations) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index a56830d520..6e897e915d 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -152,5 +152,6 @@ class DirectoryHandler(BaseHandler): user_id=user_id, ) - yield self.state_handler.handle_new_event(event, snapshot) - yield self._on_new_room_event(event, snapshot, extra_users=[user_id]) + yield self._on_new_room_event( + event, snapshot, extra_users=[user_id], suppress_auth=True + ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f52591d2a3..bdd28f04bb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -22,6 +22,8 @@ from synapse.api.constants import Membership from synapse.util.logutils import log_function from synapse.federation.pdu_codec import PduCodec from synapse.api.errors import SynapseError +from synapse.util.async import run_on_reactor +from synapse.types import EventID from twisted.internet import defer, reactor @@ -62,6 +64,9 @@ class FederationHandler(BaseHandler): self.pdu_codec = PduCodec(hs) + # When joining a room we need to queue any events for that room up + self.room_queues = {} + @log_function @defer.inlineCallbacks def handle_new_event(self, event, snapshot): @@ -78,6 +83,8 @@ class FederationHandler(BaseHandler): processing. """ + yield run_on_reactor() + pdu = self.pdu_codec.pdu_from_event(event) if not hasattr(pdu, "destinations") or not pdu.destinations: @@ -87,98 +94,83 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def on_receive_pdu(self, pdu, backfilled): + def on_receive_pdu(self, pdu, backfilled, state=None): """ Called by the ReplicationLayer when we have a new pdu. We need to - do auth checks and put it throught the StateHandler. + do auth checks and put it through the StateHandler. """ event = self.pdu_codec.event_from_pdu(pdu) logger.debug("Got event: %s", event.event_id) - with (yield self.lock_manager.lock(pdu.context)): - if event.is_state and not backfilled: - is_new_state = yield self.state_handler.handle_new_state( - pdu - ) - else: - is_new_state = False + if event.room_id in self.room_queues: + self.room_queues[event.room_id].append(pdu) + return + + logger.debug("Processing event: %s", event.event_id) + + if state: + state = [self.pdu_codec.event_from_pdu(p) for p in state] + + is_new_state = yield self.state_handler.annotate_state_groups( + event, + old_state=state + ) + + logger.debug("Event: %s", event) + + if not backfilled: + yield self.auth.check(event, None, raises=True) + + is_new_state = is_new_state and not backfilled + # TODO: Implement something in federation that allows us to # respond to PDU. - target_is_mine = False - if hasattr(event, "target_host"): - target_is_mine = event.target_host == self.hs.hostname - - if event.type == InviteJoinEvent.TYPE: - if not target_is_mine: - logger.debug("Ignoring invite/join event %s", event) - return - - # If we receive an invite/join event then we need to join the - # sender to the given room. - # TODO: We should probably auth this or some such - content = event.content - content.update({"membership": Membership.JOIN}) - new_event = self.event_factory.create_event( - etype=RoomMemberEvent.TYPE, - state_key=event.user_id, - room_id=event.room_id, - user_id=event.user_id, - membership=Membership.JOIN, - content=content + with (yield self.room_lock.lock(event.room_id)): + yield self.store.persist_event( + event, + backfilled, + is_new_state=is_new_state ) - yield self.hs.get_handlers().room_member_handler.change_membership( - new_event, - do_auth=False, - ) + room = yield self.store.get_room(event.room_id) - else: - with (yield self.room_lock.lock(event.room_id)): - yield self.store.persist_event( - event, - backfilled, - is_new_state=is_new_state + if not room: + # Huh, let's try and get the current state + try: + yield self.replication_layer.get_state_for_context( + event.origin, event.room_id, event.event_id, ) - room = yield self.store.get_room(event.room_id) - - if not room: - # Huh, let's try and get the current state - try: - yield self.replication_layer.get_state_for_context( - event.origin, event.room_id - ) - - hosts = yield self.store.get_joined_hosts_for_room( - event.room_id - ) - if self.hs.hostname in hosts: - try: - yield self.store.store_room( - room_id=event.room_id, - room_creator_user_id="", - is_public=False, - ) - except: - pass - except: - logger.exception( - "Failed to get current state for room %s", - event.room_id - ) - - if not backfilled: - extra_users = [] - if event.type == RoomMemberEvent.TYPE: - target_user_id = event.state_key - target_user = self.hs.parse_userid(target_user_id) - extra_users.append(target_user) - - yield self.notifier.on_new_room_event( - event, extra_users=extra_users + hosts = yield self.store.get_joined_hosts_for_room( + event.room_id + ) + if self.hs.hostname in hosts: + try: + yield self.store.store_room( + room_id=event.room_id, + room_creator_user_id="", + is_public=False, + ) + except: + pass + except: + logger.exception( + "Failed to get current state for room %s", + event.room_id ) + if not backfilled: + extra_users = [] + if event.type == RoomMemberEvent.TYPE: + target_user_id = event.state_key + target_user = self.hs.parse_userid(target_user_id) + extra_users.append(target_user) + + yield self.notifier.on_new_room_event( + event, extra_users=extra_users + ) + if event.type == RoomMemberEvent.TYPE: if event.membership == Membership.JOIN: user = self.hs.parse_userid(event.state_key) @@ -189,13 +181,28 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks def backfill(self, dest, room_id, limit): - pdus = yield self.replication_layer.backfill(dest, room_id, limit) + extremities = yield self.store.get_oldest_events_in_room(room_id) + + pdus = yield self.replication_layer.backfill( + dest, + room_id, + limit, + extremities=[ + self.pdu_codec.decode_event_id(e) + for e in extremities + ] + ) events = [] for pdu in pdus: event = self.pdu_codec.event_from_pdu(pdu) + + # FIXME (erikj): Not sure this actually works :/ + yield self.state_handler.annotate_state_groups(event) + events.append(event) + yield self.store.persist_event(event, backfilled=True) defer.returnValue(events) @@ -203,62 +210,230 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks def do_invite_join(self, target_host, room_id, joinee, content, snapshot): - hosts = yield self.store.get_joined_hosts_for_room(room_id) if self.hs.hostname in hosts: # We are already in the room. logger.debug("We're already in the room apparently") defer.returnValue(False) - # First get current state to see if we are already joined. + pdu = yield self.replication_layer.make_join( + target_host, + room_id, + joinee + ) + + logger.debug("Got response to make_join: %s", pdu) + + event = self.pdu_codec.event_from_pdu(pdu) + + # We should assert some things. + assert(event.type == RoomMemberEvent.TYPE) + assert(event.user_id == joinee) + assert(event.state_key == joinee) + assert(event.room_id == room_id) + + event.outlier = False + + self.room_queues[room_id] = [] + try: - yield self.replication_layer.get_state_for_context( - target_host, room_id + event.event_id = self.event_factory.create_event_id() + event.content = content + + state = yield self.replication_layer.send_join( + target_host, + self.pdu_codec.pdu_from_event(event) ) - hosts = yield self.store.get_joined_hosts_for_room(room_id) - if self.hs.hostname in hosts: - # Oh, we were actually in the room already. - logger.debug("We're already in the room apparently") - defer.returnValue(False) - except Exception: - logger.exception("Failed to get current state") - - new_event = self.event_factory.create_event( - etype=InviteJoinEvent.TYPE, - target_host=target_host, - room_id=room_id, - user_id=joinee, - content=content - ) + state = [self.pdu_codec.event_from_pdu(p) for p in state] - new_event.destinations = [target_host] + logger.debug("do_invite_join state: %s", state) - snapshot.fill_out_prev_events(new_event) - yield self.handle_new_event(new_event, snapshot) + is_new_state = yield self.state_handler.annotate_state_groups( + event, + old_state=state + ) - # TODO (erikj): Time out here. - d = defer.Deferred() - self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d) - reactor.callLater(10, d.cancel) + logger.debug("do_invite_join event: %s", event) - try: - yield d - except defer.CancelledError: - raise SynapseError(500, "Unable to join remote room") + try: + yield self.store.store_room( + room_id=room_id, + room_creator_user_id="", + is_public=False + ) + except: + # FIXME + pass - try: - yield self.store.store_room( - room_id=room_id, - room_creator_user_id="", - is_public=False + for e in state: + # FIXME: Auth these. + e.outlier = True + + yield self.state_handler.annotate_state_groups( + e, + ) + + yield self.store.persist_event( + e, + backfilled=False, + is_new_state=False + ) + + yield self.store.persist_event( + event, + backfilled=False, + is_new_state=is_new_state ) - except: - pass + finally: + room_queue = self.room_queues[room_id] + del self.room_queues[room_id] + for p in room_queue: + try: + yield self.on_receive_pdu(p, backfilled=False) + except: + pass defer.returnValue(True) + @defer.inlineCallbacks + @log_function + def on_make_join_request(self, context, user_id): + event = self.event_factory.create_event( + etype=RoomMemberEvent.TYPE, + content={"membership": Membership.JOIN}, + room_id=context, + user_id=user_id, + state_key=user_id, + ) + + snapshot = yield self.store.snapshot_room( + event.room_id, event.user_id, + ) + snapshot.fill_out_prev_events(event) + + yield self.state_handler.annotate_state_groups(event) + yield self.auth.check(event, None, raises=True) + + pdu = self.pdu_codec.pdu_from_event(event) + + defer.returnValue(pdu) + + @defer.inlineCallbacks + @log_function + def on_send_join_request(self, origin, pdu): + event = self.pdu_codec.event_from_pdu(pdu) + + event.outlier = False + + is_new_state = yield self.state_handler.annotate_state_groups(event) + yield self.auth.check(event, None, raises=True) + + # FIXME (erikj): All this is duplicated above :( + + yield self.store.persist_event( + event, + backfilled=False, + is_new_state=is_new_state + ) + + extra_users = [] + if event.type == RoomMemberEvent.TYPE: + target_user_id = event.state_key + target_user = self.hs.parse_userid(target_user_id) + extra_users.append(target_user) + + yield self.notifier.on_new_room_event( + event, extra_users=extra_users + ) + + if event.type == RoomMemberEvent.TYPE: + if event.membership == Membership.JOIN: + user = self.hs.parse_userid(event.state_key) + self.distributor.fire( + "user_joined_room", user=user, room_id=event.room_id + ) + + new_pdu = self.pdu_codec.pdu_from_event(event); + new_pdu.destinations = yield self.store.get_joined_hosts_for_room( + event.room_id + ) + + yield self.replication_layer.send_pdu(new_pdu) + + defer.returnValue([ + self.pdu_codec.pdu_from_event(e) + for e in event.state_events.values() + ]) + + @defer.inlineCallbacks + def get_state_for_pdu(self, event_id): + yield run_on_reactor() + + state_groups = yield self.store.get_state_groups( + [event_id] + ) + + if state_groups: + results = { + (e.type, e.state_key): e for e in state_groups[0].state + } + + event = yield self.store.get_event(event_id) + if hasattr(event, "state_key"): + # Get previous state + if hasattr(event, "prev_state") and event.prev_state: + prev_event = yield self.store.get_event(event.prev_state) + results[(event.type, event.state_key)] = prev_event + else: + del results[(event.type, event.state_key)] + + defer.returnValue( + [ + self.pdu_codec.pdu_from_event(s) + for s in results.values() + ] + ) + else: + defer.returnValue([]) + + @defer.inlineCallbacks + @log_function + def on_backfill_request(self, context, pdu_list, limit): + + events = yield self.store.get_backfill_events( + context, + pdu_list, + limit + ) + + defer.returnValue([ + self.pdu_codec.pdu_from_event(e) + for e in events + ]) + + @defer.inlineCallbacks + @log_function + def get_persisted_pdu(self, event_id): + """ Get a PDU from the database with given origin and id. + + Returns: + Deferred: Results in a `Pdu`. + """ + event = yield self.store.get_event( + event_id, + allow_none=True, + ) + + if event: + defer.returnValue(self.pdu_codec.pdu_from_event(event)) + else: + defer.returnValue(None) + + @log_function + def get_min_depth_for_context(self, context): + return self.store.get_min_depth(context) @log_function def _on_user_joined(self, user, room_id): diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py index 3f152e18f0..99d15261d4 100644 --- a/synapse/handlers/login.py +++ b/synapse/handlers/login.py @@ -54,7 +54,7 @@ class LoginHandler(BaseHandler): # pull out the hash for this user if they exist user_info = yield self.store.get_user_by_id(user_id=user) if not user_info: - logger.warn("Attempted to login as %s but they do not exist.", user) + logger.warn("Attempted to login as %s but they do not exist", user) raise LoginError(403, "", errcode=Codes.FORBIDDEN) stored_hash = user_info[0]["password_hash"] diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7b2b8549ed..c6f6ab14d1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -83,10 +83,9 @@ class MessageHandler(BaseHandler): snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - if not suppress_auth: - yield self.auth.check(event, snapshot, raises=True) - - yield self._on_new_room_event(event, snapshot) + yield self._on_new_room_event( + event, snapshot, suppress_auth=suppress_auth + ) self.hs.get_handlers().presence_handler.bump_presence_active_time( user @@ -115,8 +114,12 @@ class MessageHandler(BaseHandler): user = self.hs.parse_userid(user_id) - events, next_token = yield data_source.get_pagination_rows( - user, pagin_config, room_id + events, next_key = yield data_source.get_pagination_rows( + user, pagin_config.get_source_config("room"), room_id + ) + + next_token = pagin_config.from_token.copy_and_replace( + "room_key", next_key ) chunk = { @@ -145,10 +148,6 @@ class MessageHandler(BaseHandler): state_key=event.state_key, ) - yield self.auth.check(event, snapshot, raises=True) - - yield self.state_handler.handle_new_event(event, snapshot) - yield self._on_new_room_event(event, snapshot) @defer.inlineCallbacks @@ -197,7 +196,7 @@ class MessageHandler(BaseHandler): raise RoomError( 403, "Member does not meet private room rules.") - data = yield self.store.get_current_state( + data = yield self.state_handler.get_current_state( room_id, event_type, state_key ) defer.returnValue(data) @@ -217,8 +216,6 @@ class MessageHandler(BaseHandler): def send_feedback(self, event): snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - yield self.auth.check(event, snapshot, raises=True) - # store message in db yield self._on_new_room_event(event, snapshot) @@ -235,7 +232,7 @@ class MessageHandler(BaseHandler): yield self.auth.check_joined_room(room_id, user_id) # TODO: This is duplicating logic from snapshot_all_rooms - current_state = yield self.store.get_current_state(room_id) + current_state = yield self.state_handler.get_current_state(room_id) defer.returnValue([self.hs.serialize_event(c) for c in current_state]) @defer.inlineCallbacks @@ -271,7 +268,7 @@ class MessageHandler(BaseHandler): presence_stream = self.hs.get_event_sources().sources["presence"] pagination_config = PaginationConfig(from_token=now_token) presence, _ = yield presence_stream.get_pagination_rows( - user, pagination_config, None + user, pagination_config.get_source_config("presence"), None ) public_rooms = yield self.store.get_rooms(is_public=True) @@ -312,7 +309,7 @@ class MessageHandler(BaseHandler): "end": end_token.to_string(), } - current_state = yield self.store.get_current_state( + current_state = yield self.state_handler.get_current_state( event.room_id ) d["state"] = [self.hs.serialize_event(c) for c in current_state] diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b2af09f090..2ccc2245b7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -76,9 +76,7 @@ class PresenceHandler(BaseHandler): "stopped_user_eventstream", self.stopped_user_eventstream ) - distributor.observe("user_joined_room", - self.user_joined_room - ) + distributor.observe("user_joined_room", self.user_joined_room) distributor.declare("collect_presencelike_data") @@ -156,14 +154,12 @@ class PresenceHandler(BaseHandler): defer.returnValue(True) if (yield self.store.user_rooms_intersect( - [u.to_string() for u in observer_user, observed_user] - )): + [u.to_string() for u in observer_user, observed_user])): defer.returnValue(True) if (yield self.store.is_presence_visible( - observed_localpart=observed_user.localpart, - observer_userid=observer_user.to_string(), - )): + observed_localpart=observed_user.localpart, + observer_userid=observer_user.to_string())): defer.returnValue(True) defer.returnValue(False) @@ -171,7 +167,8 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks def get_state(self, target_user, auth_user): if target_user.is_mine: - visible = yield self.is_presence_visible(observer_user=auth_user, + visible = yield self.is_presence_visible( + observer_user=auth_user, observed_user=target_user ) @@ -219,9 +216,9 @@ class PresenceHandler(BaseHandler): ) if state["presence"] not in self.STATE_LEVELS: - raise SynapseError(400, "'%s' is not a valid presence state" % - state["presence"] - ) + raise SynapseError(400, "'%s' is not a valid presence state" % ( + state["presence"], + )) logger.debug("Updating presence state of %s to %s", target_user.localpart, state["presence"]) @@ -229,7 +226,7 @@ class PresenceHandler(BaseHandler): state_to_store = dict(state) state_to_store["state"] = state_to_store.pop("presence") - statuscache=self._get_or_offline_usercache(target_user) + statuscache = self._get_or_offline_usercache(target_user) was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]] now_level = self.STATE_LEVELS[state["presence"]] @@ -649,8 +646,9 @@ class PresenceHandler(BaseHandler): del state["user_id"] if "presence" not in state: - logger.warning("Received a presence 'push' EDU from %s without" - + " a 'presence' key", origin + logger.warning( + "Received a presence 'push' EDU from %s without a" + " 'presence' key", origin ) continue @@ -745,7 +743,7 @@ class PresenceHandler(BaseHandler): defer.returnValue((localusers, remote_domains)) def push_update_to_clients(self, observed_user, users_to_push=[], - room_ids=[], statuscache=None): + room_ids=[], statuscache=None): self.notifier.on_new_user_event( users_to_push, room_ids, @@ -765,8 +763,7 @@ class PresenceEventSource(object): presence = self.hs.get_handlers().presence_handler if (yield presence.store.user_rooms_intersect( - [u.to_string() for u in observer_user, observed_user] - )): + [u.to_string() for u in observer_user, observed_user])): defer.returnValue(True) if observed_user.is_mine: @@ -823,15 +820,12 @@ class PresenceEventSource(object): def get_pagination_rows(self, user, pagination_config, key): # TODO (erikj): Does this make sense? Ordering? - from_token = pagination_config.from_token - to_token = pagination_config.to_token - observer_user = user - from_key = int(from_token.presence_key) + from_key = int(pagination_config.from_key) - if to_token: - to_key = int(to_token.presence_key) + if pagination_config.to_key: + to_key = int(pagination_config.to_key) else: to_key = -1 @@ -841,7 +835,7 @@ class PresenceEventSource(object): updates = [] # TODO(paul): use a DeferredList ? How to limit concurrency. for observed_user in cachemap.keys(): - if not (to_key < cachemap[observed_user].serial < from_key): + if not (to_key < cachemap[observed_user].serial <= from_key): continue if (yield self.is_visible(observer_user, observed_user)): @@ -849,30 +843,15 @@ class PresenceEventSource(object): # TODO(paul): limit - updates = [(k, cachemap[k]) for k in cachemap - if to_key < cachemap[k].serial < from_key] - if updates: clock = self.clock earliest_serial = max([x[1].serial for x in updates]) data = [x[1].make_event(user=x[0], clock=clock) for x in updates] - if to_token: - next_token = to_token - else: - next_token = from_token - - next_token = next_token.copy_and_replace( - "presence_key", earliest_serial - ) - defer.returnValue((data, next_token)) + defer.returnValue((data, earliest_serial)) else: - if not to_token: - to_token = from_token.copy_and_replace( - "presence_key", 0 - ) - defer.returnValue(([], to_token)) + defer.returnValue(([], 0)) class UserPresenceCache(object): diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index dab9b03f04..4cd0a06093 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -218,5 +218,6 @@ class ProfileHandler(BaseHandler): user_id=j.state_key, ) - yield self.state_handler.handle_new_event(new_event, snapshot) - yield self._on_new_room_event(new_event, snapshot) + yield self._on_new_room_event( + new_event, snapshot, suppress_auth=True + ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 94b7890b5e..7df9d9b82d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,7 +15,6 @@ """Contains functions for registering clients.""" from twisted.internet import defer -from twisted.python import log from synapse.types import UserID from synapse.api.errors import ( @@ -64,9 +63,11 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() token = self._generate_token(user_id) - yield self.store.register(user_id=user_id, + yield self.store.register( + user_id=user_id, token=token, - password_hash=password_hash) + password_hash=password_hash + ) self.distributor.fire("registered_user", user) else: @@ -127,7 +128,7 @@ class RegistrationHandler(BaseHandler): try: threepid = yield self._threepid_from_creds(c) except: - log.err() + logger.exception("Couldn't validate 3pid") raise RegistrationError(400, "Couldn't validate 3pid") if not threepid: @@ -181,8 +182,11 @@ class RegistrationHandler(BaseHandler): data = yield httpCli.post_urlencoded_get_json( creds['idServer'], "/_matrix/identity/api/v1/3pid/bind", - {'sid': creds['sid'], 'clientSecret': creds['clientSecret'], - 'mxid': mxid} + { + 'sid': creds['sid'], + 'clientSecret': creds['clientSecret'], + 'mxid': mxid, + } ) defer.returnValue(data) @@ -223,5 +227,3 @@ class RegistrationHandler(BaseHandler): } ) defer.returnValue(data) - - diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 21ae03df0d..ffc0892f1a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -129,8 +129,9 @@ class RoomCreationHandler(BaseHandler): logger.debug("Event: %s", event) - yield self.state_handler.handle_new_event(event, snapshot) - yield self._on_new_room_event(event, snapshot, extra_users=[user]) + yield self._on_new_room_event( + event, snapshot, extra_users=[user], suppress_auth=True + ) for event in creation_events: yield handle_event(event) @@ -391,8 +392,6 @@ class RoomMemberHandler(BaseHandler): yield self._do_join(event, snapshot, do_auth=do_auth) else: # This is not a JOIN, so we can handle it normally. - if do_auth: - yield self.auth.check(event, snapshot, raises=True) # If we're banning someone, set a req power level if event.membership == Membership.BAN: @@ -414,6 +413,7 @@ class RoomMemberHandler(BaseHandler): event, membership=event.content["membership"], snapshot=snapshot, + do_auth=do_auth, ) defer.returnValue({"room_id": room_id}) @@ -502,14 +502,11 @@ class RoomMemberHandler(BaseHandler): if not have_joined: logger.debug("Doing normal join") - if do_auth: - yield self.auth.check(event, snapshot, raises=True) - - yield self.state_handler.handle_new_event(event, snapshot) yield self._do_local_membership_update( event, membership=event.content["membership"], snapshot=snapshot, + do_auth=do_auth, ) user = self.hs.parse_userid(event.user_id) @@ -553,7 +550,8 @@ class RoomMemberHandler(BaseHandler): defer.returnValue([r.room_id for r in rooms]) - def _do_local_membership_update(self, event, membership, snapshot): + def _do_local_membership_update(self, event, membership, snapshot, + do_auth): destinations = [] # If we're inviting someone, then we should also send it to that @@ -570,9 +568,10 @@ class RoomMemberHandler(BaseHandler): return self._on_new_room_event( event, snapshot, extra_destinations=destinations, - extra_users=[target_user] + extra_users=[target_user], suppress_auth=(not do_auth), ) + class RoomListHandler(BaseHandler): @defer.inlineCallbacks @@ -612,23 +611,14 @@ class RoomEventSource(object): return self.store.get_room_events_max_id() @defer.inlineCallbacks - def get_pagination_rows(self, user, pagination_config, key): - from_token = pagination_config.from_token - to_token = pagination_config.to_token - limit = pagination_config.limit - direction = pagination_config.direction - - to_key = to_token.room_key if to_token else None - + def get_pagination_rows(self, user, config, key): events, next_key = yield self.store.paginate_room_events( room_id=key, - from_key=from_token.room_key, - to_key=to_key, - direction=direction, - limit=limit, + from_key=config.from_key, + to_key=config.to_key, + direction=config.direction, + limit=config.limit, with_feedback=True ) - next_token = from_token.copy_and_replace("room_key", next_key) - - defer.returnValue((events, next_token)) + defer.returnValue((events, next_key)) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0ca4e5c31e..d88a53242c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -96,9 +96,10 @@ class TypingNotificationHandler(BaseHandler): remotedomains = set() rm_handler = self.homeserver.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into(room_id, - localusers=localusers, remotedomains=remotedomains, - ignore_user=user) + yield rm_handler.fetch_room_distributions_into( + room_id, localusers=localusers, remotedomains=remotedomains, + ignore_user=user + ) for u in localusers: self.push_update_to_clients( @@ -130,8 +131,9 @@ class TypingNotificationHandler(BaseHandler): localusers = set() rm_handler = self.homeserver.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into(room_id, - localusers=localusers) + yield rm_handler.fetch_room_distributions_into( + room_id, localusers=localusers + ) for u in localusers: self.push_update_to_clients( @@ -142,7 +144,7 @@ class TypingNotificationHandler(BaseHandler): ) def push_update_to_clients(self, room_id, observer_user, observed_user, - typing): + typing): # TODO(paul) steal this from presence.py pass @@ -158,4 +160,4 @@ class TypingNotificationEventSource(object): return 0 def get_pagination_rows(self, user, pagination_config, key): - return ([], pagination_config.from_token) + return ([], pagination_config.from_key) diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 9bff9ec169..f9811bfa04 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/synapse/http/client.py b/synapse/http/client.py index 46c90dbb76..c34b086eb9 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -16,7 +16,9 @@ from twisted.internet import defer, reactor from twisted.internet.error import DNSLookupError -from twisted.web.client import _AgentBase, _URI, readBody, FileBodyProducer, PartialDownloadError +from twisted.web.client import ( + _AgentBase, _URI, readBody, FileBodyProducer, PartialDownloadError +) from twisted.web.http_headers import Headers from synapse.http.endpoint import matrix_endpoint @@ -97,7 +99,7 @@ class BaseHttpClient(object): retries_left = 5 - endpoint = self._getEndpoint(reactor, destination); + endpoint = self._getEndpoint(reactor, destination) while True: @@ -181,7 +183,7 @@ class MatrixHttpClient(BaseHttpClient): auth_headers = [] - for key,sig in request["signatures"][self.server_name].items(): + for key, sig in request["signatures"][self.server_name].items(): auth_headers.append(bytes( "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( self.server_name, key, sig, @@ -276,7 +278,6 @@ class MatrixHttpClient(BaseHttpClient): defer.returnValue(json.loads(body)) - def _getEndpoint(self, reactor, destination): return matrix_endpoint( reactor, destination, timeout=10, @@ -351,6 +352,7 @@ class IdentityServerHttpClient(BaseHttpClient): defer.returnValue(json.loads(body)) + class CaptchaServerHttpClient(MatrixHttpClient): """Separate HTTP client for talking to google's captcha servers""" @@ -384,6 +386,7 @@ class CaptchaServerHttpClient(MatrixHttpClient): else: raise e + def _print_ex(e): if hasattr(e, "reasons") and e.reasons: for ex in e.reasons: diff --git a/synapse/http/content_repository.py b/synapse/http/content_repository.py index 7dd4a859f8..3159ffff0a 100644 --- a/synapse/http/content_repository.py +++ b/synapse/http/content_repository.py @@ -38,8 +38,8 @@ class ContentRepoResource(resource.Resource): Uploads are POSTed to wherever this Resource is linked to. This resource returns a "content token" which can be used to GET this content again. The - token is typically a path, but it may not be. Tokens can expire, be one-time - uses, etc. + token is typically a path, but it may not be. Tokens can expire, be + one-time uses, etc. In this case, the token is a path to the file and contains 3 interesting sections: @@ -175,10 +175,9 @@ class ContentRepoResource(resource.Resource): with open(fname, "wb") as f: f.write(request.content.read()) - # FIXME (erikj): These should use constants. file_name = os.path.basename(fname) - # FIXME: we can't assume what the public mounted path of the repo is + # FIXME: we can't assume what the repo's public mounted path is # ...plus self-signed SSL won't work to remote clients anyway # ...and we can't assume that it's SSL anyway, as we might want to # server it via the non-SSL listener... @@ -201,6 +200,3 @@ class ContentRepoResource(resource.Resource): 500, json.dumps({"error": "Internal server error"}), send_cors=True) - - - diff --git a/synapse/notifier.py b/synapse/notifier.py index 5b02c71d1e..f38c410e33 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -167,7 +167,8 @@ class Notifier(object): ) def eb(failure): - logger.error("Failed to notify listener", + logger.error( + "Failed to notify listener", exc_info=( failure.type, failure.value, @@ -207,7 +208,7 @@ class Notifier(object): ) if timeout: - reactor.callLater(timeout/1000, self._timeout_listener, listener) + reactor.callLater(timeout/1000.0, self._timeout_listener, listener) self._register_with_keys(listener) diff --git a/synapse/rest/base.py b/synapse/rest/base.py index 2e8e3fa7d4..dc784c1527 100644 --- a/synapse/rest/base.py +++ b/synapse/rest/base.py @@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX from synapse.rest.transactions import HttpTransactionStore import re +import logging + + +logger = logging.getLogger(__name__) + def client_path_pattern(path_regex): """Creates a regex compiled client path with the correct client path diff --git a/synapse/rest/events.py b/synapse/rest/events.py index 097195d7cc..92ff5e5ca7 100644 --- a/synapse/rest/events.py +++ b/synapse/rest/events.py @@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError from synapse.streams.config import PaginationConfig from synapse.rest.base import RestServlet, client_path_pattern +import logging + + +logger = logging.getLogger(__name__) + + class EventStreamRestServlet(RestServlet): PATTERN = client_path_pattern("/events$") @@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): auth_user = yield self.auth.get_user_by_req(request) - - handler = self.handlers.event_stream_handler - pagin_config = PaginationConfig.from_request(request) - timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if "timeout" in request.args: - try: - timeout = int(request.args["timeout"][0]) - except ValueError: - raise SynapseError(400, "timeout must be in milliseconds.") - - chunk = yield handler.get_stream(auth_user.to_string(), pagin_config, - timeout=timeout) + try: + handler = self.handlers.event_stream_handler + pagin_config = PaginationConfig.from_request(request) + timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS + if "timeout" in request.args: + try: + timeout = int(request.args["timeout"][0]) + except ValueError: + raise SynapseError(400, "timeout must be in milliseconds.") + + chunk = yield handler.get_stream( + auth_user.to_string(), pagin_config, timeout=timeout + ) + except: + logger.exception("Event stream failed") + raise defer.returnValue((200, chunk)) diff --git a/synapse/rest/profile.py b/synapse/rest/profile.py index dad5a208c7..72e02d8dd8 100644 --- a/synapse/rest/profile.py +++ b/synapse/rest/profile.py @@ -108,9 +108,9 @@ class ProfileRestServlet(RestServlet): ) defer.returnValue((200, { - "displayname": displayname, - "avatar_url": avatar_url - })) + "displayname": displayname, + "avatar_url": avatar_url + })) def register_servlets(hs, http_server): diff --git a/synapse/rest/register.py b/synapse/rest/register.py index 804117ee09..5c15614ea9 100644 --- a/synapse/rest/register.py +++ b/synapse/rest/register.py @@ -60,40 +60,45 @@ class RegisterRestServlet(RestServlet): def on_GET(self, request): if self.hs.config.enable_registration_captcha: - return (200, { - "flows": [ + return ( + 200, + {"flows": [ { "type": LoginType.RECAPTCHA, - "stages": ([LoginType.RECAPTCHA, - LoginType.EMAIL_IDENTITY, - LoginType.PASSWORD]) + "stages": [ + LoginType.RECAPTCHA, + LoginType.EMAIL_IDENTITY, + LoginType.PASSWORD + ] }, { "type": LoginType.RECAPTCHA, "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD] } - ] - }) + ]} + ) else: - return (200, { - "flows": [ + return ( + 200, + {"flows": [ { "type": LoginType.EMAIL_IDENTITY, - "stages": ([LoginType.EMAIL_IDENTITY, - LoginType.PASSWORD]) + "stages": [ + LoginType.EMAIL_IDENTITY, LoginType.PASSWORD + ] }, { "type": LoginType.PASSWORD } - ] - }) + ]} + ) @defer.inlineCallbacks def on_POST(self, request): register_json = _parse_json(request) - session = (register_json["session"] if "session" in register_json - else None) + session = (register_json["session"] + if "session" in register_json else None) login_type = None if "type" not in register_json: raise SynapseError(400, "Missing 'type' key.") @@ -122,7 +127,9 @@ class RegisterRestServlet(RestServlet): defer.returnValue((200, response)) except KeyError as e: logger.exception(e) - raise SynapseError(400, "Missing JSON keys for login type %s." % login_type) + raise SynapseError(400, "Missing JSON keys for login type %s." % ( + login_type, + )) def on_OPTIONS(self, request): return (200, {}) @@ -183,8 +190,10 @@ class RegisterRestServlet(RestServlet): session["user"] = register_json["user"] defer.returnValue(None) else: - raise SynapseError(400, "Captcha bypass HMAC incorrect", - errcode=Codes.CAPTCHA_NEEDED) + raise SynapseError( + 400, "Captcha bypass HMAC incorrect", + errcode=Codes.CAPTCHA_NEEDED + ) challenge = None user_response = None @@ -230,12 +239,15 @@ class RegisterRestServlet(RestServlet): if ("user" in session and "user" in register_json and session["user"] != register_json["user"]): - raise SynapseError(400, "Cannot change user ID during registration") + raise SynapseError( + 400, "Cannot change user ID during registration" + ) password = register_json["password"].encode("utf-8") - desired_user_id = (register_json["user"].encode("utf-8") if "user" - in register_json else None) - if desired_user_id and urllib.quote(desired_user_id) != desired_user_id: + desired_user_id = (register_json["user"].encode("utf-8") + if "user" in register_json else None) + if (desired_user_id + and urllib.quote(desired_user_id) != desired_user_id): raise SynapseError( 400, "User ID must only contain characters which do not " + diff --git a/synapse/rest/room.py b/synapse/rest/room.py index c72bdc2c34..ec0ce78fda 100644 --- a/synapse/rest/room.py +++ b/synapse/rest/room.py @@ -48,7 +48,9 @@ class RoomCreateRestServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, txn_id): try: - defer.returnValue(self.txns.get_client_transaction(request, txn_id)) + defer.returnValue( + self.txns.get_client_transaction(request, txn_id) + ) except KeyError: pass @@ -98,8 +100,8 @@ class RoomStateEventRestServlet(RestServlet): no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" # /room/$roomid/state/$eventtype/$statekey - state_key = ("/rooms/(?P<room_id>[^/]*)/state/" + - "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") + state_key = ("/rooms/(?P<room_id>[^/]*)/state/" + "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") http_server.register_path("GET", client_path_pattern(state_key), @@ -133,7 +135,9 @@ class RoomStateEventRestServlet(RestServlet): ) if not data: - raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + raise SynapseError( + 404, "Event not found.", errcode=Codes.NOT_FOUND + ) defer.returnValue((200, data[0].get_dict()["content"])) @defer.inlineCallbacks @@ -195,7 +199,9 @@ class RoomSendEventRestServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, txn_id): try: - defer.returnValue(self.txns.get_client_transaction(request, txn_id)) + defer.returnValue( + self.txns.get_client_transaction(request, txn_id) + ) except KeyError: pass @@ -254,7 +260,9 @@ class JoinRoomAliasServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_identifier, txn_id): try: - defer.returnValue(self.txns.get_client_transaction(request, txn_id)) + defer.returnValue( + self.txns.get_client_transaction(request, txn_id) + ) except KeyError: pass @@ -293,7 +301,8 @@ class RoomMemberListRestServlet(RestServlet): target_user = self.hs.parse_userid(event["user_id"]) # Presence is an optional cache; don't fail if we can't fetch it try: - presence_state = yield self.handlers.presence_handler.get_state( + presence_handler = self.handlers.presence_handler + presence_state = yield presence_handler.get_state( target_user=target_user, auth_user=user ) event["content"].update(presence_state) @@ -359,11 +368,11 @@ class RoomInitialSyncRestServlet(RestServlet): # { state event } , { state event } # ] # } - # Probably worth keeping the keys room_id and membership for parity with - # /initialSync even though they must be joined to sync this and know the - # room ID, so clients can reuse the same code (room_id and membership - # are MANDATORY for /initialSync, so the code will expect it to be - # there) + # Probably worth keeping the keys room_id and membership for parity + # with /initialSync even though they must be joined to sync this and + # know the room ID, so clients can reuse the same code (room_id and + # membership are MANDATORY for /initialSync, so the code will expect + # it to be there) defer.returnValue((200, {})) @@ -388,8 +397,8 @@ class RoomMembershipRestServlet(RestServlet): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] - PATTERN = ("/rooms/(?P<room_id>[^/]*)/" + - "(?P<membership_action>join|invite|leave|ban|kick)") + PATTERN = ("/rooms/(?P<room_id>[^/]*)/" + "(?P<membership_action>join|invite|leave|ban|kick)") register_txn_path(self, PATTERN, http_server) @defer.inlineCallbacks @@ -422,7 +431,9 @@ class RoomMembershipRestServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, membership_action, txn_id): try: - defer.returnValue(self.txns.get_client_transaction(request, txn_id)) + defer.returnValue( + self.txns.get_client_transaction(request, txn_id) + ) except KeyError: pass @@ -431,6 +442,7 @@ class RoomMembershipRestServlet(RestServlet): self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) + class RoomRedactEventRestServlet(RestServlet): def register(self, http_server): PATTERN = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") @@ -457,7 +469,9 @@ class RoomRedactEventRestServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, event_id, txn_id): try: - defer.returnValue(self.txns.get_client_transaction(request, txn_id)) + defer.returnValue( + self.txns.get_client_transaction(request, txn_id) + ) except KeyError: pass @@ -503,10 +517,10 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): ) if with_get: http_server.register_path( - "GET", - client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"), - servlet.on_GET - ) + "GET", + client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"), + servlet.on_GET + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/transactions.py b/synapse/rest/transactions.py index e06dcc8c57..93c0122f30 100644 --- a/synapse/rest/transactions.py +++ b/synapse/rest/transactions.py @@ -30,9 +30,9 @@ class HttpTransactionStore(object): """Retrieve a response for this request. Args: - key (str): A transaction-independent key for this request. Typically - this is a combination of the path (without the transaction id) and - the user's access token. + key (str): A transaction-independent key for this request. Usually + this is a combination of the path (without the transaction id) + and the user's access token. txn_id (str): The transaction ID for this request Returns: A tuple of (HTTP response code, response content) or None. @@ -51,9 +51,9 @@ class HttpTransactionStore(object): """Stores an HTTP response tuple. Args: - key (str): A transaction-independent key for this request. Typically - this is a combination of the path (without the transaction id) and - the user's access token. + key (str): A transaction-independent key for this request. Usually + this is a combination of the path (without the transaction id) + and the user's access token. txn_id (str): The transaction ID for this request. response (tuple): A tuple of (HTTP response code, response content) """ @@ -92,5 +92,3 @@ class HttpTransactionStore(object): token = request.args["access_token"][0] path_without_txn_id = request.path.rsplit("/", 1)[0] return path_without_txn_id + "/" + token - - diff --git a/synapse/rest/voip.py b/synapse/rest/voip.py index 0d0243a249..432c2475f8 100644 --- a/synapse/rest/voip.py +++ b/synapse/rest/voip.py @@ -34,23 +34,23 @@ class VoipRestServlet(RestServlet): turnSecret = self.hs.config.turn_shared_secret userLifetime = self.hs.config.turn_user_lifetime if not turnUris or not turnSecret or not userLifetime: - defer.returnValue( (200, {}) ) + defer.returnValue((200, {})) expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 username = "%d:%s" % (expiry, auth_user.to_string()) - + mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) - # We need to use standard base64 encoding here, *not* syutil's encode_base64 - # because we need to add the standard padding to get the same result as the - # TURN server. + # We need to use standard base64 encoding here, *not* syutil's + # encode_base64 because we need to add the standard padding to get the + # same result as the TURN server. password = base64.b64encode(mac.digest()) - defer.returnValue( (200, { + defer.returnValue((200, { 'username': username, 'password': password, 'ttl': userLifetime / 1000, 'uris': turnUris, - }) ) + })) def on_OPTIONS(self, request): return (200, {}) diff --git a/synapse/server.py b/synapse/server.py index a4d2d4aba5..d770b20b19 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -28,7 +28,7 @@ from synapse.handlers import Handlers from synapse.rest import RestServletFactory from synapse.state import StateHandler from synapse.storage import DataStore -from synapse.types import UserID, RoomAlias, RoomID +from synapse.types import UserID, RoomAlias, RoomID, EventID from synapse.util import Clock from synapse.util.distributor import Distributor from synapse.util.lockutils import LockManager @@ -143,6 +143,11 @@ class BaseHomeServer(object): object.""" return RoomID.from_string(s, hs=self) + def parse_eventid(self, s): + """Parse the string given by 's' as a Event ID and return a EventID + object.""" + return EventID.from_string(s, hs=self) + def serialize_event(self, e): return serialize_event(self, e) diff --git a/synapse/state.py b/synapse/state.py index bc6b928ec7..f4efc287c9 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,11 +16,14 @@ from twisted.internet import defer -from synapse.federation.pdu_codec import encode_event_id, decode_event_id from synapse.util.logutils import log_function +from synapse.util.async import run_on_reactor + +from synapse.types import EventID from collections import namedtuple +import copy import logging import hashlib @@ -35,13 +38,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) class StateHandler(object): - """ Repsonsible for doing state conflict resolution. + """ Responsible for doing state conflict resolution. """ def __init__(self, hs): self.store = hs.get_datastore() self._replication = hs.get_replication_layer() self.server_name = hs.hostname + self.hs = hs @defer.inlineCallbacks @log_function @@ -50,7 +54,7 @@ class StateHandler(object): to update the state and b) works out what the prev_state should be. Returns: - Deferred: Resolved with a boolean indicating if we succesfully + Deferred: Resolved with a boolean indicating if we successfully updated the state. Raised: @@ -61,200 +65,149 @@ class StateHandler(object): if not hasattr(event, "state_key"): return - key = KeyStateTuple( - event.room_id, - event.type, - _get_state_key_from_event(event) - ) - # Now I need to fill out the prev state and work out if it has auth # (w.r.t. to power levels) snapshot.fill_out_prev_events(event) + yield self.annotate_state_groups(event) - current_state = snapshot.prev_state_pdu - - if current_state: - event.prev_state = encode_event_id( - current_state.pdu_id, current_state.origin + if event.old_state_events: + current_state = event.old_state_events.get( + (event.type, event.state_key) ) - # TODO check current_state to see if the min power level is less - # than the power level of the user - # power_level = self._get_power_level_for_event(event) - - pdu_id, origin = decode_event_id(event.event_id, self.server_name) - - yield self.store.update_current_state( - pdu_id=pdu_id, - origin=origin, - context=key.context, - pdu_type=key.type, - state_key=key.state_key - ) + if current_state: + event.prev_state = current_state.event_id defer.returnValue(True) @defer.inlineCallbacks @log_function - def handle_new_state(self, new_pdu): - """ Apply conflict resolution to `new_pdu`. - - This should be called on every new state pdu, regardless of whether or - not there is a conflict. - - This function is safe against the race of it getting called with two - `PDU`s trying to update the same state. - """ - - # This needs to be done in a transaction. + def annotate_state_groups(self, event, old_state=None): + yield run_on_reactor() - is_new = yield self._handle_new_state(new_pdu) + if old_state: + event.state_group = None + event.old_state_events = { + (s.type, s.state_key): s for s in old_state + } + event.state_events = event.old_state_events - logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin) + if hasattr(event, "state_key"): + event.state_events[(event.type, event.state_key)] = event - if is_new: - yield self.store.update_current_state( - pdu_id=new_pdu.pdu_id, - origin=new_pdu.origin, - context=new_pdu.context, - pdu_type=new_pdu.pdu_type, - state_key=new_pdu.state_key - ) - - defer.returnValue(is_new) - - def _get_power_level_for_event(self, event): - # return self._persistence.get_power_level_for_user(event.room_id, - # event.sender) - return event.power_level + defer.returnValue(False) + return - @defer.inlineCallbacks - @log_function - def _handle_new_state(self, new_pdu): - tree, missing_branch = yield self.store.get_unresolved_state_tree( - new_pdu - ) - new_branch, current_branch = tree + if hasattr(event, "outlier") and event.outlier: + event.state_group = None + event.old_state_events = None + event.state_events = {} + defer.returnValue(False) + return - logger.debug( - "_handle_new_state new=%s, current=%s", - new_branch, current_branch + new_state = yield self.resolve_state_groups( + [e for e, _ in event.prev_events] ) - if missing_branch is not None: - # We're missing some PDUs. Fetch them. - # TODO (erikj): Limit this. - missing_prev = tree[missing_branch][-1] - - pdu_id = missing_prev.prev_state_id - origin = missing_prev.prev_state_origin - - is_missing = yield self.store.get_pdu(pdu_id, origin) is None - if not is_missing: - raise Exception("Conflict resolution failed") - - yield self._replication.get_pdu( - destination=missing_prev.origin, - pdu_origin=origin, - pdu_id=pdu_id, - outlier=True - ) + event.old_state_events = copy.deepcopy(new_state) - updated_current = yield self._handle_new_state(new_pdu) - defer.returnValue(updated_current) + if hasattr(event, "state_key"): + new_state[(event.type, event.state_key)] = event - if not current_branch: - # There is no current state - defer.returnValue(True) - return + event.state_group = None + event.state_events = new_state - n = new_branch[-1] - c = current_branch[-1] + defer.returnValue(hasattr(event, "state_key")) - common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin + @defer.inlineCallbacks + def get_current_state(self, room_id, event_type=None, state_key=""): + events = yield self.store.get_latest_events_in_room(room_id) - if common_ancestor: - # We found a common ancestor! + event_ids = [ + e_id + for e_id, _, _ in events + ] - if len(current_branch) == 1: - # This is a direct clobber so we can just... - defer.returnValue(True) + res = yield self.resolve_state_groups(event_ids) - else: - # We didn't find a common ancestor. This is probably fine. - pass + if event_type: + defer.returnValue(res.get((event_type, state_key))) + return - result = yield self._do_conflict_res( - new_branch, current_branch, common_ancestor - ) - defer.returnValue(result) + defer.returnValue(res.values()) @defer.inlineCallbacks - def _do_conflict_res(self, new_branch, current_branch, common_ancestor): - conflict_res = [ - self._do_power_level_conflict_res, - self._do_chain_length_conflict_res, - self._do_hash_conflict_res, - ] - - for algo in conflict_res: - new_res, curr_res = yield defer.maybeDeferred( - algo, - new_branch, current_branch, common_ancestor - ) - - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) + @log_function + def resolve_state_groups(self, event_ids): + state_groups = yield self.store.get_state_groups( + event_ids + ) - raise Exception("Conflict resolution failed.") + state = {} + for group in state_groups: + for s in group.state: + state.setdefault( + (s.type, s.state_key), + {} + )[s.event_id] = s + + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } + + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } + + try: + new_state = {} + new_state.update(unconflicted_state) + for key, events in conflicted_state.items(): + new_state[key] = yield self._resolve_state_events(events) + except: + logger.exception("Failed to resolve state") + raise + + defer.returnValue(new_state) @defer.inlineCallbacks - def _do_power_level_conflict_res(self, new_branch, current_branch, - common_ancestor): + @log_function + def _resolve_state_events(self, events): + curr_events = events + new_powers_deferreds = [] - for e in new_branch[:-1] if common_ancestor else new_branch: - if hasattr(e, "user_id"): - new_powers_deferreds.append( - self.store.get_power_level(e.context, e.user_id) - ) - - current_powers_deferreds = [] - for e in current_branch[:-1] if common_ancestor else current_branch: - if hasattr(e, "user_id"): - current_powers_deferreds.append( - self.store.get_power_level(e.context, e.user_id) - ) + for e in curr_events: + new_powers_deferreds.append( + self.store.get_power_level(e.room_id, e.user_id) + ) new_powers = yield defer.gatherResults( new_powers_deferreds, consumeErrors=True ) - current_powers = yield defer.gatherResults( - current_powers_deferreds, - consumeErrors=True - ) - - max_power_new = max(new_powers) - max_power_current = max(current_powers) + max_power = max([int(p) for p in new_powers]) - defer.returnValue( - (max_power_new, max_power_current) - ) - - def _do_chain_length_conflict_res(self, new_branch, current_branch, - common_ancestor): - return (len(new_branch), len(current_branch)) + curr_events = [ + z[0] for z in zip(curr_events, new_powers) + if int(z[1]) == max_power + ] - def _do_hash_conflict_res(self, new_branch, current_branch, - common_ancestor): - new_str = "".join([p.pdu_id + p.origin for p in new_branch]) - c_str = "".join([p.pdu_id + p.origin for p in current_branch]) + if not curr_events: + raise RuntimeError("Max didn't get a max?") + elif len(curr_events) == 1: + defer.returnValue(curr_events[0]) - return ( - hashlib.sha1(new_str).hexdigest(), - hashlib.sha1(c_str).hexdigest() + # TODO: For now, just choose the one with the largest event_id. + defer.returnValue( + sorted( + curr_events, + key=lambda e: hashlib.sha1( + e.event_id + e.user_id + e.room_id + e.type + ).hexdigest() + )[0] ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 1639e2c973..6b8fed4502 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -37,14 +37,17 @@ from .registration import RegistrationStore from .room import RoomStore from .roommember import RoomMemberStore from .stream import StreamStore -from .pdu import StatePduStore, PduStore, PdusTable from .transactions import TransactionStore from .keys import KeyStore +from .event_federation import EventFederationStore + +from .state import StateStore from .signatures import SignatureStore from syutil.base64util import decode_base64 -from synapse.crypto.event_signing import compute_pdu_event_reference_hash +from synapse.crypto.event_signing import compute_event_reference_hash + import json import logging @@ -56,7 +59,6 @@ logger = logging.getLogger(__name__) SCHEMAS = [ "transactions", - "pdu", "users", "profiles", "presence", @@ -64,7 +66,9 @@ SCHEMAS = [ "room_aliases", "keys", "redactions", - "signatures", + "state", + "event_edges", + "event_signatures", ] @@ -79,10 +83,12 @@ class _RollbackButIsFineException(Exception): """ pass + class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore, - PresenceStore, PduStore, StatePduStore, TransactionStore, - DirectoryStore, KeyStore, SignatureStore): + PresenceStore, TransactionStore, + DirectoryStore, KeyStore, StateStore, SignatureStore, + EventFederationStore, ): def __init__(self, hs): super(DataStore, self).__init__(hs) @@ -105,6 +111,7 @@ class DataStore(RoomMemberStore, RoomStore, try: yield self.runInteraction( + "persist_event", self._persist_pdu_event_txn, pdu=pdu, event=event, @@ -125,7 +132,8 @@ class DataStore(RoomMemberStore, RoomStore, "type", "room_id", "content", - "unrecognized_keys" + "unrecognized_keys", + "depth", ], allow_none=allow_none, ) @@ -139,68 +147,12 @@ class DataStore(RoomMemberStore, RoomStore, def _persist_pdu_event_txn(self, txn, pdu=None, event=None, backfilled=False, stream_ordering=None, is_new_state=True): - if pdu is not None: - self._persist_event_pdu_txn(txn, pdu) if event is not None: return self._persist_event_txn( txn, event, backfilled, stream_ordering, is_new_state=is_new_state, ) - def _persist_event_pdu_txn(self, txn, pdu): - cols = dict(pdu.__dict__) - unrec_keys = dict(pdu.unrecognized_keys) - del cols["hashes"] - del cols["signatures"] - del cols["content"] - del cols["prev_pdus"] - cols["content_json"] = json.dumps(pdu.content) - - unrec_keys.update({ - k: v for k, v in cols.items() - if k not in PdusTable.fields - }) - - cols["unrecognized_keys"] = json.dumps(unrec_keys) - - cols["ts"] = cols.pop("origin_server_ts") - - logger.debug("Persisting: %s", repr(cols)) - - for hash_alg, hash_base64 in pdu.hashes.items(): - hash_bytes = decode_base64(hash_base64) - self._store_pdu_content_hash_txn( - txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes, - ) - - signatures = pdu.signatures.get(pdu.origin, {}) - - for key_id, signature_base64 in signatures.items(): - signature_bytes = decode_base64(signature_base64) - self._store_pdu_origin_signature_txn( - txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes, - ) - - for prev_pdu_id, prev_origin, prev_hashes in pdu.prev_pdus: - for alg, hash_base64 in prev_hashes.items(): - hash_bytes = decode_base64(hash_base64) - self._store_prev_pdu_hash_txn( - txn, pdu.pdu_id, pdu.origin, prev_pdu_id, prev_origin, alg, - hash_bytes - ) - - (ref_alg, ref_hash_bytes) = compute_pdu_event_reference_hash(pdu) - self._store_pdu_reference_hash_txn( - txn, pdu.pdu_id, pdu.origin, ref_alg, ref_hash_bytes - ) - - if pdu.is_state: - self._persist_state_txn(txn, pdu.prev_pdus, cols) - else: - self._persist_pdu_txn(txn, pdu.prev_pdus, cols) - - self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth) - @log_function def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, is_new_state=True): @@ -225,6 +177,10 @@ class DataStore(RoomMemberStore, RoomStore, elif event.type == RoomRedactionEvent.TYPE: self._store_redaction(txn, event) + outlier = False + if hasattr(event, "outlier"): + outlier = event.outlier + vals = { "topological_ordering": event.depth, "event_id": event.event_id, @@ -232,25 +188,33 @@ class DataStore(RoomMemberStore, RoomStore, "room_id": event.room_id, "content": json.dumps(event.content), "processed": True, + "outlier": outlier, + "depth": event.depth, } if stream_ordering is not None: vals["stream_ordering"] = stream_ordering - if hasattr(event, "outlier"): - vals["outlier"] = event.outlier - else: - vals["outlier"] = False - unrec = { k: v for k, v in event.get_full_dict().items() - if k not in vals.keys() and k not in ["redacted", "redacted_because"] + if k not in vals.keys() and k not in [ + "redacted", + "redacted_because", + "signatures", + "hashes", + "prev_events", + ] } vals["unrecognized_keys"] = json.dumps(unrec) try: - self._simple_insert_txn(txn, "events", vals) + self._simple_insert_txn( + txn, + "events", + vals, + or_replace=(not outlier), + ) except: logger.warn( "Failed to persist, probably duplicate: %s", @@ -259,6 +223,16 @@ class DataStore(RoomMemberStore, RoomStore, ) raise _RollbackButIsFineException("_persist_event") + self._handle_prev_events( + txn, + outlier=outlier, + event_id=event.event_id, + prev_events=event.prev_events, + room_id=event.room_id, + ) + + self._store_state_groups_txn(txn, event) + is_state = hasattr(event, "state_key") and event.state_key is not None if is_new_state and is_state: vals = { @@ -284,6 +258,35 @@ class DataStore(RoomMemberStore, RoomStore, } ) + for hash_alg, hash_base64 in event.hashes.items(): + hash_bytes = decode_base64(hash_base64) + self._store_event_content_hash_txn( + txn, event.event_id, hash_alg, hash_bytes, + ) + + if hasattr(event, "signatures"): + signatures = event.signatures.get(event.origin, {}) + + for key_id, signature_base64 in signatures.items(): + signature_bytes = decode_base64(signature_base64) + self._store_event_origin_signature_txn( + txn, event.event_id, event.origin, key_id, signature_bytes, + ) + + for prev_event_id, prev_hashes in event.prev_events: + for alg, hash_base64 in prev_hashes.items(): + hash_bytes = decode_base64(hash_base64) + self._store_prev_event_hash_txn( + txn, event.event_id, prev_event_id, alg, hash_bytes + ) + + (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) + self._store_event_reference_hash_txn( + txn, event.event_id, ref_alg, ref_hash_bytes + ) + + self._update_min_depth_for_room_txn(txn, event.room_id, event.depth) + def _store_redaction(self, txn, event): txn.execute( "INSERT OR IGNORE INTO redactions " @@ -366,29 +369,19 @@ class DataStore(RoomMemberStore, RoomStore, """ def _snapshot(txn): membership_state = self._get_room_member(txn, user_id, room_id) - prev_pdus = self._get_latest_pdus_in_context( - txn, room_id - ) - - if state_type is not None and state_key is not None: - prev_state_pdu = self._get_current_state_pdu( - txn, room_id, state_type, state_key - ) - else: - prev_state_pdu = None + prev_events = self._get_latest_events_in_room(txn, room_id) return Snapshot( store=self, room_id=room_id, user_id=user_id, - prev_pdus=prev_pdus, + prev_events=prev_events, membership_state=membership_state, state_type=state_type, state_key=state_key, - prev_state_pdu=prev_state_pdu, ) - return self.runInteraction(_snapshot) + return self.runInteraction("snapshot_room", _snapshot) class Snapshot(object): @@ -397,7 +390,7 @@ class Snapshot(object): store (DataStore): The datastore. room_id (RoomId): The room of the snapshot. user_id (UserId): The user this snapshot is for. - prev_pdus (list): The list of PDU ids this snapshot is after. + prev_events (list): The list of event ids this snapshot is after. membership_state (RoomMemberEvent): The current state of the user in the room. state_type (str, optional): State type captured by the snapshot @@ -406,29 +399,29 @@ class Snapshot(object): the previous value of the state type and key in the room. """ - def __init__(self, store, room_id, user_id, prev_pdus, + def __init__(self, store, room_id, user_id, prev_events, membership_state, state_type=None, state_key=None, prev_state_pdu=None): self.store = store self.room_id = room_id self.user_id = user_id - self.prev_pdus = prev_pdus + self.prev_events = prev_events self.membership_state = membership_state self.state_type = state_type self.state_key = state_key self.prev_state_pdu = prev_state_pdu def fill_out_prev_events(self, event): - if hasattr(event, "prev_pdus"): + if hasattr(event, "prev_events"): return - event.prev_pdus = [ - (p_id, origin, hashes) - for p_id, origin, hashes, _ in self.prev_pdus + event.prev_events = [ + (event_id, hashes) + for event_id, hashes, _ in self.prev_events ] - if self.prev_pdus: - event.depth = max([int(v) for _, _, _, v in self.prev_pdus]) + 1 + if self.prev_events: + event.depth = max([int(v) for _, _, v in self.prev_events]) + 1 else: event.depth = 0 @@ -487,9 +480,10 @@ def prepare_database(db_conn): db_conn.commit() else: - sql_script = "BEGIN TRANSACTION;" + sql_script = "BEGIN TRANSACTION;\n" for sql_loc in SCHEMAS: sql_script += read_schema(sql_loc) + sql_script += "\n" sql_script += "COMMIT TRANSACTION;" c.executescript(sql_script) db_conn.commit() diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 65a86e9056..464b12f032 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,54 +19,66 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.api.events.utils import prune_event from synapse.util.logutils import log_function +from syutil.base64util import encode_base64 import collections import copy import json +import sys +import time logger = logging.getLogger(__name__) sql_logger = logging.getLogger("synapse.storage.SQL") +transaction_logger = logging.getLogger("synapse.storage.txn") class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging to the .execute() method.""" - __slots__ = ["txn"] + __slots__ = ["txn", "name"] - def __init__(self, txn): + def __init__(self, txn, name): object.__setattr__(self, "txn", txn) + object.__setattr__(self, "name", name) - def __getattribute__(self, name): - if name == "execute": - return object.__getattribute__(self, "execute") - - return getattr(object.__getattribute__(self, "txn"), name) + def __getattr__(self, name): + return getattr(self.txn, name) def __setattr__(self, name, value): - setattr(object.__getattribute__(self, "txn"), name, value) + setattr(self.txn, name, value) def execute(self, sql, *args, **kwargs): # TODO(paul): Maybe use 'info' and 'debug' for values? - sql_logger.debug("[SQL] %s", sql) + sql_logger.debug("[SQL] {%s} %s", self.name, sql) try: if args and args[0]: values = args[0] - sql_logger.debug("[SQL values] " + - ", ".join(("<%s>",) * len(values)), *values) + sql_logger.debug( + "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)), + self.name, + *values + ) except: # Don't let logging failures stop SQL from working pass - # TODO(paul): Here would be an excellent place to put some timing - # measurements, and log (warning?) slow queries. - return object.__getattribute__(self, "txn").execute( - sql, *args, **kwargs - ) + start = time.clock() * 1000 + try: + return self.txn.execute( + sql, *args, **kwargs + ) + except: + logger.exception("[SQL FAIL] {%s}", self.name) + raise + finally: + end = time.clock() * 1000 + sql_logger.debug("[SQL time] {%s} %f", self.name, end - start) class SQLBaseStore(object): + _TXN_ID = 0 def __init__(self, hs): self.hs = hs @@ -74,10 +86,30 @@ class SQLBaseStore(object): self.event_factory = hs.get_event_factory() self._clock = hs.get_clock() - def runInteraction(self, func, *args, **kwargs): + def runInteraction(self, desc, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" def inner_func(txn, *args, **kwargs): - return func(LoggingTransaction(txn), *args, **kwargs) + start = time.clock() * 1000 + txn_id = SQLBaseStore._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + + name = "%s-%x" % (desc, txn_id, ) + + transaction_logger.debug("[TXN START] {%s}", name) + try: + return func(LoggingTransaction(txn, name), *args, **kwargs) + except: + logger.exception("[TXN FAIL] {%s}", name) + raise + finally: + end = time.clock() * 1000 + transaction_logger.debug( + "[TXN END] {%s} %f", + name, end - start + ) return self._db_pool.runInteraction(inner_func, *args, **kwargs) @@ -113,7 +145,7 @@ class SQLBaseStore(object): else: return cursor.fetchall() - return self.runInteraction(interaction) + return self.runInteraction("_execute", interaction) def _execute_and_decode(self, query, *args): return self._execute(self.cursor_to_dict, query, *args) @@ -130,6 +162,7 @@ class SQLBaseStore(object): or_replace : bool; if True performs an INSERT OR REPLACE """ return self.runInteraction( + "_simple_insert", self._simple_insert_txn, table, values, or_replace=or_replace, or_ignore=or_ignore, ) @@ -170,7 +203,6 @@ class SQLBaseStore(object): table, keyvalues, retcols=retcols, allow_none=allow_none ) - @defer.inlineCallbacks def _simple_select_one_onecol(self, table, keyvalues, retcol, allow_none=False): """Executes a SELECT query on the named table, which is expected to @@ -181,19 +213,41 @@ class SQLBaseStore(object): keyvalues : dict of column names and values to select the row with retcol : string giving the name of the column to return """ - ret = yield self._simple_select_one( + return self.runInteraction( + "_simple_select_one_onecol_txn", + self._simple_select_one_onecol_txn, + table, keyvalues, retcol, allow_none=allow_none, + ) + + def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol, + allow_none=False): + ret = self._simple_select_onecol_txn( + txn, table=table, keyvalues=keyvalues, - retcols=[retcol], - allow_none=allow_none + retcol=retcol, ) if ret: - defer.returnValue(ret[retcol]) + return ret[0] else: - defer.returnValue(None) + if allow_none: + return None + else: + raise StoreError(404, "No row found") + + def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): + sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { + "retcol": retcol, + "table": table, + "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), + } + + txn.execute(sql, keyvalues.values()) + + return [r[0] for r in txn.fetchall()] + - @defer.inlineCallbacks def _simple_select_onecol(self, table, keyvalues, retcol): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -206,19 +260,11 @@ class SQLBaseStore(object): Returns: Deferred: Results in a list """ - sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { - "retcol": retcol, - "table": table, - "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), - } - - def func(txn): - txn.execute(sql, keyvalues.values()) - return txn.fetchall() - - res = yield self.runInteraction(func) - - defer.returnValue([r[0] for r in res]) + return self.runInteraction( + "_simple_select_onecol", + self._simple_select_onecol_txn, + table, keyvalues, retcol + ) def _simple_select_list(self, table, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or @@ -239,7 +285,7 @@ class SQLBaseStore(object): txn.execute(sql, keyvalues.values()) return self.cursor_to_dict(txn) - return self.runInteraction(func) + return self.runInteraction("_simple_select_list", func) def _simple_update_one(self, table, keyvalues, updatevalues, retcols=None): @@ -307,7 +353,7 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched") return ret - return self.runInteraction(func) + return self.runInteraction("_simple_selectupdate_one", func) def _simple_delete_one(self, table, keyvalues): """Executes a DELETE query on the named table, expecting to delete a @@ -319,7 +365,7 @@ class SQLBaseStore(object): """ sql = "DELETE FROM %s WHERE %s" % ( table, - " AND ".join("%s = ?" % (k) for k in keyvalues) + " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) def func(txn): @@ -328,7 +374,25 @@ class SQLBaseStore(object): raise StoreError(404, "No row found") if txn.rowcount > 1: raise StoreError(500, "more than one row matched") - return self.runInteraction(func) + return self.runInteraction("_simple_delete_one", func) + + def _simple_delete(self, table, keyvalues): + """Executes a DELETE query on the named table. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + + return self.runInteraction("_simple_delete", self._simple_delete_txn) + + def _simple_delete_txn(self, txn, table, keyvalues): + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k, ) for k in keyvalues) + ) + + return txn.execute(sql, keyvalues.values()) def _simple_max_id(self, table): """Executes a SELECT query on the named table, expecting to return the @@ -346,7 +410,7 @@ class SQLBaseStore(object): return 0 return max_id - return self.runInteraction(func) + return self.runInteraction("_simple_max_id", func) def _parse_event_from_row(self, row_dict): d = copy.deepcopy({k: v for k, v in row_dict.items()}) @@ -370,7 +434,9 @@ class SQLBaseStore(object): ) def _parse_events(self, rows): - return self.runInteraction(self._parse_events_txn, rows) + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) def _parse_events_txn(self, txn, rows): events = [self._parse_event_from_row(r) for r in rows] @@ -378,6 +444,17 @@ class SQLBaseStore(object): sql = "SELECT * FROM events WHERE event_id = ?" for ev in events: + signatures = self._get_event_origin_signatures_txn( + txn, ev.event_id, + ) + + ev.signatures = { + k: encode_base64(v) for k, v in signatures.items() + } + + prev_events = self._get_latest_events_in_room(txn, ev.room_id) + ev.prev_events = [(e_id, s,) for e_id, s, _ in prev_events] + if hasattr(ev, "prev_state"): # Load previous state_content. # TODO: Should we be pulling this out above? diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 52373a28a6..d6a7113b9c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -95,6 +95,7 @@ class DirectoryStore(SQLBaseStore): def delete_room_alias(self, room_alias): return self.runInteraction( + "delete_room_alias", self._delete_room_alias_txn, room_alias, ) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py new file mode 100644 index 0000000000..dcc116bad2 --- /dev/null +++ b/synapse/storage/event_federation.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from syutil.base64util import encode_base64 + +import logging + + +logger = logging.getLogger(__name__) + + +class EventFederationStore(SQLBaseStore): + + def get_oldest_events_in_room(self, room_id): + return self.runInteraction( + "get_oldest_events_in_room", + self._get_oldest_events_in_room_txn, + room_id, + ) + + def _get_oldest_events_in_room_txn(self, txn, room_id): + return self._simple_select_onecol_txn( + txn, + table="event_backward_extremities", + keyvalues={ + "room_id": room_id, + }, + retcol="event_id", + ) + + def get_latest_events_in_room(self, room_id): + return self.runInteraction( + "get_latest_events_in_room", + self._get_latest_events_in_room, + room_id, + ) + + def _get_latest_events_in_room(self, txn, room_id): + self._simple_select_onecol_txn( + txn, + table="event_forward_extremities", + keyvalues={ + "room_id": room_id, + }, + retcol="event_id", + ) + + sql = ( + "SELECT e.event_id, e.depth FROM events as e " + "INNER JOIN event_forward_extremities as f " + "ON e.event_id = f.event_id " + "WHERE f.room_id = ?" + ) + + txn.execute(sql, (room_id, )) + + results = [] + for event_id, depth in txn.fetchall(): + hashes = self._get_event_reference_hashes_txn(txn, event_id) + prev_hashes = { + k: encode_base64(v) for k, v in hashes.items() + if k == "sha256" + } + results.append((event_id, prev_hashes, depth)) + + return results + + def get_min_depth(self, room_id): + return self.runInteraction( + "get_min_depth", + self._get_min_depth_interaction, + room_id, + ) + + def _get_min_depth_interaction(self, txn, room_id): + min_depth = self._simple_select_one_onecol_txn( + txn, + table="room_depth", + keyvalues={"room_id": room_id,}, + retcol="min_depth", + allow_none=True, + ) + + return int(min_depth) if min_depth is not None else None + + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self._get_min_depth_interaction(txn, room_id) + + do_insert = depth < min_depth if min_depth else True + + if do_insert: + self._simple_insert_txn( + txn, + table="room_depth", + values={ + "room_id": room_id, + "min_depth": depth, + }, + or_replace=True, + ) + + def _handle_prev_events(self, txn, outlier, event_id, prev_events, + room_id): + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_insert_txn( + txn, + table="event_edges", + values={ + "event_id": event_id, + "prev_event_id": e_id, + "room_id": room_id, + }, + or_ignore=True, + ) + + # Update the extremities table if this is not an outlier. + if not outlier: + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_delete_txn( + txn, + table="event_forward_extremities", + keyvalues={ + "event_id": e_id, + "room_id": room_id, + } + ) + + + + # We only insert as a forward extremity the new pdu if there are no + # other pdus that reference it as a prev pdu + query = ( + "INSERT OR IGNORE INTO %(table)s (event_id, room_id) " + "SELECT ?, ? WHERE NOT EXISTS (" + "SELECT 1 FROM %(event_edges)s WHERE " + "prev_event_id = ? " + ")" + ) % { + "table": "event_forward_extremities", + "event_edges": "event_edges", + } + + logger.debug("query: %s", query) + + txn.execute(query, (event_id, room_id, event_id)) + + # Insert all the prev_pdus as a backwards thing, they'll get + # deleted in a second if they're incorrect anyway. + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_insert_txn( + txn, + table="event_backward_extremities", + values={ + "event_id": e_id, + "room_id": room_id, + }, + or_ignore=True, + ) + + # Also delete from the backwards extremities table all ones that + # reference pdus that we have already seen + query = ( + "DELETE FROM event_backward_extremities WHERE EXISTS (" + "SELECT 1 FROM events " + "WHERE " + "event_backward_extremities.event_id = events.event_id " + "AND not events.outlier " + ")" + ) + txn.execute(query) + + + def get_backfill_events(self, room_id, event_list, limit): + """Get a list of Events for a given topic that occured before (and + including) the pdus in pdu_list. Return a list of max size `limit`. + + Args: + txn + room_id (str) + event_list (list) + limit (int) + + Return: + list: A list of PduTuples + """ + return self.runInteraction( + "get_backfill_events", + self._get_backfill_events, room_id, event_list, limit + ) + + def _get_backfill_events(self, txn, room_id, event_list, limit): + logger.debug( + "_get_backfill_events: %s, %s, %s", + room_id, repr(event_list), limit + ) + + # We seed the pdu_results with the things from the pdu_list. + event_results = event_list + + front = event_list + + query = ( + "SELECT prev_event_id FROM event_edges " + "WHERE room_id = ? AND event_id = ? " + "LIMIT ?" + ) + + # We iterate through all event_ids in `front` to select their previous + # events. These are dumped in `new_front`. + # We continue until we reach the limit *or* new_front is empty (i.e., + # we've run out of things to select + while front and len(event_results) < limit: + + new_front = [] + for event_id in front: + logger.debug( + "_backfill_interaction: id=%s", + event_id + ) + + txn.execute( + query, + (room_id, event_id, limit - len(event_results)) + ) + + for row in txn.fetchall(): + logger.debug( + "_backfill_interaction: got id=%s", + *row + ) + new_front.append(row) + + front = new_front + event_results += new_front + + # We also want to update the `prev_pdus` attributes before returning. + return self._get_pdu_tuples(txn, event_results) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 4feb8335ba..fd705138e6 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -18,9 +18,10 @@ from _base import SQLBaseStore from twisted.internet import defer import OpenSSL -from syutil.crypto.signing_key import decode_verify_key_bytes +from syutil.crypto.signing_key import decode_verify_key_bytes import hashlib + class KeyStore(SQLBaseStore): """Persistence for signature verification keys and tls X.509 certificates """ diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py deleted file mode 100644 index 3a90c382f0..0000000000 --- a/synapse/storage/pdu.py +++ /dev/null @@ -1,932 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -from ._base import SQLBaseStore, Table, JoinHelper - -from synapse.federation.units import Pdu -from synapse.util.logutils import log_function - -from syutil.base64util import encode_base64 - -from collections import namedtuple - -import logging - - -logger = logging.getLogger(__name__) - - -class PduStore(SQLBaseStore): - """A collection of queries for handling PDUs. - """ - - def get_pdu(self, pdu_id, origin): - """Given a pdu_id and origin, get a PDU. - - Args: - txn - pdu_id (str) - origin (str) - - Returns: - PduTuple: If the pdu does not exist in the database, returns None - """ - - return self.runInteraction( - self._get_pdu_tuple, pdu_id, origin - ) - - def _get_pdu_tuple(self, txn, pdu_id, origin): - res = self._get_pdu_tuples(txn, [(pdu_id, origin)]) - return res[0] if res else None - - def _get_pdu_tuples(self, txn, pdu_id_tuples): - results = [] - for pdu_id, origin in pdu_id_tuples: - txn.execute( - PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"), - (pdu_id, origin) - ) - - edges = [ - (r.prev_pdu_id, r.prev_origin) - for r in PduEdgesTable.decode_results(txn.fetchall()) - ] - - edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin) - - hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin) - signatures = self._get_pdu_origin_signatures_txn( - txn, pdu_id, origin - ) - - query = ( - "SELECT %(fields)s FROM %(pdus)s as p " - "LEFT JOIN %(state)s as s " - "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " - "WHERE p.pdu_id = ? AND p.origin = ? " - ) % { - "fields": _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s"), - "pdus": PdusTable.table_name, - "state": StatePdusTable.table_name, - } - - txn.execute(query, (pdu_id, origin)) - - row = txn.fetchone() - if row: - results.append(PduTuple( - PduEntry(*row), edges, hashes, signatures, edge_hashes - )) - - return results - - def get_current_state_for_context(self, context): - """Get a list of PDUs that represent the current state for a given - context - - Args: - context (str) - - Returns: - list: A list of PduTuples - """ - - return self.runInteraction( - self._get_current_state_for_context, - context - ) - - def _get_current_state_for_context(self, txn, context): - query = ( - "SELECT pdu_id, origin FROM %s WHERE context = ?" - % CurrentStateTable.table_name - ) - - logger.debug("get_current_state %s, Args=%s", query, context) - txn.execute(query, (context,)) - - res = txn.fetchall() - - logger.debug("get_current_state %d results", len(res)) - - return self._get_pdu_tuples(txn, res) - - def _persist_pdu_txn(self, txn, prev_pdus, cols): - """Inserts a (non-state) PDU into the database. - - Args: - txn, - prev_pdus (list) - **cols: The columns to insert into the PdusTable. - """ - entry = PdusTable.EntryType( - **{k: cols.get(k, None) for k in PdusTable.fields} - ) - - txn.execute(PdusTable.insert_statement(), entry) - - self._handle_prev_pdus( - txn, entry.outlier, entry.pdu_id, entry.origin, - prev_pdus, entry.context - ) - - def mark_pdu_as_processed(self, pdu_id, pdu_origin): - """Mark a received PDU as processed. - - Args: - txn - pdu_id (str) - pdu_origin (str) - """ - - return self.runInteraction( - self._mark_as_processed, pdu_id, pdu_origin - ) - - def _mark_as_processed(self, txn, pdu_id, pdu_origin): - txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name) - - def get_all_pdus_from_context(self, context): - """Get a list of all PDUs for a given context.""" - return self.runInteraction( - self._get_all_pdus_from_context, context, - ) - - def _get_all_pdus_from_context(self, txn, context): - query = ( - "SELECT pdu_id, origin FROM %s " - "WHERE context = ?" - ) % PdusTable.table_name - - txn.execute(query, (context,)) - - return self._get_pdu_tuples(txn, txn.fetchall()) - - def get_backfill(self, context, pdu_list, limit): - """Get a list of Pdus for a given topic that occured before (and - including) the pdus in pdu_list. Return a list of max size `limit`. - - Args: - txn - context (str) - pdu_list (list) - limit (int) - - Return: - list: A list of PduTuples - """ - return self.runInteraction( - self._get_backfill, context, pdu_list, limit - ) - - def _get_backfill(self, txn, context, pdu_list, limit): - logger.debug( - "backfill: %s, %s, %s", - context, repr(pdu_list), limit - ) - - # We seed the pdu_results with the things from the pdu_list. - pdu_results = pdu_list - - front = pdu_list - - query = ( - "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s " - "WHERE context = ? AND pdu_id = ? AND origin = ? " - "LIMIT ?" - ) % { - "edges_table": PduEdgesTable.table_name, - } - - # We iterate through all pdu_ids in `front` to select their previous - # pdus. These are dumped in `new_front`. We continue until we reach the - # limit *or* new_front is empty (i.e., we've run out of things to - # select - while front and len(pdu_results) < limit: - - new_front = [] - for pdu_id, origin in front: - logger.debug( - "_backfill_interaction: i=%s, o=%s", - pdu_id, origin - ) - - txn.execute( - query, - (context, pdu_id, origin, limit - len(pdu_results)) - ) - - for row in txn.fetchall(): - logger.debug( - "_backfill_interaction: got i=%s, o=%s", - *row - ) - new_front.append(row) - - front = new_front - pdu_results += new_front - - # We also want to update the `prev_pdus` attributes before returning. - return self._get_pdu_tuples(txn, pdu_results) - - def get_min_depth_for_context(self, context): - """Get the current minimum depth for a context - - Args: - txn - context (str) - """ - return self.runInteraction( - self._get_min_depth_for_context, context - ) - - def _get_min_depth_for_context(self, txn, context): - return self._get_min_depth_interaction(txn, context) - - def _get_min_depth_interaction(self, txn, context): - txn.execute( - "SELECT min_depth FROM %s WHERE context = ?" - % ContextDepthTable.table_name, - (context,) - ) - - row = txn.fetchone() - - return row[0] if row else None - - def _update_min_depth_for_context_txn(self, txn, context, depth): - """Update the minimum `depth` of the given context, which is the line - on which we stop backfilling backwards. - - Args: - context (str) - depth (int) - """ - min_depth = self._get_min_depth_interaction(txn, context) - - do_insert = depth < min_depth if min_depth else True - - if do_insert: - txn.execute( - "INSERT OR REPLACE INTO %s (context, min_depth) " - "VALUES (?,?)" % ContextDepthTable.table_name, - (context, depth) - ) - - def _get_latest_pdus_in_context(self, txn, context): - """Get's a list of the most current pdus for a given context. This is - used when we are sending a Pdu and need to fill out the `prev_pdus` - key - - Args: - txn - context - """ - query = ( - "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p " - "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id " - "AND f.origin = p.origin " - "WHERE f.context = ?" - ) % { - "pdus": PdusTable.table_name, - "forward": PduForwardExtremitiesTable.table_name, - } - - logger.debug("get_prev query: %s", query) - - txn.execute( - query, - (context, ) - ) - - results = [] - for pdu_id, origin, depth in txn.fetchall(): - hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin) - sha256_bytes = hashes["sha256"] - prev_hashes = {"sha256": encode_base64(sha256_bytes)} - results.append((pdu_id, origin, prev_hashes, depth)) - - return results - - @defer.inlineCallbacks - def get_oldest_pdus_in_context(self, context): - """Get a list of Pdus that we haven't backfilled beyond yet (and havent - seen). This list is used when we want to backfill backwards and is the - list we send to the remote server. - - Args: - txn - context (str) - - Returns: - list: A list of PduIdTuple. - """ - results = yield self._execute( - None, - "SELECT pdu_id, origin FROM %(back)s WHERE context = ?" - % {"back": PduBackwardExtremitiesTable.table_name, }, - context - ) - - defer.returnValue([PduIdTuple(i, o) for i, o in results]) - - def is_pdu_new(self, pdu_id, origin, context, depth): - """For a given Pdu, try and figure out if it's 'new', i.e., if it's - not something we got randomly from the past, for example when we - request the current state of the room that will probably return a bunch - of pdus from before we joined. - - Args: - txn - pdu_id (str) - origin (str) - context (str) - depth (int) - - Returns: - bool - """ - - return self.runInteraction( - self._is_pdu_new, - pdu_id=pdu_id, - origin=origin, - context=context, - depth=depth - ) - - def _is_pdu_new(self, txn, pdu_id, origin, context, depth): - # If depth > min depth in back table, then we classify it as new. - # OR if there is nothing in the back table, then it kinda needs to - # be a new thing. - query = ( - "SELECT min(p.depth) FROM %(edges)s as e " - "INNER JOIN %(back)s as b " - "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin " - "INNER JOIN %(pdus)s as p " - "ON e.pdu_id = p.pdu_id AND p.origin = e.origin " - "WHERE p.context = ?" - ) % { - "pdus": PdusTable.table_name, - "edges": PduEdgesTable.table_name, - "back": PduBackwardExtremitiesTable.table_name, - } - - txn.execute(query, (context,)) - - min_depth, = txn.fetchone() - - if not min_depth or depth > int(min_depth): - logger.debug( - "is_new true: id=%s, o=%s, d=%s min_depth=%s", - pdu_id, origin, depth, min_depth - ) - return True - - # If this pdu is in the forwards table, then it also is a new one - query = ( - "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?" - ) % { - "forward": PduForwardExtremitiesTable.table_name, - } - - txn.execute(query, (pdu_id, origin)) - - # Did we get anything? - if txn.fetchall(): - logger.debug( - "is_new true: id=%s, o=%s, d=%s was forward", - pdu_id, origin, depth - ) - return True - - logger.debug( - "is_new false: id=%s, o=%s, d=%s", - pdu_id, origin, depth - ) - - # FINE THEN. It's probably old. - return False - - @staticmethod - @log_function - def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus, - context): - txn.executemany( - PduEdgesTable.insert_statement(), - [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus] - ) - - # Update the extremities table if this is not an outlier. - if not outlier: - - # First, we delete the new one from the forwards extremities table. - query = ( - "DELETE FROM %s WHERE pdu_id = ? AND origin = ?" - % PduForwardExtremitiesTable.table_name - ) - txn.executemany(query, list(p[:2] for p in prev_pdus)) - - # We only insert as a forward extremety the new pdu if there are no - # other pdus that reference it as a prev pdu - query = ( - "INSERT INTO %(table)s (pdu_id, origin, context) " - "SELECT ?, ?, ? WHERE NOT EXISTS (" - "SELECT 1 FROM %(pdu_edges)s WHERE " - "prev_pdu_id = ? AND prev_origin = ?" - ")" - ) % { - "table": PduForwardExtremitiesTable.table_name, - "pdu_edges": PduEdgesTable.table_name - } - - logger.debug("query: %s", query) - - txn.execute(query, (pdu_id, origin, context, pdu_id, origin)) - - # Insert all the prev_pdus as a backwards thing, they'll get - # deleted in a second if they're incorrect anyway. - txn.executemany( - PduBackwardExtremitiesTable.insert_statement(), - [(i, o, context) for i, o, _ in prev_pdus] - ) - - # Also delete from the backwards extremities table all ones that - # reference pdus that we have already seen - query = ( - "DELETE FROM %(pdu_back)s WHERE EXISTS (" - "SELECT 1 FROM %(pdus)s AS pdus " - "WHERE " - "%(pdu_back)s.pdu_id = pdus.pdu_id " - "AND %(pdu_back)s.origin = pdus.origin " - "AND not pdus.outlier " - ")" - ) % { - "pdu_back": PduBackwardExtremitiesTable.table_name, - "pdus": PdusTable.table_name, - } - txn.execute(query) - - -class StatePduStore(SQLBaseStore): - """A collection of queries for handling state PDUs. - """ - - def _persist_state_txn(self, txn, prev_pdus, cols): - """Inserts a state PDU into the database - - Args: - txn, - prev_pdus (list) - **cols: The columns to insert into the PdusTable and StatePdusTable - """ - pdu_entry = PdusTable.EntryType( - **{k: cols.get(k, None) for k in PdusTable.fields} - ) - state_entry = StatePdusTable.EntryType( - **{k: cols.get(k, None) for k in StatePdusTable.fields} - ) - - logger.debug("Inserting pdu: %s", repr(pdu_entry)) - logger.debug("Inserting state: %s", repr(state_entry)) - - txn.execute(PdusTable.insert_statement(), pdu_entry) - txn.execute(StatePdusTable.insert_statement(), state_entry) - - self._handle_prev_pdus( - txn, - pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus, - pdu_entry.context - ) - - def get_unresolved_state_tree(self, new_state_pdu): - return self.runInteraction( - self._get_unresolved_state_tree, new_state_pdu - ) - - @log_function - def _get_unresolved_state_tree(self, txn, new_pdu): - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - ReturnType = namedtuple( - "StateReturnType", ["new_branch", "current_branch"] - ) - return_value = ReturnType([new_pdu], []) - - if not current: - logger.debug("get_unresolved_state_tree No current state.") - return (return_value, None) - - return_value.current_branch.append(current) - - enum_branches = self._enumerate_state_branches( - txn, new_pdu, current - ) - - missing_branch = None - for branch, prev_state, state in enum_branches: - if state: - return_value[branch].append(state) - else: - # We don't have prev_state :( - missing_branch = branch - break - - return (return_value, missing_branch) - - def update_current_state(self, pdu_id, origin, context, pdu_type, - state_key): - return self.runInteraction( - self._update_current_state, - pdu_id, origin, context, pdu_type, state_key - ) - - def _update_current_state(self, txn, pdu_id, origin, context, pdu_type, - state_key): - query = ( - "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" - ) % { - "curr": CurrentStateTable.table_name, - "fields": CurrentStateTable.get_fields_string(), - "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) - } - - query_args = CurrentStateTable.EntryType( - pdu_id=pdu_id, - origin=origin, - context=context, - pdu_type=pdu_type, - state_key=state_key - ) - - txn.execute(query, query_args) - - def get_current_state_pdu(self, context, pdu_type, state_key): - """For a given context, pdu_type, state_key 3-tuple, return what is - currently considered the current state. - - Args: - txn - context (str) - pdu_type (str) - state_key (str) - - Returns: - PduEntry - """ - - return self.runInteraction( - self._get_current_state_pdu, context, pdu_type, state_key - ) - - def _get_current_state_pdu(self, txn, context, pdu_type, state_key): - return self._get_current_interaction(txn, context, pdu_type, state_key) - - def _get_current_interaction(self, txn, context, pdu_type, state_key): - logger.debug( - "_get_current_interaction %s %s %s", - context, pdu_type, state_key - ) - - fields = _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s") - - current_query = ( - "SELECT %(fields)s FROM %(state)s as s " - "INNER JOIN %(pdus)s as p " - "ON s.pdu_id = p.pdu_id AND s.origin = p.origin " - "INNER JOIN %(curr)s as c " - "ON s.pdu_id = c.pdu_id AND s.origin = c.origin " - "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? " - ) % { - "fields": fields, - "curr": CurrentStateTable.table_name, - "state": StatePdusTable.table_name, - "pdus": PdusTable.table_name, - } - - txn.execute( - current_query, - (context, pdu_type, state_key) - ) - - row = txn.fetchone() - - result = PduEntry(*row) if row else None - - if not result: - logger.debug("_get_current_interaction not found") - else: - logger.debug( - "_get_current_interaction found %s %s", - result.pdu_id, result.origin - ) - - return result - - def handle_new_state(self, new_pdu): - """Actually perform conflict resolution on the new_pdu on the - assumption we have all the pdus required to perform it. - - Args: - new_pdu - - Returns: - bool: True if the new_pdu clobbered the current state, False if not - """ - return self.runInteraction( - self._handle_new_state, new_pdu - ) - - def _handle_new_state(self, txn, new_pdu): - logger.debug( - "handle_new_state %s %s", - new_pdu.pdu_id, new_pdu.origin - ) - - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - is_current = False - - if (not current or not current.prev_state_id - or not current.prev_state_origin): - # Oh, we don't have any state for this yet. - is_current = True - elif (current.pdu_id == new_pdu.prev_state_id - and current.origin == new_pdu.prev_state_origin): - # Oh! A direct clobber. Just do it. - is_current = True - else: - ## - # Ok, now loop through until we get to a common ancestor. - max_new = int(new_pdu.power_level) - max_current = int(current.power_level) - - enum_branches = self._enumerate_state_branches( - txn, new_pdu, current - ) - for branch, prev_state, state in enum_branches: - if not state: - raise RuntimeError( - "Could not find state_pdu %s %s" % - ( - prev_state.prev_state_id, - prev_state.prev_state_origin - ) - ) - - if branch == 0: - max_new = max(int(state.depth), max_new) - else: - max_current = max(int(state.depth), max_current) - - is_current = max_new > max_current - - if is_current: - logger.debug("handle_new_state make current") - - # Right, this is a new thing, so woo, just insert it. - txn.execute( - "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" - % { - "curr": CurrentStateTable.table_name, - "fields": CurrentStateTable.get_fields_string(), - "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) - }, - CurrentStateTable.EntryType( - *(new_pdu.__dict__[k] for k in CurrentStateTable.fields) - ) - ) - else: - logger.debug("handle_new_state not current") - - logger.debug("handle_new_state done") - - return is_current - - @log_function - def _enumerate_state_branches(self, txn, pdu_a, pdu_b): - branch_a = pdu_a - branch_b = pdu_b - - while True: - if (branch_a.pdu_id == branch_b.pdu_id - and branch_a.origin == branch_b.origin): - # Woo! We found a common ancestor - logger.debug("_enumerate_state_branches Found common ancestor") - break - - do_branch_a = ( - hasattr(branch_a, "prev_state_id") and - branch_a.prev_state_id - ) - - do_branch_b = ( - hasattr(branch_b, "prev_state_id") and - branch_b.prev_state_id - ) - - logger.debug( - "do_branch_a=%s, do_branch_b=%s", - do_branch_a, do_branch_b - ) - - if do_branch_a and do_branch_b: - do_branch_a = int(branch_a.depth) > int(branch_b.depth) - - if do_branch_a: - pdu_tuple = PduIdTuple( - branch_a.prev_state_id, - branch_a.prev_state_origin - ) - - prev_branch = branch_a - - logger.debug("getting branch_a prev %s", pdu_tuple) - branch_a = self._get_pdu_tuple(txn, *pdu_tuple) - if branch_a: - branch_a = Pdu.from_pdu_tuple(branch_a) - - logger.debug("branch_a=%s", branch_a) - - yield (0, prev_branch, branch_a) - - if not branch_a: - break - elif do_branch_b: - pdu_tuple = PduIdTuple( - branch_b.prev_state_id, - branch_b.prev_state_origin - ) - - prev_branch = branch_b - - logger.debug("getting branch_b prev %s", pdu_tuple) - branch_b = self._get_pdu_tuple(txn, *pdu_tuple) - if branch_b: - branch_b = Pdu.from_pdu_tuple(branch_b) - - logger.debug("branch_b=%s", branch_b) - - yield (1, prev_branch, branch_b) - - if not branch_b: - break - else: - break - - -class PdusTable(Table): - table_name = "pdus" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "ts", - "depth", - "is_state", - "content_json", - "unrecognized_keys", - "outlier", - "have_processed", - ] - - EntryType = namedtuple("PdusEntry", fields) - - -class PduDestinationsTable(Table): - table_name = "pdu_destinations" - - fields = [ - "pdu_id", - "origin", - "destination", - "delivered_ts", - ] - - EntryType = namedtuple("PduDestinationsEntry", fields) - - -class PduEdgesTable(Table): - table_name = "pdu_edges" - - fields = [ - "pdu_id", - "origin", - "prev_pdu_id", - "prev_origin", - "context" - ] - - EntryType = namedtuple("PduEdgesEntry", fields) - - -class PduForwardExtremitiesTable(Table): - table_name = "pdu_forward_extremities" - - fields = [ - "pdu_id", - "origin", - "context", - ] - - EntryType = namedtuple("PduForwardExtremitiesEntry", fields) - - -class PduBackwardExtremitiesTable(Table): - table_name = "pdu_backward_extremities" - - fields = [ - "pdu_id", - "origin", - "context", - ] - - EntryType = namedtuple("PduBackwardExtremitiesEntry", fields) - - -class ContextDepthTable(Table): - table_name = "context_depth" - - fields = [ - "context", - "min_depth", - ] - - EntryType = namedtuple("ContextDepthEntry", fields) - - -class StatePdusTable(Table): - table_name = "state_pdus" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "state_key", - "power_level", - "prev_state_id", - "prev_state_origin", - ] - - EntryType = namedtuple("StatePdusEntry", fields) - - -class CurrentStateTable(Table): - table_name = "current_state" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "state_key", - ] - - EntryType = namedtuple("CurrentStateEntry", fields) - -_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable) - - -# TODO: These should probably be put somewhere more sensible -PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin")) - -PduEntry = _pdu_state_joiner.EntryType -""" We are always interested in the join of the PdusTable and StatePdusTable, -rather than just the PdusTable. - -This does not include a prev_pdus key. -""" - -PduTuple = namedtuple( - "PduTuple", - ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes") -) -""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent -the `prev_pdus` key of a PDU. -""" diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 719806f82b..a2ca6f9a69 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -62,8 +62,10 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if the user_id could not be registered. """ - yield self.runInteraction(self._register, user_id, token, - password_hash) + yield self.runInteraction( + "register", + self._register, user_id, token, password_hash + ) def _register(self, txn, user_id, token, password_hash): now = int(self.clock.time()) @@ -100,6 +102,7 @@ class RegistrationStore(SQLBaseStore): StoreError if no user was found. """ return self.runInteraction( + "get_user_by_token", self._query_for_auth, token ) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 8cd46334cf..7e48ce9cc3 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -150,6 +150,7 @@ class RoomStore(SQLBaseStore): def get_power_level(self, room_id, user_id): return self.runInteraction( + "get_power_level", self._get_power_level, room_id, user_id, ) @@ -183,6 +184,7 @@ class RoomStore(SQLBaseStore): def get_ops_levels(self, room_id): return self.runInteraction( + "get_ops_levels", self._get_ops_levels, room_id, ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index ceeef5880e..93329703a2 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -33,7 +33,9 @@ class RoomMemberStore(SQLBaseStore): target_user_id = event.state_key domain = self.hs.parse_userid(target_user_id).domain except: - logger.exception("Failed to parse target_user_id=%s", target_user_id) + logger.exception( + "Failed to parse target_user_id=%s", target_user_id + ) raise logger.debug( @@ -65,7 +67,8 @@ class RoomMemberStore(SQLBaseStore): # Check if this was the last person to have left. member_events = self._get_members_query_txn( txn, - where_clause="c.room_id = ? AND m.membership = ? AND m.user_id != ?", + where_clause=("c.room_id = ? AND m.membership = ?" + " AND m.user_id != ?"), where_values=(event.room_id, Membership.JOIN, target_user_id,) ) @@ -120,7 +123,6 @@ class RoomMemberStore(SQLBaseStore): else: return None - def get_room_members(self, room_id, membership=None): """Retrieve the current room member list for a room. diff --git a/synapse/storage/schema/edge_pdus.sql b/synapse/storage/schema/edge_pdus.sql deleted file mode 100644 index 8a00868065..0000000000 --- a/synapse/storage/schema/edge_pdus.sql +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2014 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS context_edge_pdus( - id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this - pdu_id TEXT, - origin TEXT, - context TEXT, - CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin) -); - -CREATE TABLE IF NOT EXISTS origin_edge_pdus( - id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this - pdu_id TEXT, - origin TEXT, - CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin) -); - -CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin); -CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin); diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql new file mode 100644 index 0000000000..e5f768c705 --- /dev/null +++ b/synapse/storage/schema/event_edges.sql @@ -0,0 +1,49 @@ + +CREATE TABLE IF NOT EXISTS event_forward_extremities( + event_id TEXT, + room_id TEXT, + CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id); +CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_backward_extremities( + event_id TEXT, + room_id TEXT, + CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id); +CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_edges( + event_id TEXT, + prev_event_id TEXT, + room_id TEXT, + CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id) +); + +CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id); +CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id); + + +CREATE TABLE IF NOT EXISTS room_depth( + room_id TEXT, + min_depth INTEGER, + CONSTRAINT uniqueness UNIQUE (room_id) +); + +CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id); + + +create TABLE IF NOT EXISTS event_destinations( + event_id TEXT, + destination TEXT, + delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered + CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id); diff --git a/synapse/storage/schema/event_signatures.sql b/synapse/storage/schema/event_signatures.sql new file mode 100644 index 0000000000..5491c7ecec --- /dev/null +++ b/synapse/storage/schema/event_signatures.sql @@ -0,0 +1,65 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_content_hashes ( + event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_reference_hashes ( + event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes ( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_origin_signatures ( + event_id TEXT, + origin TEXT, + key_id TEXT, + signature BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, key_id) +); + +CREATE INDEX IF NOT EXISTS event_origin_signatures_id ON event_origin_signatures ( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_edge_hashes( + event_id TEXT, + prev_event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE ( + event_id, prev_event_id, algorithm + ) +); + +CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes( + event_id +); diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql index 3aa83f5c8c..8d6f655993 100644 --- a/synapse/storage/schema/im.sql +++ b/synapse/storage/schema/im.sql @@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events( unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, + depth INTEGER DEFAULT 0 NOT NULL, CONSTRAINT ev_uniq UNIQUE (event_id) ); diff --git a/synapse/storage/schema/pdu.sql b/synapse/storage/schema/pdu.sql deleted file mode 100644 index 16e111a56c..0000000000 --- a/synapse/storage/schema/pdu.sql +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2014 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ --- Stores pdus and their content -CREATE TABLE IF NOT EXISTS pdus( - pdu_id TEXT, - origin TEXT, - context TEXT, - pdu_type TEXT, - ts INTEGER, - depth INTEGER DEFAULT 0 NOT NULL, - is_state BOOL, - content_json TEXT, - unrecognized_keys TEXT, - outlier BOOL NOT NULL, - have_processed BOOL, - CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) -); - --- Stores what the current state pdu is for a given (context, pdu_type, key) tuple -CREATE TABLE IF NOT EXISTS state_pdus( - pdu_id TEXT, - origin TEXT, - context TEXT, - pdu_type TEXT, - state_key TEXT, - power_level TEXT, - prev_state_id TEXT, - prev_state_origin TEXT, - CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) - CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin) -); - -CREATE TABLE IF NOT EXISTS current_state( - pdu_id TEXT, - origin TEXT, - context TEXT, - pdu_type TEXT, - state_key TEXT, - CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) - CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE -); - --- Stores where each pdu we want to send should be sent and the delivery status. -create TABLE IF NOT EXISTS pdu_destinations( - pdu_id TEXT, - origin TEXT, - destination TEXT, - delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE -); - -CREATE TABLE IF NOT EXISTS pdu_forward_extremities( - pdu_id TEXT, - origin TEXT, - context TEXT, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE -); - -CREATE TABLE IF NOT EXISTS pdu_backward_extremities( - pdu_id TEXT, - origin TEXT, - context TEXT, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE -); - -CREATE TABLE IF NOT EXISTS pdu_edges( - pdu_id TEXT, - origin TEXT, - prev_pdu_id TEXT, - prev_origin TEXT, - context TEXT, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context) -); - -CREATE TABLE IF NOT EXISTS context_depth( - context TEXT, - min_depth INTEGER, - CONSTRAINT uniqueness UNIQUE (context) -); - -CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context); - - -CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin); - -CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin); --- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination); - -CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context); -CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin); - -CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin); - -CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context); diff --git a/synapse/storage/schema/signatures.sql b/synapse/storage/schema/signatures.sql deleted file mode 100644 index 1c45a51bec..0000000000 --- a/synapse/storage/schema/signatures.sql +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2014 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS pdu_content_hashes ( - pdu_id TEXT, - origin TEXT, - algorithm TEXT, - hash BLOB, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm) -); - -CREATE INDEX IF NOT EXISTS pdu_content_hashes_id ON pdu_content_hashes ( - pdu_id, origin -); - -CREATE TABLE IF NOT EXISTS pdu_reference_hashes ( - pdu_id TEXT, - origin TEXT, - algorithm TEXT, - hash BLOB, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm) -); - -CREATE INDEX IF NOT EXISTS pdu_reference_hashes_id ON pdu_reference_hashes ( - pdu_id, origin -); - -CREATE TABLE IF NOT EXISTS pdu_origin_signatures ( - pdu_id TEXT, - origin TEXT, - key_id TEXT, - signature BLOB, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, key_id) -); - -CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures ( - pdu_id, origin -); - -CREATE TABLE IF NOT EXISTS pdu_edge_hashes( - pdu_id TEXT, - origin TEXT, - prev_pdu_id TEXT, - prev_origin TEXT, - algorithm TEXT, - hash BLOB, - CONSTRAINT uniqueness UNIQUE ( - pdu_id, origin, prev_pdu_id, prev_origin, algorithm - ) -); - -CREATE INDEX IF NOT EXISTS pdu_edge_hashes_id ON pdu_edge_hashes( - pdu_id, origin -); diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql new file mode 100644 index 0000000000..b44c56b519 --- /dev/null +++ b/synapse/storage/schema/state.sql @@ -0,0 +1,33 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS state_groups( + id INTEGER PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS state_groups_state( + state_group INTEGER NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS event_to_state_groups( + event_id TEXT NOT NULL, + state_group INTEGER NOT NULL +); \ No newline at end of file diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 82be946d3f..b4b3d5d7ea 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -17,139 +17,149 @@ from _base import SQLBaseStore class SignatureStore(SQLBaseStore): - """Persistence for PDU signatures and hashes""" + """Persistence for event signatures and hashes""" - def _get_pdu_content_hashes_txn(self, txn, pdu_id, origin): - """Get all the hashes for a given PDU. + def _get_event_content_hashes_txn(self, txn, event_id): + """Get all the hashes for a given Event. Args: txn (cursor): - pdu_id (str): Id for the PDU. - origin (str): origin of the PDU. + event_id (str): Id for the Event. Returns: A dict of algorithm -> hash. """ query = ( "SELECT algorithm, hash" - " FROM pdu_content_hashes" - " WHERE pdu_id = ? and origin = ?" + " FROM event_content_hashes" + " WHERE event_id = ?" ) - txn.execute(query, (pdu_id, origin)) + txn.execute(query, (event_id, )) return dict(txn.fetchall()) - def _store_pdu_content_hash_txn(self, txn, pdu_id, origin, algorithm, + def _store_event_content_hash_txn(self, txn, event_id, algorithm, hash_bytes): - """Store a hash for a PDU + """Store a hash for a Event Args: txn (cursor): - pdu_id (str): Id for the PDU. - origin (str): origin of the PDU. + event_id (str): Id for the Event. algorithm (str): Hashing algorithm. hash_bytes (bytes): Hash function output bytes. """ - self._simple_insert_txn(txn, "pdu_content_hashes", { - "pdu_id": pdu_id, - "origin": origin, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }) + self._simple_insert_txn( + txn, + "event_content_hashes", + { + "event_id": event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) - def _get_pdu_reference_hashes_txn(self, txn, pdu_id, origin): + def _get_event_reference_hashes_txn(self, txn, event_id): """Get all the hashes for a given PDU. Args: txn (cursor): - pdu_id (str): Id for the PDU. - origin (str): origin of the PDU. + event_id (str): Id for the Event. Returns: A dict of algorithm -> hash. """ query = ( "SELECT algorithm, hash" - " FROM pdu_reference_hashes" - " WHERE pdu_id = ? and origin = ?" + " FROM event_reference_hashes" + " WHERE event_id = ?" ) - txn.execute(query, (pdu_id, origin)) + txn.execute(query, (event_id, )) return dict(txn.fetchall()) - def _store_pdu_reference_hash_txn(self, txn, pdu_id, origin, algorithm, + def _store_event_reference_hash_txn(self, txn, event_id, algorithm, hash_bytes): """Store a hash for a PDU Args: txn (cursor): - pdu_id (str): Id for the PDU. - origin (str): origin of the PDU. + event_id (str): Id for the Event. algorithm (str): Hashing algorithm. hash_bytes (bytes): Hash function output bytes. """ - self._simple_insert_txn(txn, "pdu_reference_hashes", { - "pdu_id": pdu_id, - "origin": origin, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }) + self._simple_insert_txn( + txn, + "event_reference_hashes", + { + "event_id": event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) - def _get_pdu_origin_signatures_txn(self, txn, pdu_id, origin): + def _get_event_origin_signatures_txn(self, txn, event_id): """Get all the signatures for a given PDU. Args: txn (cursor): - pdu_id (str): Id for the PDU. - origin (str): origin of the PDU. + event_id (str): Id for the Event. Returns: A dict of key_id -> signature_bytes. """ query = ( "SELECT key_id, signature" - " FROM pdu_origin_signatures" - " WHERE pdu_id = ? and origin = ?" + " FROM event_origin_signatures" + " WHERE event_id = ? " ) - txn.execute(query, (pdu_id, origin)) + txn.execute(query, (event_id, )) return dict(txn.fetchall()) - def _store_pdu_origin_signature_txn(self, txn, pdu_id, origin, key_id, - signature_bytes): + def _store_event_origin_signature_txn(self, txn, event_id, origin, key_id, + signature_bytes): """Store a signature from the origin server for a PDU. Args: txn (cursor): - pdu_id (str): Id for the PDU. - origin (str): origin of the PDU. + event_id (str): Id for the Event. + origin (str): origin of the Event. key_id (str): Id for the signing key. signature (bytes): The signature. """ - self._simple_insert_txn(txn, "pdu_origin_signatures", { - "pdu_id": pdu_id, - "origin": origin, - "key_id": key_id, - "signature": buffer(signature_bytes), - }) + self._simple_insert_txn( + txn, + "event_origin_signatures", + { + "event_id": event_id, + "origin": origin, + "key_id": key_id, + "signature": buffer(signature_bytes), + }, + or_ignore=True, + ) - def _get_prev_pdu_hashes_txn(self, txn, pdu_id, origin): + def _get_prev_event_hashes_txn(self, txn, event_id): """Get all the hashes for previous PDUs of a PDU Args: txn (cursor): - pdu_id (str): Id of the PDU. - origin (str): Origin of the PDU. + event_id (str): Id for the Event. Returns: dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes. """ query = ( - "SELECT prev_pdu_id, prev_origin, algorithm, hash" - " FROM pdu_edge_hashes" - " WHERE pdu_id = ? and origin = ?" + "SELECT prev_event_id, algorithm, hash" + " FROM event_edge_hashes" + " WHERE event_id = ?" ) - txn.execute(query, (pdu_id, origin)) + txn.execute(query, (event_id, )) results = {} - for prev_pdu_id, prev_origin, algorithm, hash_bytes in txn.fetchall(): - hashes = results.setdefault((prev_pdu_id, prev_origin), {}) + for prev_event_id, algorithm, hash_bytes in txn.fetchall(): + hashes = results.setdefault(prev_event_id, {}) hashes[algorithm] = hash_bytes return results - def _store_prev_pdu_hash_txn(self, txn, pdu_id, origin, prev_pdu_id, - prev_origin, algorithm, hash_bytes): - self._simple_insert_txn(txn, "pdu_edge_hashes", { - "pdu_id": pdu_id, - "origin": origin, - "prev_pdu_id": prev_pdu_id, - "prev_origin": prev_origin, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }) + def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id, + algorithm, hash_bytes): + self._simple_insert_txn( + txn, + "event_edge_hashes", + { + "event_id": event_id, + "prev_event_id": prev_event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) \ No newline at end of file diff --git a/synapse/storage/state.py b/synapse/storage/state.py new file mode 100644 index 0000000000..e08acd6404 --- /dev/null +++ b/synapse/storage/state.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from twisted.internet import defer + +from collections import namedtuple + + +StateGroup = namedtuple("StateGroup", ("group", "state")) + + +class StateStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_state_groups(self, event_ids): + groups = set() + for event_id in event_ids: + group = yield self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + ) + if group: + groups.add(group) + + res = [] + for group in groups: + state_ids = yield self._simple_select_onecol( + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + state = [] + for state_id in state_ids: + s = yield self.get_event( + state_id, + allow_none=True, + ) + if s: + state.append(s) + + res.append(StateGroup(group, state)) + + defer.returnValue(res) + + def store_state_groups(self, event): + return self.runInteraction( + "store_state_groups", + self._store_state_groups_txn, event + ) + + def _store_state_groups_txn(self, txn, event): + if not event.state_events: + return + + state_group = event.state_group + if not state_group: + state_group = self._simple_insert_txn( + txn, + table="state_groups", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + } + ) + + for state in event.state_events.values(): + self._simple_insert_txn( + txn, + table="state_groups_state", + values={ + "state_group": state_group, + "room_id": state.room_id, + "type": state.type, + "state_key": state.state_key, + "event_id": state.event_id, + } + ) + + self._simple_insert_txn( + txn, + table="event_to_state_groups", + values={ + "state_group": state_group, + "event_id": event.event_id, + } + ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d61f909939..8f7f61d29d 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -309,7 +309,10 @@ class StreamStore(SQLBaseStore): defer.returnValue(ret) def get_room_events_max_id(self): - return self.runInteraction(self._get_room_events_max_id_txn) + return self.runInteraction( + "get_room_events_max_id", + self._get_room_events_max_id_txn + ) def _get_room_events_max_id_txn(self, txn): txn.execute( diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 2ba8e30efe..00d0f48082 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -14,7 +14,6 @@ # limitations under the License. from ._base import SQLBaseStore, Table -from .pdu import PdusTable from collections import namedtuple @@ -42,6 +41,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( + "get_received_txn_response", self._get_received_txn_response, transaction_id, origin ) @@ -73,6 +73,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( + "set_received_txn_response", self._set_received_txn_response, transaction_id, origin, code, response_dict ) @@ -88,7 +89,7 @@ class TransactionStore(SQLBaseStore): txn.execute(query, (code, response_json, transaction_id, origin)) def prep_send_transaction(self, transaction_id, destination, - origin_server_ts, pdu_list): + origin_server_ts): """Persists an outgoing transaction and calculates the values for the previous transaction id list. @@ -99,19 +100,19 @@ class TransactionStore(SQLBaseStore): transaction_id (str) destination (str) origin_server_ts (int) - pdu_list (list) Returns: list: A list of previous transaction ids. """ return self.runInteraction( + "prep_send_transaction", self._prep_send_transaction, - transaction_id, destination, origin_server_ts, pdu_list + transaction_id, destination, origin_server_ts ) def _prep_send_transaction(self, txn, transaction_id, destination, - origin_server_ts, pdu_list): + origin_server_ts): # First we find out what the prev_txs should be. # Since we know that we are only sending one transaction at a time, @@ -139,15 +140,15 @@ class TransactionStore(SQLBaseStore): # Update the tx id -> pdu id mapping - values = [ - (transaction_id, destination, pdu[0], pdu[1]) - for pdu in pdu_list - ] - - logger.debug("Inserting: %s", repr(values)) - - query = TransactionsToPduTable.insert_statement() - txn.executemany(query, values) + # values = [ + # (transaction_id, destination, pdu[0], pdu[1]) + # for pdu in pdu_list + # ] + # + # logger.debug("Inserting: %s", repr(values)) + # + # query = TransactionsToPduTable.insert_statement() + # txn.executemany(query, values) return prev_txns @@ -161,6 +162,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ return self.runInteraction( + "delivered_txn", self._delivered_txn, transaction_id, destination, code, response_dict ) @@ -186,6 +188,7 @@ class TransactionStore(SQLBaseStore): list: A list of `ReceivedTransactionsTable.EntryType` """ return self.runInteraction( + "get_transactions_after", self._get_transactions_after, transaction_id, destination ) @@ -202,49 +205,6 @@ class TransactionStore(SQLBaseStore): return ReceivedTransactionsTable.decode_results(txn.fetchall()) - def get_pdus_after_transaction(self, transaction_id, destination): - """For a given local transaction_id that we sent to a given destination - home server, return a list of PDUs that were sent to that destination - after it. - - Args: - txn - transaction_id (str) - destination (str) - - Returns - list: A list of PduTuple - """ - return self.runInteraction( - self._get_pdus_after_transaction, - transaction_id, destination - ) - - def _get_pdus_after_transaction(self, txn, transaction_id, destination): - - # Query that first get's all transaction_ids with an id greater than - # the one given from the `sent_transactions` table. Then JOIN on this - # from the `tx->pdu` table to get a list of (pdu_id, origin) that - # specify the pdus that were sent in those transactions. - query = ( - "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp " - "INNER JOIN %(sent_tx)s as st " - "ON tp.transaction_id = st.transaction_id " - "AND tp.destination = st.destination " - "WHERE st.id > (" - "SELECT id FROM %(sent_tx)s " - "WHERE transaction_id = ? AND destination = ?" - ) % { - "tx_pdu": TransactionsToPduTable.table_name, - "sent_tx": SentTransactions.table_name, - } - - txn.execute(query, (transaction_id, destination)) - - pdus = PdusTable.decode_results(txn.fetchall()) - - return self._get_pdu_tuples(txn, pdus) - class ReceivedTransactionsTable(Table): table_name = "received_transactions" diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 6483ce2e25..527507e5cd 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -22,6 +22,19 @@ import logging logger = logging.getLogger(__name__) +class SourcePaginationConfig(object): + + """A configuration object which stores pagination parameters for a + specific event source.""" + + def __init__(self, from_key=None, to_key=None, direction='f', + limit=0): + self.from_key = from_key + self.to_key = to_key + self.direction = 'f' if direction == 'f' else 'b' + self.limit = int(limit) + + class PaginationConfig(object): """A configuration object which stores pagination parameters.""" @@ -82,3 +95,13 @@ class PaginationConfig(object): "<PaginationConfig from_tok=%s, to_tok=%s, " "direction=%s, limit=%s>" ) % (self.from_token, self.to_token, self.direction, self.limit) + + def get_source_config(self, source_name): + keyname = "%s_key" % source_name + + return SourcePaginationConfig( + from_key=getattr(self.from_token, keyname), + to_key=getattr(self.to_token, keyname) if self.to_token else None, + direction=self.direction, + limit=self.limit, + ) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 41715436b0..fb698d2d71 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -35,7 +35,7 @@ class NullSource(object): return defer.succeed(0) def get_pagination_rows(self, user, pagination_config, key): - return defer.succeed(([], pagination_config.from_token)) + return defer.succeed(([], pagination_config.from_key)) class EventSources(object): diff --git a/synapse/types.py b/synapse/types.py index c51bc8e4f2..649ff2f7d7 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -78,6 +78,11 @@ class DomainSpecificString( """Create a structure on the local domain""" return cls(localpart=localpart, domain=hs.hostname, is_mine=True) + @classmethod + def create(cls, localpart, domain, hs): + is_mine = domain == hs.hostname + return cls(localpart=localpart, domain=domain, is_mine=is_mine) + class UserID(DomainSpecificString): """Structure representing a user ID.""" @@ -94,6 +99,11 @@ class RoomID(DomainSpecificString): SIGIL = "!" +class EventID(DomainSpecificString): + """Structure representing an event id. """ + SIGIL = "$" + + class StreamToken( namedtuple( "Token", diff --git a/synapse/util/async.py b/synapse/util/async.py index 647ea6142c..bf578f8bfb 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -21,3 +21,10 @@ def sleep(seconds): d = defer.Deferred() reactor.callLater(seconds, d.callback, seconds) return d + + +def run_on_reactor(): + """ This will cause the rest of the function to be invoked upon the next + iteration of the main loop + """ + return sleep(0) \ No newline at end of file diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 1de50e049f..eddbe5837f 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -42,7 +42,8 @@ class Distributor(object): if name in self.signals: raise KeyError("%r already has a signal named %s" % (self, name)) - self.signals[name] = Signal(name, + self.signals[name] = Signal( + name, suppress_failures=self.suppress_failures, ) diff --git a/synapse/util/emailutils.py b/synapse/util/emailutils.py index cdb0abd7ea..7038cab6c2 100644 --- a/synapse/util/emailutils.py +++ b/synapse/util/emailutils.py @@ -42,8 +42,8 @@ def send_email(smtp_server, from_addr, to_addr, subject, body): EmailException if there was a problem sending the mail. """ if not smtp_server or not from_addr or not to_addr: - raise EmailException("Need SMTP server, from and to addresses. Check " + - "the config to set these.") + raise EmailException("Need SMTP server, from and to addresses. Check" + " the config to set these.") msg = MIMEMultipart('alternative') msg['Subject'] = subject @@ -68,4 +68,4 @@ def send_email(smtp_server, from_addr, to_addr, subject, body): twisted.python.log.err() ese = EmailException() ese.cause = origException - raise ese \ No newline at end of file + raise ese diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index 6c99705747..c91eb897a8 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - import copy + class JsonEncodedObject(object): """ A common base class for defining protocol units that are represented as JSON. @@ -89,6 +89,7 @@ class JsonEncodedObject(object): def __str__(self): return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) + def _encode(obj): if type(obj) is list: return [_encode(o) for o in obj] diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 1850deacf5..fdc2e8de4a 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -29,6 +29,7 @@ from synapse.server import HomeServer from synapse.api.constants import PresenceState from synapse.api.errors import SynapseError from synapse.handlers.presence import PresenceHandler, UserPresenceCache +from synapse.streams.config import SourcePaginationConfig OFFLINE = PresenceState.OFFLINE @@ -676,6 +677,21 @@ class PresencePushTestCase(unittest.TestCase): msg="Presence event should be visible to self-reflection" ) + config = SourcePaginationConfig(from_key=1, to_key=0) + (chunk, _) = yield self.event_source.get_pagination_rows( + self.u_apple, config, None + ) + self.assertEquals(chunk, + [ + {"type": "m.presence", + "content": { + "user_id": "@apple:test", + "presence": ONLINE, + "last_active_ago": 0, + }}, + ] + ) + # Banana sees it because of presence subscription (events, _) = yield self.event_source.get_new_events_for_user( self.u_banana, 0, None diff --git a/webclient/app-controller.js b/webclient/app-controller.js index 7d61207554..e4b7cd286f 100644 --- a/webclient/app-controller.js +++ b/webclient/app-controller.js @@ -53,7 +53,7 @@ angular.module('MatrixWebClientController', ['matrixService', 'mPresence', 'even * Open a given page. * @param {String} url url of the page */ - $scope.goToPage = function(url) { + $rootScope.goToPage = function(url) { $location.url(url); }; diff --git a/webclient/app-directive.js b/webclient/app-directive.js index 75283598ab..c1ba0af3a9 100644 --- a/webclient/app-directive.js +++ b/webclient/app-directive.js @@ -40,4 +40,45 @@ angular.module('matrixWebClient') } } }; -}]); \ No newline at end of file +}]) +.directive('asjson', function() { + return { + restrict: 'A', + require: 'ngModel', + link: function (scope, element, attrs, ngModelCtrl) { + function isValidJson(model) { + var flag = true; + try { + angular.fromJson(model); + } catch (err) { + flag = false; + } + return flag; + }; + + function string2JSON(text) { + try { + var j = angular.fromJson(text); + ngModelCtrl.$setValidity('json', true); + return j; + } catch (err) { + //returning undefined results in a parser error as of angular-1.3-rc.0, and will not go through $validators + //return undefined + ngModelCtrl.$setValidity('json', false); + return text; + } + }; + + function JSON2String(object) { + return angular.toJson(object, true); + }; + + //$validators is an object, where key is the error + //ngModelCtrl.$validators.json = isValidJson; + + //array pipelines + ngModelCtrl.$parsers.push(string2JSON); + ngModelCtrl.$formatters.push(JSON2String); + } + } +}); diff --git a/webclient/app-filter.js b/webclient/app-filter.js index 39ea1d637d..f19db4141d 100644 --- a/webclient/app-filter.js +++ b/webclient/app-filter.js @@ -76,6 +76,17 @@ angular.module('matrixWebClient') return filtered; }; }) +.filter('stateEventsFilter', function($sce) { + return function(events) { + var filtered = {}; + angular.forEach(events, function(value, key) { + if (value && typeof(value.state_key) === "string") { + filtered[key] = value; + } + }); + return filtered; + }; +}) .filter('unsafe', ['$sce', function($sce) { return function(text) { return $sce.trustAsHtml(text); diff --git a/webclient/app.css b/webclient/app.css index bdf475d635..5ab8e2b8fd 100755 --- a/webclient/app.css +++ b/webclient/app.css @@ -403,6 +403,7 @@ textarea, input { } .roomNameSection, .roomTopicSection { + text-align: right; float: right; width: 100%; } @@ -412,9 +413,40 @@ textarea, input { } .roomHeaderInfo { + text-align: right; float: right; margin-top: 15px; - width: 50%; +} + +/*** Room Info Dialog ***/ + +.room-info { + border-collapse: collapse; + width: 100%; +} + +.room-info-event { + border-bottom: 1pt solid black; +} + +.room-info-event-meta { + padding-top: 1em; + padding-bottom: 1em; +} + +.room-info-event-content { + padding-top: 1em; + padding-bottom: 1em; +} + +.monospace { + font-family: monospace; +} + +.room-info-textarea-content { + height: auto; + width: 100%; + resize: vertical; } /*** Participant list ***/ diff --git a/webclient/app.js b/webclient/app.js index 099e2170a0..c091f8c6cf 100644 --- a/webclient/app.js +++ b/webclient/app.js @@ -30,8 +30,10 @@ var matrixWebClient = angular.module('matrixWebClient', [ 'MatrixCall', 'eventStreamService', 'eventHandlerService', + 'notificationService', 'infinite-scroll', - 'ui.bootstrap' + 'ui.bootstrap', + 'monospaced.elastic' ]); matrixWebClient.config(['$routeProvider', '$provide', '$httpProvider', diff --git a/webclient/components/matrix/event-handler-service.js b/webclient/components/matrix/event-handler-service.js index e7109c0cb4..e63584510b 100644 --- a/webclient/components/matrix/event-handler-service.js +++ b/webclient/components/matrix/event-handler-service.js @@ -27,8 +27,8 @@ Typically, this service will store events or broadcast them to any listeners if typically all the $on method would do is update its own $scope. */ angular.module('eventHandlerService', []) -.factory('eventHandlerService', ['matrixService', '$rootScope', '$q', '$timeout', 'mPresence', -function(matrixService, $rootScope, $q, $timeout, mPresence) { +.factory('eventHandlerService', ['matrixService', '$rootScope', '$q', '$timeout', 'mPresence', 'notificationService', +function(matrixService, $rootScope, $q, $timeout, mPresence, notificationService) { var ROOM_CREATE_EVENT = "ROOM_CREATE_EVENT"; var MSG_EVENT = "MSG_EVENT"; var MEMBER_EVENT = "MEMBER_EVENT"; @@ -45,44 +45,6 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) { var eventMap = {}; $rootScope.presence = {}; - - // TODO: This is attached to the rootScope so .html can just go containsBingWord - // for determining classes so it is easy to highlight bing messages. It seems a - // bit strange to put the impl in this service though, but I can't think of a better - // file to put it in. - $rootScope.containsBingWord = function(content) { - if (!content || $.type(content) != "string") { - return false; - } - var bingWords = matrixService.config().bingWords; - var shouldBing = false; - - // case-insensitive name check for user_id OR display_name if they exist - var myUserId = matrixService.config().user_id; - if (myUserId) { - myUserId = myUserId.toLocaleLowerCase(); - } - var myDisplayName = matrixService.config().display_name; - if (myDisplayName) { - myDisplayName = myDisplayName.toLocaleLowerCase(); - } - if ( (myDisplayName && content.toLocaleLowerCase().indexOf(myDisplayName) != -1) || - (myUserId && content.toLocaleLowerCase().indexOf(myUserId) != -1) ) { - shouldBing = true; - } - - // bing word list check - if (bingWords && !shouldBing) { - for (var i=0; i<bingWords.length; i++) { - var re = RegExp(bingWords[i]); - if (content.search(re) != -1) { - shouldBing = true; - break; - } - } - } - return shouldBing; - }; var initialSyncDeferred; @@ -172,6 +134,17 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) { }; var handleMessage = function(event, isLiveEvent) { + // Check for empty event content + var hasContent = false; + for (var prop in event.content) { + hasContent = true; + break; + } + if (!hasContent) { + // empty json object is a redacted event, so ignore. + return; + } + if (isLiveEvent) { if (event.user_id === matrixService.config().user_id && (event.content.msgtype === "m.text" || event.content.msgtype === "m.emote") ) { @@ -190,7 +163,12 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) { } if (window.Notification && event.user_id != matrixService.config().user_id) { - var shouldBing = $rootScope.containsBingWord(event.content.body); + var shouldBing = notificationService.containsBingWord( + matrixService.config().user_id, + matrixService.config().display_name, + matrixService.config().bingWords, + event.content.body + ); // Ideally we would notify only when the window is hidden (i.e. document.hidden = true). // @@ -220,17 +198,29 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) { if (event.content.msgtype === "m.emote") { message = "* " + displayname + " " + message; } + else if (event.content.msgtype === "m.image") { + message = displayname + " sent an image."; + } + + var roomTitle = matrixService.getRoomIdToAliasMapping(event.room_id); + var theRoom = $rootScope.events.rooms[event.room_id]; + if (!roomTitle && theRoom && theRoom["m.room.name"] && theRoom["m.room.name"].content) { + roomTitle = theRoom["m.room.name"].content.name; + } - var notification = new window.Notification( - displayname + - " (" + (matrixService.getRoomIdToAliasMapping(event.room_id) || event.room_id) + ")", // FIXME: don't leak room_ids here - { - "body": message, - "icon": member ? member.avatar_url : undefined - }); - $timeout(function() { - notification.close(); - }, 5 * 1000); + if (!roomTitle) { + roomTitle = event.room_id; + } + + notificationService.showNotification( + displayname + " (" + roomTitle + ")", + message, + member ? member.avatar_url : undefined, + function() { + console.log("notification.onclick() room=" + event.room_id); + $rootScope.goToPage('room/' + event.room_id); + } + ); } } } @@ -319,6 +309,31 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) { $rootScope.events.rooms[event.room_id].messages.push(event); } }; + + var handleRedaction = function(event, isLiveEvent) { + if (!isLiveEvent) { + // we have nothing to remove, so just ignore it. + console.log("Received redacted event: "+JSON.stringify(event)); + return; + } + + // we need to remove something possibly: do we know the redacted + // event ID? + if (eventMap[event.redacts]) { + // remove event from list of messages in this room. + var eventList = $rootScope.events.rooms[event.room_id].messages; + for (var i=0; i<eventList.length; i++) { + if (eventList[i].event_id === event.redacts) { + console.log("Removing event " + event.redacts); + eventList.splice(i, 1); + break; + } + } + + // broadcast the redaction so controllers can nuke this + console.log("Redacted an event."); + } + } /** * Get the index of the event in $rootScope.events.rooms[room_id].messages @@ -481,7 +496,17 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) { case 'm.room.topic': handleRoomTopic(event, isLiveEvent, isStateEvent); break; + case 'm.room.redaction': + handleRedaction(event, isLiveEvent); + break; default: + // if it is a state event, then just add it in so it + // displays on the Room Info screen. + if (typeof(event.state_key) === "string") { // incls. 0-len strings + if (event.room_id) { + handleRoomDateEvent(event, isLiveEvent, false); + } + } console.log("Unable to handle event type " + event.type); console.log(JSON.stringify(event, undefined, 4)); break; diff --git a/webclient/components/matrix/matrix-filter.js b/webclient/components/matrix/matrix-filter.js index e6f2acc5fd..3d64a569a1 100644 --- a/webclient/components/matrix/matrix-filter.js +++ b/webclient/components/matrix/matrix-filter.js @@ -47,7 +47,6 @@ angular.module('matrixFilter', []) else if (room.members && !isPublicRoom) { // Do not rename public room var user_id = matrixService.config().user_id; - // Else, build the name from its users // Limit the room renaming to 1:1 room if (2 === Object.keys(room.members).length) { @@ -65,8 +64,16 @@ angular.module('matrixFilter', []) var otherUserId; - if (Object.keys(room.members)[0] && Object.keys(room.members)[0] !== user_id) { + if (Object.keys(room.members)[0]) { otherUserId = Object.keys(room.members)[0]; + // this could be an invite event (from event stream) + if (otherUserId === user_id && + room.members[user_id].content.membership === "invite") { + // this is us being invited to this room, so the + // *user_id* is the other user ID and not the state + // key. + otherUserId = room.members[user_id].user_id; + } } else { // it's got to be an invite, or failing that a self-chat; diff --git a/webclient/components/matrix/matrix-service.js b/webclient/components/matrix/matrix-service.js index a4f0568bce..1840cf46c0 100644 --- a/webclient/components/matrix/matrix-service.js +++ b/webclient/components/matrix/matrix-service.js @@ -438,6 +438,14 @@ angular.module('matrixService', []) return this.sendMessage(room_id, msg_id, content); }, + redactEvent: function(room_id, event_id) { + var path = "/rooms/$room_id/redact/$event_id"; + path = path.replace("$room_id", room_id); + path = path.replace("$event_id", event_id); + var content = {}; + return doRequest("POST", path, undefined, content); + }, + // get a snapshot of the members in a room. getMemberList: function(room_id) { // Like the cmd client, escape room ids diff --git a/webclient/components/matrix/notification-service.js b/webclient/components/matrix/notification-service.js new file mode 100644 index 0000000000..9a911413c3 --- /dev/null +++ b/webclient/components/matrix/notification-service.js @@ -0,0 +1,104 @@ +/* +Copyright 2014 OpenMarket Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +'use strict'; + +/* +This service manages notifications: enabling, creating and showing them. This +also contains 'bing word' logic. +*/ +angular.module('notificationService', []) +.factory('notificationService', ['$timeout', function($timeout) { + + var getLocalPartFromUserId = function(user_id) { + if (!user_id) { + return null; + } + var localpartRegex = /@(.*):\w+/i + var results = localpartRegex.exec(user_id); + if (results && results.length == 2) { + return results[1]; + } + return null; + }; + + return { + + containsBingWord: function(userId, displayName, bingWords, content) { + // case-insensitive name check for user_id OR display_name if they exist + var userRegex = ""; + if (userId) { + var localpart = getLocalPartFromUserId(userId); + if (localpart) { + localpart = localpart.toLocaleLowerCase(); + userRegex += "\\b" + localpart + "\\b"; + } + } + if (displayName) { + displayName = displayName.toLocaleLowerCase(); + if (userRegex.length > 0) { + userRegex += "|"; + } + userRegex += "\\b" + displayName + "\\b"; + } + + var regexList = [new RegExp(userRegex, 'i')]; + + // bing word list check + if (bingWords && bingWords.length > 0) { + for (var i=0; i<bingWords.length; i++) { + var re = RegExp(bingWords[i], 'i'); + regexList.push(re); + } + } + return this.hasMatch(regexList, content); + }, + + hasMatch: function(regExps, content) { + if (!content || $.type(content) != "string") { + return false; + } + + if (regExps && regExps.length > 0) { + for (var i=0; i<regExps.length; i++) { + if (content.search(regExps[i]) != -1) { + return true; + } + } + } + return false; + }, + + showNotification: function(title, body, icon, onclick) { + var notification = new window.Notification( + title, + { + "body": body, + "icon": icon + } + ); + + if (onclick) { + notification.onclick = onclick; + } + + $timeout(function() { + notification.close(); + }, 5 * 1000); + } + }; + +}]); diff --git a/webclient/index.html b/webclient/index.html index 35c8051298..bc011a6c72 100644 --- a/webclient/index.html +++ b/webclient/index.html @@ -20,6 +20,7 @@ <script type='text/javascript' src="js/ui-bootstrap-tpls-0.11.2.js"></script> <script type='text/javascript' src='js/ng-infinite-scroll-matrix.js'></script> <script type='text/javascript' src='js/autofill-event.js'></script> + <script type='text/javascript' src='js/elastic.js'></script> <script src="app.js"></script> <script src="config.js"></script> <script src="app-controller.js"></script> @@ -40,6 +41,7 @@ <script src="components/matrix/matrix-phone-service.js"></script> <script src="components/matrix/event-stream-service.js"></script> <script src="components/matrix/event-handler-service.js"></script> + <script src="components/matrix/notification-service.js"></script> <script src="components/matrix/presence-service.js"></script> <script src="components/fileInput/file-input-directive.js"></script> <script src="components/fileUpload/file-upload-service.js"></script> diff --git a/webclient/js/elastic.js b/webclient/js/elastic.js new file mode 100644 index 0000000000..d585d81109 --- /dev/null +++ b/webclient/js/elastic.js @@ -0,0 +1,216 @@ +/* + * angular-elastic v2.4.0 + * (c) 2014 Monospaced http://monospaced.com + * License: MIT + */ + +angular.module('monospaced.elastic', []) + + .constant('msdElasticConfig', { + append: '' + }) + + .directive('msdElastic', [ + '$timeout', '$window', 'msdElasticConfig', + function($timeout, $window, config) { + 'use strict'; + + return { + require: 'ngModel', + restrict: 'A, C', + link: function(scope, element, attrs, ngModel) { + + // cache a reference to the DOM element + var ta = element[0], + $ta = element; + + // ensure the element is a textarea, and browser is capable + if (ta.nodeName !== 'TEXTAREA' || !$window.getComputedStyle) { + return; + } + + // set these properties before measuring dimensions + $ta.css({ + 'overflow': 'hidden', + 'overflow-y': 'hidden', + 'word-wrap': 'break-word' + }); + + // force text reflow + var text = ta.value; + ta.value = ''; + ta.value = text; + + var append = attrs.msdElastic ? attrs.msdElastic.replace(/\\n/g, '\n') : config.append, + $win = angular.element($window), + mirrorInitStyle = 'position: absolute; top: -999px; right: auto; bottom: auto;' + + 'left: 0; overflow: hidden; -webkit-box-sizing: content-box;' + + '-moz-box-sizing: content-box; box-sizing: content-box;' + + 'min-height: 0 !important; height: 0 !important; padding: 0;' + + 'word-wrap: break-word; border: 0;', + $mirror = angular.element('<textarea tabindex="-1" ' + + 'style="' + mirrorInitStyle + '"/>').data('elastic', true), + mirror = $mirror[0], + taStyle = getComputedStyle(ta), + resize = taStyle.getPropertyValue('resize'), + borderBox = taStyle.getPropertyValue('box-sizing') === 'border-box' || + taStyle.getPropertyValue('-moz-box-sizing') === 'border-box' || + taStyle.getPropertyValue('-webkit-box-sizing') === 'border-box', + boxOuter = !borderBox ? {width: 0, height: 0} : { + width: parseInt(taStyle.getPropertyValue('border-right-width'), 10) + + parseInt(taStyle.getPropertyValue('padding-right'), 10) + + parseInt(taStyle.getPropertyValue('padding-left'), 10) + + parseInt(taStyle.getPropertyValue('border-left-width'), 10), + height: parseInt(taStyle.getPropertyValue('border-top-width'), 10) + + parseInt(taStyle.getPropertyValue('padding-top'), 10) + + parseInt(taStyle.getPropertyValue('padding-bottom'), 10) + + parseInt(taStyle.getPropertyValue('border-bottom-width'), 10) + }, + minHeightValue = parseInt(taStyle.getPropertyValue('min-height'), 10), + heightValue = parseInt(taStyle.getPropertyValue('height'), 10), + minHeight = Math.max(minHeightValue, heightValue) - boxOuter.height, + maxHeight = parseInt(taStyle.getPropertyValue('max-height'), 10), + mirrored, + active, + copyStyle = ['font-family', + 'font-size', + 'font-weight', + 'font-style', + 'letter-spacing', + 'line-height', + 'text-transform', + 'word-spacing', + 'text-indent']; + + // exit if elastic already applied (or is the mirror element) + if ($ta.data('elastic')) { + return; + } + + // Opera returns max-height of -1 if not set + maxHeight = maxHeight && maxHeight > 0 ? maxHeight : 9e4; + + // append mirror to the DOM + if (mirror.parentNode !== document.body) { + angular.element(document.body).append(mirror); + } + + // set resize and apply elastic + $ta.css({ + 'resize': (resize === 'none' || resize === 'vertical') ? 'none' : 'horizontal' + }).data('elastic', true); + + /* + * methods + */ + + function initMirror() { + var mirrorStyle = mirrorInitStyle; + + mirrored = ta; + // copy the essential styles from the textarea to the mirror + taStyle = getComputedStyle(ta); + angular.forEach(copyStyle, function(val) { + mirrorStyle += val + ':' + taStyle.getPropertyValue(val) + ';'; + }); + mirror.setAttribute('style', mirrorStyle); + } + + function adjust() { + var taHeight, + taComputedStyleWidth, + mirrorHeight, + width, + overflow; + + if (mirrored !== ta) { + initMirror(); + } + + // active flag prevents actions in function from calling adjust again + if (!active) { + active = true; + + mirror.value = ta.value + append; // optional whitespace to improve animation + mirror.style.overflowY = ta.style.overflowY; + + taHeight = ta.style.height === '' ? 'auto' : parseInt(ta.style.height, 10); + + taComputedStyleWidth = getComputedStyle(ta).getPropertyValue('width'); + + // ensure getComputedStyle has returned a readable 'used value' pixel width + if (taComputedStyleWidth.substr(taComputedStyleWidth.length - 2, 2) === 'px') { + // update mirror width in case the textarea width has changed + width = parseInt(taComputedStyleWidth, 10) - boxOuter.width; + mirror.style.width = width + 'px'; + } + + mirrorHeight = mirror.scrollHeight; + + if (mirrorHeight > maxHeight) { + mirrorHeight = maxHeight; + overflow = 'scroll'; + } else if (mirrorHeight < minHeight) { + mirrorHeight = minHeight; + } + mirrorHeight += boxOuter.height; + + ta.style.overflowY = overflow || 'hidden'; + + if (taHeight !== mirrorHeight) { + ta.style.height = mirrorHeight + 'px'; + scope.$emit('elastic:resize', $ta); + } + + // small delay to prevent an infinite loop + $timeout(function() { + active = false; + }, 1); + + } + } + + function forceAdjust() { + active = false; + adjust(); + } + + /* + * initialise + */ + + // listen + if ('onpropertychange' in ta && 'oninput' in ta) { + // IE9 + ta['oninput'] = ta.onkeyup = adjust; + } else { + ta['oninput'] = adjust; + } + + $win.bind('resize', forceAdjust); + + scope.$watch(function() { + return ngModel.$modelValue; + }, function(newValue) { + forceAdjust(); + }); + + scope.$on('elastic:adjust', function() { + initMirror(); + forceAdjust(); + }); + + $timeout(adjust); + + /* + * destroy + */ + + scope.$on('$destroy', function() { + $mirror.remove(); + $win.unbind('resize', forceAdjust); + }); + } + }; + } + ]); diff --git a/webclient/mobile.css b/webclient/mobile.css index 7c62a072d5..6fa9221ccf 100644 --- a/webclient/mobile.css +++ b/webclient/mobile.css @@ -65,13 +65,16 @@ } #roomName { - float: left; - font-size: 14px ! important; + font-size: 12px ! important; margin-top: 0px ! important; } + + .roomTopicSection { + display: none; + } #roomPage { - top: 35px ! important; + top: 40px ! important; left: 5px ! important; right: 5px ! important; bottom: 70px ! important; diff --git a/webclient/room/room-controller.js b/webclient/room/room-controller.js index 78520a829d..486ead0da9 100644 --- a/webclient/room/room-controller.js +++ b/webclient/room/room-controller.js @@ -15,11 +15,21 @@ limitations under the License. */ angular.module('RoomController', ['ngSanitize', 'matrixFilter', 'mFileInput']) -.controller('RoomController', ['$modal', '$filter', '$scope', '$timeout', '$routeParams', '$location', '$rootScope', 'matrixService', 'mPresence', 'eventHandlerService', 'mFileUpload', 'matrixPhoneService', 'MatrixCall', - function($modal, $filter, $scope, $timeout, $routeParams, $location, $rootScope, matrixService, mPresence, eventHandlerService, mFileUpload, matrixPhoneService, MatrixCall) { +.controller('RoomController', ['$modal', '$filter', '$scope', '$timeout', '$routeParams', '$location', '$rootScope', 'matrixService', 'mPresence', 'eventHandlerService', 'mFileUpload', 'matrixPhoneService', 'MatrixCall', 'notificationService', + function($modal, $filter, $scope, $timeout, $routeParams, $location, $rootScope, matrixService, mPresence, eventHandlerService, mFileUpload, matrixPhoneService, MatrixCall, notificationService) { 'use strict'; var MESSAGES_PER_PAGINATION = 30; var THUMBNAIL_SIZE = 320; + + // .html needs this + $scope.containsBingWord = function(content) { + return notificationService.containsBingWord( + matrixService.config().user_id, + matrixService.config().display_name, + matrixService.config().bingWords, + content + ); + }; // Room ids. Computed and resolved in onInit $scope.room_id = undefined; @@ -133,7 +143,9 @@ angular.module('RoomController', ['ngSanitize', 'matrixFilter', 'mFileInput']) // Do not autoscroll to the bottom to display the new event if the user is not at the bottom. // Exception: in case where the event is from the user, we want to force scroll to the bottom var objDiv = document.getElementById("messageTableWrapper"); - if ((objDiv.offsetHeight + objDiv.scrollTop >= objDiv.scrollHeight) || force) { + // add a 10px buffer to this check so if the message list is not *quite* + // at the bottom it still scrolls since it basically is at the bottom. + if ((10 + objDiv.offsetHeight + objDiv.scrollTop >= objDiv.scrollHeight) || force) { $timeout(function() { objDiv.scrollTop = objDiv.scrollHeight; @@ -189,16 +201,20 @@ angular.module('RoomController', ['ngSanitize', 'matrixFilter', 'mFileInput']) // Notify when a user joins if ((document.hidden || matrixService.presence.unavailable === mPresence.getState()) && event.state_key !== $scope.state.user_id && "join" === event.membership) { - var notification = new window.Notification( - event.content.displayname + - " (" + (matrixService.getRoomIdToAliasMapping(event.room_id) || event.room_id) + ")", // FIXME: don't leak room_ids here - { - "body": event.content.displayname + " joined", - "icon": event.content.avatar_url ? event.content.avatar_url : undefined - }); - $timeout(function() { - notification.close(); - }, 5 * 1000); + var userName = event.content.displayname; + if (!userName) { + userName = event.state_key; + } + notificationService.showNotification( + userName + + " (" + (matrixService.getRoomIdToAliasMapping(event.room_id) || event.room_id) + ")", + userName + " joined", + event.content.avatar_url ? event.content.avatar_url : undefined, + function() { + console.log("notification.onclick() room=" + event.room_id); + $rootScope.goToPage('room/' + event.room_id); + } + ); } } } @@ -983,10 +999,87 @@ angular.module('RoomController', ['ngSanitize', 'matrixFilter', 'mFileInput']) }; $scope.openJson = function(content) { - console.log("Displaying modal dialog for " + JSON.stringify(content)); + $scope.event_selected = content; + // scope this so the template can check power levels and enable/disable + // buttons + $scope.pow = matrixService.getUserPowerLevel; + + var modalInstance = $modal.open({ + templateUrl: 'eventInfoTemplate.html', + controller: 'EventInfoController', + scope: $scope + }); + + modalInstance.result.then(function(action) { + if (action === "redact") { + var eventId = $scope.event_selected.event_id; + console.log("Redacting event ID " + eventId); + matrixService.redactEvent( + $scope.event_selected.room_id, + eventId + ).then(function(response) { + console.log("Redaction = " + JSON.stringify(response)); + }, function(error) { + console.error("Failed to redact event: "+JSON.stringify(error)); + if (error.data.error) { + $scope.feedback = error.data.error; + } + }); + } + }, function() { + // any dismiss code + }); + }; + + $scope.openRoomInfo = function() { + $scope.roomInfo = {}; + $scope.roomInfo.newEvent = { + content: {}, + type: "", + state_key: "" + }; + + var stateFilter = $filter("stateEventsFilter"); + var stateEvents = stateFilter($scope.events.rooms[$scope.room_id]); + // The modal dialog will 2-way bind this field, so we MUST make a deep + // copy of the state events else we will be *actually adjusing our view + // of the world* when fiddling with the JSON!! Apparently parse/stringify + // is faster than jQuery's extend when doing deep copies. + $scope.roomInfo.stateEvents = JSON.parse(JSON.stringify(stateEvents)); var modalInstance = $modal.open({ - template: "<pre>" + angular.toJson(content, true) + "</pre>" + templateUrl: 'roomInfoTemplate.html', + controller: 'RoomInfoController', + size: 'lg', + scope: $scope }); }; -}]); +}]) +.controller('EventInfoController', function($scope, $modalInstance) { + console.log("Displaying modal dialog for >>>> " + JSON.stringify($scope.event_selected)); + $scope.redact = function() { + console.log("User level = "+$scope.pow($scope.room_id, $scope.state.user_id)+ + " Redact level = "+$scope.events.rooms[$scope.room_id]["m.room.ops_levels"].content.redact_level); + console.log("Redact event >> " + JSON.stringify($scope.event_selected)); + $modalInstance.close("redact"); + }; +}) +.controller('RoomInfoController', function($scope, $modalInstance, $filter, matrixService) { + console.log("Displaying room info."); + + $scope.submit = function(event) { + if (event.content) { + console.log("submit >>> " + JSON.stringify(event.content)); + matrixService.sendStateEvent($scope.room_id, event.type, + event.content, event.state_key).then(function(response) { + $modalInstance.dismiss(); + }, function(err) { + $scope.feedback = err.data.error; + } + ); + } + }; + + $scope.dismiss = $modalInstance.dismiss; + +}); diff --git a/webclient/room/room.html b/webclient/room/room.html index e753b037fe..5265f42dd8 100644 --- a/webclient/room/room.html +++ b/webclient/room/room.html @@ -1,5 +1,59 @@ <div ng-controller="RoomController" data-ng-init="onInit()" class="room" style="height: 100%;"> + <script type="text/ng-template" id="eventInfoTemplate.html"> + <div class="modal-body"> + <pre> {{event_selected | json}} </pre> + </div> + <div class="modal-footer"> + <button ng-click="redact()" type="button" class="btn btn-danger" + ng-disabled="!events.rooms[room_id]['m.room.ops_levels'].content.redact_level || !pow(room_id, state.user_id) || pow(room_id, state.user_id) < events.rooms[room_id]['m.room.ops_levels'].content.redact_level" + title="Delete this event on all home servers. This cannot be undone."> + Redact + </button> + </div> + </script> + + <script type="text/ng-template" id="roomInfoTemplate.html"> + <div class="modal-body"> + <table class="room-info"> + <tr ng-repeat="(key, event) in roomInfo.stateEvents" class="room-info-event"> + <td class="room-info-event-meta" width="30%"> + <span class="monospace">{{ key }}</span> + <br/> + {{ (event.origin_server_ts) | date:'MMM d HH:mm' }} + <br/> + Set by: <span class="monospace">{{ event.user_id }}</span> + <br/> + <span ng-show="event.required_power_level >= 0">Required power level: {{event.required_power_level}}<br/></span> + <button ng-click="submit(event)" type="button" class="btn btn-success" ng-disabled="!event.content"> + Submit + </button> + </td> + <td class="room-info-event-content" width="70%"> + <textarea class="room-info-textarea-content" msd-elastic ng-model="event.content" asjson></textarea> + </td> + </tr> + <tr> + <td class="room-info-event-meta" width="30%"> + <input ng-model="roomInfo.newEvent.type" placeholder="your.event.type" /> + <br/> + <button ng-click="submit(roomInfo.newEvent)" type="button" class="btn btn-success" ng-disabled="!roomInfo.newEvent.content || !roomInfo.newEvent.type"> + Submit + </button> + </td> + <td class="room-info-event-content" width="70%"> + <textarea class="room-info-textarea-content" msd-elastic ng-model="roomInfo.newEvent.content" asjson></textarea> + </td> + </tr> + </table> + </div> + <div class="modal-footer"> + <button ng-click="dismiss()" type="button" class="btn"> + Close + </button> + </div> + </script> + <div id="roomHeader"> <a href ng-click="goToPage('/')"><img src="img/logo-small.png" width="100" height="43" alt="[matrix]"/></a> <div class="roomHeaderInfo"> @@ -79,15 +133,15 @@ </div> </td> <td class="avatar"> - <img class="avatarImage" ng-src="{{ members[msg.user_id].avatar_url || 'img/default-profile.png' }}" width="32" height="32" + <img class="avatarImage" ng-src="{{ members[msg.user_id].avatar_url || 'img/default-profile.png' }}" width="32" height="32" title="{{msg.user_id}}" ng-hide="events.rooms[room_id].messages[$index - 1].user_id === msg.user_id || msg.user_id === state.user_id"/> </td> <td ng-class="(!msg.content.membership && ('m.room.topic' !== msg.type && 'm.room.name' !== msg.type))? (msg.content.msgtype === 'm.emote' ? 'emote text' : 'text') : 'membership text'"> - <div class="bubble"> - <span ng-if="'join' === msg.content.membership && msg.changedKey === 'membership'" ng-click="openJson(msg)"> + <div class="bubble" ng-click="openJson(msg)"> + <span ng-if="'join' === msg.content.membership && msg.changedKey === 'membership'"> {{ members[msg.state_key].displayname || msg.state_key }} joined </span> - <span ng-if="'leave' === msg.content.membership && msg.changedKey === 'membership'" ng-click="openJson(msg)"> + <span ng-if="'leave' === msg.content.membership && msg.changedKey === 'membership'"> <span ng-if="msg.user_id === msg.state_key"> {{ members[msg.state_key].displayname || msg.state_key }} left </span> @@ -101,7 +155,7 @@ </span> </span> <span ng-if="'invite' === msg.content.membership && msg.changedKey === 'membership' || - 'ban' === msg.content.membership && msg.changedKey === 'membership'" ng-click="openJson(msg)"> + 'ban' === msg.content.membership && msg.changedKey === 'membership'"> {{ members[msg.user_id].displayname || msg.user_id }} {{ {"invite": "invited", "ban": "banned"}[msg.content.membership] }} {{ members[msg.state_key].displayname || msg.state_key }} @@ -109,25 +163,24 @@ : {{ msg.content.reason }} </span> </span> - <span ng-if="msg.changedKey === 'displayname'" ng-click="openJson(msg)"> + <span ng-if="msg.changedKey === 'displayname'"> {{ msg.user_id }} changed their display name from {{ msg.prev_content.displayname }} to {{ msg.content.displayname }} </span> <span ng-show='msg.content.msgtype === "m.emote"' ng-class="msg.echo_msg_state" ng-bind-html="'* ' + (members[msg.user_id].displayname || msg.user_id) + ' ' + msg.content.body | linky:'_blank'" - ng-click="openJson(msg)"/> + /> <span ng-show='msg.content.msgtype === "m.text"' class="message" - ng-click="openJson(msg)" ng-class="containsBingWord(msg.content.body) && msg.user_id != state.user_id ? msg.echo_msg_state + ' messageBing' : msg.echo_msg_state" ng-bind-html="(msg.content.msgtype === 'm.text' && msg.type === 'm.room.message' && msg.content.format === 'org.matrix.custom.html') ? (msg.content.formatted_body | unsanitizedLinky) : (msg.content.msgtype === 'm.text' && msg.type === 'm.room.message') ? (msg.content.body | linky:'_blank') : '' "/> - <span ng-show='msg.type === "m.call.invite" && msg.user_id == state.user_id' ng-click="openJson(msg)">Outgoing Call{{ isWebRTCSupported ? '' : ' (But your browser does not support VoIP)' }}</span> - <span ng-show='msg.type === "m.call.invite" && msg.user_id != state.user_id' ng-click="openJson(msg)">Incoming Call{{ isWebRTCSupported ? '' : ' (But your browser does not support VoIP)' }}</span> + <span ng-show='msg.type === "m.call.invite" && msg.user_id == state.user_id'>Outgoing Call{{ isWebRTCSupported ? '' : ' (But your browser does not support VoIP)' }}</span> + <span ng-show='msg.type === "m.call.invite" && msg.user_id != state.user_id'>Incoming Call{{ isWebRTCSupported ? '' : ' (But your browser does not support VoIP)' }}</span> <div ng-show='msg.content.msgtype === "m.image"'> <div ng-hide='msg.content.thumbnail_url' ng-style="msg.content.body.h && { 'height' : (msg.content.body.h < 320) ? msg.content.body.h : 320}"> @@ -135,15 +188,15 @@ </div> <div ng-show='msg.content.thumbnail_url' ng-style="{ 'height' : msg.content.thumbnail_info.h }"> <img class="image mouse-pointer" ng-src="{{ msg.content.thumbnail_url }}" - ng-click="$parent.fullScreenImageURL = msg.content.url"/> + ng-click="$parent.fullScreenImageURL = msg.content.url; $event.stopPropagation();"/> </div> </div> - <span ng-if="'m.room.topic' === msg.type" ng-click="openJson(msg)"> + <span ng-if="'m.room.topic' === msg.type"> {{ members[msg.user_id].displayname || msg.user_id }} changed the topic to: {{ msg.content.topic }} </span> - <span ng-if="'m.room.name' === msg.type" ng-click="openJson(msg)"> + <span ng-if="'m.room.name' === msg.type"> {{ members[msg.user_id].displayname || msg.user_id }} changed the room name to: {{ msg.content.name }} </span> @@ -204,6 +257,9 @@ > Video Call </button> + <button ng-click="openRoomInfo()"> + Room Info + </button> </div> {{ feedback }} |