diff options
author | Mark Haines <mark.haines@matrix.org> | 2014-09-22 18:54:00 +0100 |
---|---|---|
committer | Mark Haines <mark.haines@matrix.org> | 2014-09-22 18:54:00 +0100 |
commit | 09d79b0a9bf7a194383830d2e55530c70f2366b6 (patch) | |
tree | 76573bac3ca48deeca6cd33f91ed2ee3408dffb2 /synapse | |
parent | SYN-39: Add documentation explaining how to check a signature (diff) | |
parent | Show display name changes in the message list. (diff) | |
download | synapse-09d79b0a9bf7a194383830d2e55530c70f2366b6.tar.xz |
Merge branch 'develop' into server2server_signing
Diffstat (limited to 'synapse')
44 files changed, 1359 insertions, 428 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 440e633966..bba551b2c4 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a synapse home server. """ -__version__ = "0.2.1" +__version__ = "0.3.3" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b4eda3df01..8f32191b57 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -18,8 +18,8 @@ from twisted.internet import defer from synapse.api.constants import Membership, JoinRules -from synapse.api.errors import AuthError, StoreError, Codes -from synapse.api.events.room import RoomMemberEvent +from synapse.api.errors import AuthError, StoreError, Codes, SynapseError +from synapse.api.events.room import RoomMemberEvent, RoomPowerLevelsEvent from synapse.util.logutils import log_function import logging @@ -67,6 +67,9 @@ class Auth(object): else: yield self._can_send_event(event) + if event.type == RoomPowerLevelsEvent.TYPE: + yield self._check_power_levels(event) + defer.returnValue(True) else: raise AuthError(500, "Unknown event: %s" % event) @@ -172,7 +175,7 @@ class Auth(object): if kick_level: kick_level = int(kick_level) else: - kick_level = 5 + kick_level = 50 if user_level < kick_level: raise AuthError( @@ -189,7 +192,7 @@ class Auth(object): if ban_level: ban_level = int(ban_level) else: - ban_level = 5 # FIXME (erikj): What should we do here? + ban_level = 50 # FIXME (erikj): What should we do here? if user_level < ban_level: raise AuthError(403, "You don't have permission to ban") @@ -305,7 +308,9 @@ class Auth(object): else: user_level = 0 - logger.debug("Checking power level for %s, %s", event.user_id, user_level) + logger.debug( + "Checking power level for %s, %s", event.user_id, user_level + ) if current_state and hasattr(current_state, "required_power_level"): req = current_state.required_power_level @@ -315,3 +320,101 @@ class Auth(object): 403, "You don't have permission to change that state" ) + + @defer.inlineCallbacks + def _check_power_levels(self, event): + for k, v in event.content.items(): + if k == "default": + continue + + # FIXME (erikj): We don't want hsob_Ts in content. + if k == "hsob_ts": + continue + + try: + self.hs.parse_userid(k) + except: + raise SynapseError(400, "Not a valid user_id: %s" % (k,)) + + try: + int(v) + except: + raise SynapseError(400, "Not a valid power level: %s" % (v,)) + + current_state = yield self.store.get_current_state( + event.room_id, + event.type, + event.state_key, + ) + + if not current_state: + return + else: + current_state = current_state[0] + + user_level = yield self.store.get_power_level( + event.room_id, + event.user_id, + ) + + if user_level: + user_level = int(user_level) + else: + user_level = 0 + + old_list = current_state.content + + # FIXME (erikj) + old_people = {k: v for k, v in old_list.items() if k.startswith("@")} + new_people = { + k: v for k, v in event.content.items() + if k.startswith("@") + } + + removed = set(old_people.keys()) - set(new_people.keys()) + added = set(old_people.keys()) - set(new_people.keys()) + same = set(old_people.keys()) & set(new_people.keys()) + + for r in removed: + if int(old_list.content[r]) > user_level: + raise AuthError( + 403, + "You don't have permission to remove user: %s" % (r, ) + ) + + for n in added: + if int(event.content[n]) > user_level: + raise AuthError( + 403, + "You don't have permission to add ops level greater " + "than your own" + ) + + for s in same: + if int(event.content[s]) != int(old_list[s]): + if int(event.content[s]) > user_level: + raise AuthError( + 403, + "You don't have permission to add ops level greater " + "than your own" + ) + + if "default" in old_list: + old_default = int(old_list["default"]) + + if old_default > user_level: + raise AuthError( + 403, + "You don't have permission to add ops level greater than " + "your own" + ) + + if "default" in event.content: + new_default = int(event.content["default"]) + + if new_default > user_level: + raise AuthError( + 403, + "You don't have permission to add ops level greater " + "than your own" + ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index fcef062fc9..618d3d7577 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -50,3 +50,12 @@ class JoinRules(object): KNOCK = u"knock" INVITE = u"invite" PRIVATE = u"private" + + +class LoginType(object): + PASSWORD = u"m.login.password" + OAUTH = u"m.login.oauth2" + EMAIL_CODE = u"m.login.email.code" + EMAIL_URL = u"m.login.email.url" + EMAIL_IDENTITY = u"m.login.email.identity" + RECAPTCHA = u"m.login.recaptcha" \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 84afe4fa37..88175602c4 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -29,6 +29,8 @@ class Codes(object): NOT_FOUND = "M_NOT_FOUND" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" + CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" + CAPTCHA_INVALID = "M_CAPTCHA_INVALID" class CodeMessageException(Exception): @@ -101,6 +103,19 @@ class StoreError(SynapseError): pass +class InvalidCaptchaError(SynapseError): + def __init__(self, code=400, msg="Invalid captcha.", error_url=None, + errcode=Codes.CAPTCHA_INVALID): + super(InvalidCaptchaError, self).__init__(code, msg, errcode) + self.error_url = error_url + + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + error_url=self.error_url, + ) + class LimitExceededError(SynapseError): """A client has sent too many requests and is being throttled. """ diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py index f95468fc65..0cee196851 100644 --- a/synapse/api/events/__init__.py +++ b/synapse/api/events/__init__.py @@ -17,6 +17,19 @@ from synapse.api.errors import SynapseError, Codes from synapse.util.jsonobject import JsonEncodedObject +def serialize_event(hs, e): + # FIXME(erikj): To handle the case of presence events and the like + if not isinstance(e, SynapseEvent): + return e + + d = e.get_dict() + if "age_ts" in d: + d["age"] = int(hs.get_clock().time_msec()) - d["age_ts"] + del d["age_ts"] + + return d + + class SynapseEvent(JsonEncodedObject): """Base class for Synapse events. These are JSON objects which must abide @@ -43,6 +56,8 @@ class SynapseEvent(JsonEncodedObject): "content", # HTTP body, JSON "state_key", "required_power_level", + "age_ts", + "prev_content", ] internal_keys = [ @@ -141,7 +156,8 @@ class SynapseEvent(JsonEncodedObject): return "Missing %s key" % key if type(content[key]) != type(template[key]): - return "Key %s is of the wrong type." % key + return "Key %s is of the wrong type (got %s, want %s)" % ( + key, type(content[key]), type(template[key])) if type(content[key]) == dict: # we must go deeper @@ -157,7 +173,8 @@ class SynapseEvent(JsonEncodedObject): class SynapseStateEvent(SynapseEvent): - def __init__(self, **kwargs): + + def __init__(self, **kwargs): if "state_key" not in kwargs: kwargs["state_key"] = "" super(SynapseStateEvent, self).__init__(**kwargs) diff --git a/synapse/api/events/factory.py b/synapse/api/events/factory.py index a3b293e024..d3d96d73eb 100644 --- a/synapse/api/events/factory.py +++ b/synapse/api/events/factory.py @@ -47,15 +47,26 @@ class EventFactory(object): self._event_list[event_class.TYPE] = event_class self.clock = hs.get_clock() + self.hs = hs def create_event(self, etype=None, **kwargs): kwargs["type"] = etype if "event_id" not in kwargs: - kwargs["event_id"] = random_string(10) + kwargs["event_id"] = "%s@%s" % ( + random_string(10), self.hs.hostname + ) if "ts" not in kwargs: kwargs["ts"] = int(self.clock.time_msec()) + # The "age" key is a delta timestamp that should be converted into an + # absolute timestamp the minute we see it. + if "age" in kwargs: + kwargs["age_ts"] = int(self.clock.time_msec()) - int(kwargs["age"]) + del kwargs["age"] + elif "age_ts" not in kwargs: + kwargs["age_ts"] = int(self.clock.time_msec()) + if etype in self._event_list: handler = self._event_list[etype] else: diff --git a/synapse/api/events/room.py b/synapse/api/events/room.py index 33f0f0cb99..3a4dbc58ce 100644 --- a/synapse/api/events/room.py +++ b/synapse/api/events/room.py @@ -173,3 +173,10 @@ class RoomOpsPowerLevelsEvent(SynapseStateEvent): def get_content_template(self): return {} + + +class RoomAliasesEvent(SynapseStateEvent): + TYPE = "m.room.aliases" + + def get_content_template(self): + return {} diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 49cf928cc1..2f1b954902 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import read_schema +from synapse.storage import prepare_database from synapse.server import HomeServer @@ -36,30 +36,14 @@ from daemonize import Daemonize import twisted.manhole.telnet import logging -import sqlite3 import os import re import sys +import sqlite3 logger = logging.getLogger(__name__) -SCHEMAS = [ - "transactions", - "pdu", - "users", - "profiles", - "presence", - "im", - "room_aliases", -] - - -# Remember to update this number every time an incompatible change is made to -# database schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 2 - - class SynapseHomeServer(HomeServer): def build_http_client(self): @@ -80,52 +64,12 @@ class SynapseHomeServer(HomeServer): ) def build_db_pool(self): - """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we - don't have to worry about overwriting existing content. - """ - logging.info("Preparing database: %s...", self.db_name) - - with sqlite3.connect(self.db_name) as db_conn: - c = db_conn.cursor() - c.execute("PRAGMA user_version") - row = c.fetchone() - - if row and row[0]: - user_version = row[0] - - if user_version > SCHEMA_VERSION: - raise ValueError("Cannot use this database as it is too " + - "new for the server to understand" - ) - elif user_version < SCHEMA_VERSION: - logging.info("Upgrading database from version %d", - user_version - ) - - # Run every version since after the current version. - for v in range(user_version + 1, SCHEMA_VERSION + 1): - sql_script = read_schema("delta/v%d" % (v)) - c.executescript(sql_script) - - db_conn.commit() - - else: - for sql_loc in SCHEMAS: - sql_script = read_schema(sql_loc) - - c.executescript(sql_script) - db_conn.commit() - c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION) - - c.close() - - logging.info("Database prepared in %s.", self.db_name) - - pool = adbapi.ConnectionPool( - 'sqlite3', self.db_name, check_same_thread=False, - cp_min=1, cp_max=1) - - return pool + return adbapi.ConnectionPool( + "sqlite3", self.get_db_name(), + check_same_thread=False, + cp_min=1, + cp_max=1 + ) def create_resource_tree(self, web_client, redirect_root_to_web_client): """Create the resource tree for this Home Server. @@ -230,10 +174,6 @@ class SynapseHomeServer(HomeServer): logger.info("Synapse now listening on port %d", unsecure_port) -def run(): - reactor.run() - - def setup(): config = HomeServerConfig.load_config( "Synapse Homeserver", @@ -268,7 +208,15 @@ def setup(): web_client=config.webclient, redirect_root_to_web_client=True, ) - hs.start_listening(config.bind_port, config.unsecure_port) + + db_name = hs.get_db_name() + + logging.info("Preparing database: %s...", db_name) + + with sqlite3.connect(db_name) as db_conn: + prepare_database(db_conn) + + logging.info("Database prepared in %s.", db_name) hs.get_db_pool() @@ -279,12 +227,14 @@ def setup(): f.namespace['hs'] = hs reactor.listenTCP(config.manhole, f, interface='127.0.0.1') + hs.start_listening(config.bind_port, config.unsecure_port) + if config.daemonize: print config.pid_file daemon = Daemonize( app="synapse-homeserver", pid=config.pid_file, - action=run, + action=reactor.run, auto_close_fds=False, verbose=True, logger=logger, @@ -292,7 +242,7 @@ def setup(): daemon.start() else: - run() + reactor.run() if __name__ == '__main__': diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py new file mode 100644 index 0000000000..8ebcfc3623 --- /dev/null +++ b/synapse/config/captcha.py @@ -0,0 +1,46 @@ +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class CaptchaConfig(Config): + + def __init__(self, args): + super(CaptchaConfig, self).__init__(args) + self.recaptcha_private_key = args.recaptcha_private_key + self.enable_registration_captcha = args.enable_registration_captcha + self.captcha_ip_origin_is_x_forwarded = ( + args.captcha_ip_origin_is_x_forwarded + ) + + @classmethod + def add_arguments(cls, parser): + super(CaptchaConfig, cls).add_arguments(parser) + group = parser.add_argument_group("recaptcha") + group.add_argument( + "--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY", + help="The matching private key for the web client's public key." + ) + group.add_argument( + "--enable-registration-captcha", type=bool, default=False, + help="Enables ReCaptcha checks when registering, preventing signup" + + " unless a captcha is answered. Requires a valid ReCaptcha " + + "public/private key." + ) + group.add_argument( + "--captcha_ip_origin_is_x_forwarded", type=bool, default=False, + help="When checking captchas, use the X-Forwarded-For (XFF) header" + + " as the client IP and not the actual client IP." + ) \ No newline at end of file diff --git a/synapse/config/email.py b/synapse/config/email.py new file mode 100644 index 0000000000..9bcc5a8fea --- /dev/null +++ b/synapse/config/email.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class EmailConfig(Config): + + def __init__(self, args): + super(EmailConfig, self).__init__(args) + self.email_from_address = args.email_from_address + self.email_smtp_server = args.email_smtp_server + + @classmethod + def add_arguments(cls, parser): + super(EmailConfig, cls).add_arguments(parser) + email_group = parser.add_argument_group("email") + email_group.add_argument( + "--email-from-address", + default="FROM@EXAMPLE.COM", + help="The address to send emails from (e.g. for password resets)." + ) + email_group.add_argument( + "--email-smtp-server", + default="", + help="The SMTP server to send emails from (e.g. for password resets)." + ) \ No newline at end of file diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 76e2cdeddd..4b810a2302 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -19,11 +19,16 @@ from .logger import LoggingConfig from .database import DatabaseConfig from .ratelimiting import RatelimitConfig from .repository import ContentRepositoryConfig +from .captcha import CaptchaConfig +from .email import EmailConfig + class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, - RatelimitConfig, ContentRepositoryConfig): + RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, + EmailConfig): pass -if __name__=='__main__': + +if __name__ == '__main__': import sys HomeServerConfig.load_config("Generate config", sys.argv[1:], "HomeServer") diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index e12510017f..96b82f00cb 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -291,6 +291,13 @@ class ReplicationLayer(object): def on_incoming_transaction(self, transaction_data): transaction = Transaction(**transaction_data) + for p in transaction.pdus: + if "age" in p: + p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) + del p["age"] + + pdu_list = [Pdu(**p) for p in transaction.pdus] + logger.debug("[%s] Got transaction", transaction.transaction_id) response = yield self.transaction_actions.have_responded(transaction) @@ -303,8 +310,6 @@ class ReplicationLayer(object): logger.debug("[%s] Transacition is new", transaction.transaction_id) - pdu_list = [Pdu(**p) for p in transaction.pdus] - dl = [] for pdu in pdu_list: dl.append(self._handle_new_pdu(pdu)) @@ -405,9 +410,14 @@ class ReplicationLayer(object): """Returns a new Transaction containing the given PDUs suitable for transmission. """ + pdus = [p.get_dict() for p in pdu_list] + for p in pdus: + if "age_ts" in pdus: + p["age"] = int(self.clock.time_msec()) - p["age_ts"] + return Transaction( - pdus=[p.get_dict() for p in pdu_list], origin=self.server_name, + pdus=pdus, ts=int(self._clock.time_msec()), destination=None, ) @@ -593,8 +603,21 @@ class _TransactionQueue(object): logger.debug("TX [%s] Sending transaction...", destination) # Actually send the transaction + + # FIXME (erikj): This is a bit of a hack to make the Pdu age + # keys work + def cb(transaction): + now = int(self._clock.time_msec()) + if "pdus" in transaction: + for p in transaction["pdus"]: + if "age_ts" in p: + p["age"] = now - int(p["age_ts"]) + + return transaction + code, response = yield self.transport_layer.send_transaction( - transaction + transaction, + on_send_callback=cb, ) logger.debug("TX [%s] Sent transaction", destination) diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py index 6e62ae7c74..afc777ec9e 100644 --- a/synapse/federation/transport.py +++ b/synapse/federation/transport.py @@ -144,7 +144,7 @@ class TransportLayer(object): @defer.inlineCallbacks @log_function - def send_transaction(self, transaction): + def send_transaction(self, transaction, on_send_callback=None): """ Sends the given Transaction to it's destination Args: @@ -165,10 +165,23 @@ class TransportLayer(object): data = transaction.get_dict() + # FIXME (erikj): This is a bit of a hack to make the Pdu age + # keys work + def cb(destination, method, path_bytes, producer): + if not on_send_callback: + return + + transaction = json.loads(producer.body) + + new_transaction = on_send_callback(transaction) + + producer.reset(new_transaction) + code, response = yield self.client.put_json( transaction.destination, path=PREFIX + "/send/%s/" % transaction.transaction_id, - data=data + data=data, + on_send_callback=cb, ) logger.debug( diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 9740431279..622fe66a8f 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -69,6 +69,7 @@ class Pdu(JsonEncodedObject): "prev_state_id", "prev_state_origin", "required_power_level", + "user_id", ] internal_keys = [ diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 9989fe8670..de4d23bbb3 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -42,9 +42,6 @@ class BaseHandler(object): retry_after_ms=int(1000*(time_allowed - time_now)), ) - -class BaseRoomHandler(BaseHandler): - @defer.inlineCallbacks def _on_new_room_event(self, event, snapshot, extra_destinations=[], extra_users=[]): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 1b9e831fc0..4ab00a761a 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -19,8 +19,10 @@ from ._base import BaseHandler from synapse.api.errors import SynapseError from synapse.http.client import HttpClient +from synapse.api.events.room import RoomAliasesEvent import logging +import sqlite3 logger = logging.getLogger(__name__) @@ -37,7 +39,8 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def create_association(self, room_alias, room_id, servers=None): + def create_association(self, user_id, room_alias, room_id, servers=None): + # TODO(erikj): Do auth. if not room_alias.is_mine: @@ -54,12 +57,37 @@ class DirectoryHandler(BaseHandler): if not servers: raise SynapseError(400, "Failed to get server list") - yield self.store.create_room_alias_association( - room_alias, - room_id, - servers + + try: + yield self.store.create_room_alias_association( + room_alias, + room_id, + servers + ) + except sqlite3.IntegrityError: + defer.returnValue("Already exists") + + # TODO: Send the room event. + + aliases = yield self.store.get_aliases_for_room(room_id) + + event = self.event_factory.create_event( + etype=RoomAliasesEvent.TYPE, + state_key=self.hs.hostname, + room_id=room_id, + user_id=user_id, + content={"aliases": aliases}, + ) + + snapshot = yield self.store.snapshot_room( + room_id=room_id, + user_id=user_id, ) + yield self.state_handler.handle_new_event(event, snapshot) + yield self._on_new_room_event(event, snapshot, extra_users=[user_id]) + + @defer.inlineCallbacks def get_association(self, room_alias): room_id = None diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index fd24a11fb8..93dcd40324 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -15,7 +15,6 @@ from twisted.internet import defer -from synapse.api.events import SynapseEvent from synapse.util.logutils import log_function from ._base import BaseHandler @@ -71,10 +70,7 @@ class EventStreamHandler(BaseHandler): auth_user, room_ids, pagin_config, timeout ) - chunks = [ - e.get_dict() if isinstance(e, SynapseEvent) else e - for e in events - ] + chunks = [self.hs.serialize_event(e) for e in events] chunk = { "chunk": chunks, @@ -92,7 +88,9 @@ class EventStreamHandler(BaseHandler): # 10 seconds of grace to allow the client to reconnect again # before we think they're gone def _later(): - logger.debug("_later stopped_user_eventstream %s", auth_user) + logger.debug( + "_later stopped_user_eventstream %s", auth_user + ) self.distributor.fire( "stopped_user_eventstream", auth_user ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 59cbf71d78..001c6c110c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -93,22 +93,18 @@ class FederationHandler(BaseHandler): """ event = self.pdu_codec.event_from_pdu(pdu) + logger.debug("Got event: %s", event.event_id) + with (yield self.lock_manager.lock(pdu.context)): if event.is_state and not backfilled: is_new_state = yield self.state_handler.handle_new_state( pdu ) - if not is_new_state: - return else: is_new_state = False # TODO: Implement something in federation that allows us to # respond to PDU. - if hasattr(event, "state_key") and not is_new_state: - logger.debug("Ignoring old state.") - return - target_is_mine = False if hasattr(event, "target_host"): target_is_mine = event.target_host == self.hs.hostname @@ -139,7 +135,11 @@ class FederationHandler(BaseHandler): else: with (yield self.room_lock.lock(event.room_id)): - yield self.store.persist_event(event, backfilled) + yield self.store.persist_event( + event, + backfilled, + is_new_state=is_new_state + ) room = yield self.store.get_room(event.room_id) diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py index 6ee7ce5a2d..80ffdd2726 100644 --- a/synapse/handlers/login.py +++ b/synapse/handlers/login.py @@ -17,9 +17,13 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.errors import LoginError, Codes +from synapse.http.client import PlainHttpClient +from synapse.util.emailutils import EmailException +import synapse.util.emailutils as emailutils import bcrypt import logging +import urllib logger = logging.getLogger(__name__) @@ -62,4 +66,41 @@ class LoginHandler(BaseHandler): defer.returnValue(token) else: logger.warn("Failed password login for user %s", user) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) \ No newline at end of file + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + + @defer.inlineCallbacks + def reset_password(self, user_id, email): + is_valid = yield self._check_valid_association(user_id, email) + logger.info("reset_password user=%s email=%s valid=%s", user_id, email, + is_valid) + if is_valid: + try: + # send an email out + emailutils.send_email( + smtp_server=self.hs.config.email_smtp_server, + from_addr=self.hs.config.email_from_address, + to_addr=email, + subject="Password Reset", + body="TODO." + ) + except EmailException as e: + logger.exception(e) + + @defer.inlineCallbacks + def _check_valid_association(self, user_id, email): + identity = yield self._query_email(email) + if identity and "mxid" in identity: + if identity["mxid"] == user_id: + defer.returnValue(True) + return + defer.returnValue(False) + + @defer.inlineCallbacks + def _query_email(self, email): + httpCli = PlainHttpClient(self.hs) + data = yield httpCli.get_json( + 'matrix.org:8090', # TODO FIXME This should be configurable. + "/_matrix/identity/api/v1/lookup?medium=email&address=" + + "%s" % urllib.quote(email) + ) + defer.returnValue(data) \ No newline at end of file diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index dad2bbd1a4..14fae689f2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -19,7 +19,7 @@ from synapse.api.constants import Membership from synapse.api.events.room import RoomTopicEvent from synapse.api.errors import RoomError from synapse.streams.config import PaginationConfig -from ._base import BaseRoomHandler +from ._base import BaseHandler import logging @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) -class MessageHandler(BaseRoomHandler): +class MessageHandler(BaseHandler): def __init__(self, hs): super(MessageHandler, self).__init__(hs) @@ -124,7 +124,7 @@ class MessageHandler(BaseRoomHandler): ) chunk = { - "chunk": [e.get_dict() for e in events], + "chunk": [self.hs.serialize_event(e) for e in events], "start": pagin_config.from_token.to_string(), "end": next_token.to_string(), } @@ -268,6 +268,9 @@ class MessageHandler(BaseRoomHandler): user, pagination_config, None ) + public_rooms = yield self.store.get_rooms(is_public=True) + public_room_ids = [r["room_id"] for r in public_rooms] + limit = pagin_config.limit if not limit: limit = 10 @@ -276,6 +279,8 @@ class MessageHandler(BaseRoomHandler): d = { "room_id": event.room_id, "membership": event.membership, + "visibility": ("public" if event.room_id in + public_room_ids else "private"), } if event.membership == Membership.INVITE: @@ -296,7 +301,7 @@ class MessageHandler(BaseRoomHandler): end_token = now_token.copy_and_replace("room_key", token[1]) d["messages"] = { - "chunk": [m.get_dict() for m in messages], + "chunk": [self.hs.serialize_event(m) for m in messages], "start": start_token.to_string(), "end": end_token.to_string(), } @@ -304,7 +309,7 @@ class MessageHandler(BaseRoomHandler): current_state = yield self.store.get_current_state( event.room_id ) - d["state"] = [c.get_dict() for c in current_state] + d["state"] = [self.hs.serialize_event(c) for c in current_state] except: logger.exception("Failed to get snapshot") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index c79bb6ff76..b2af09f090 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -796,11 +796,12 @@ class PresenceEventSource(object): updates = [] # TODO(paul): use a DeferredList ? How to limit concurrency. for observed_user in cachemap.keys(): - if not (from_key < cachemap[observed_user].serial): + cached = cachemap[observed_user] + if not (from_key < cached.serial): continue if (yield self.is_visible(observer_user, observed_user)): - updates.append((observed_user, cachemap[observed_user])) + updates.append((observed_user, cached)) # TODO(paul): limit diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 023d8c0cf2..dab9b03f04 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -15,9 +15,9 @@ from twisted.internet import defer -from synapse.api.errors import SynapseError, AuthError - -from synapse.api.errors import CodeMessageException +from synapse.api.errors import SynapseError, AuthError, CodeMessageException +from synapse.api.constants import Membership +from synapse.api.events.room import RoomMemberEvent from ._base import BaseHandler @@ -97,6 +97,8 @@ class ProfileHandler(BaseHandler): } ) + yield self._update_join_states(target_user) + @defer.inlineCallbacks def get_avatar_url(self, target_user): if target_user.is_mine: @@ -144,6 +146,8 @@ class ProfileHandler(BaseHandler): } ) + yield self._update_join_states(target_user) + @defer.inlineCallbacks def collect_presencelike_data(self, user, state): if not user.is_mine: @@ -180,3 +184,39 @@ class ProfileHandler(BaseHandler): ) defer.returnValue(response) + + @defer.inlineCallbacks + def _update_join_states(self, user): + if not user.is_mine: + return + + joins = yield self.store.get_rooms_for_user_where_membership_is( + user.to_string(), + [Membership.JOIN], + ) + + for j in joins: + snapshot = yield self.store.snapshot_room( + j.room_id, j.state_key, RoomMemberEvent.TYPE, + j.state_key + ) + + content = { + "membership": j.content["membership"], + "prev": j.content["membership"], + } + + yield self.distributor.fire( + "collect_presencelike_data", user, content + ) + + new_event = self.event_factory.create_event( + etype=j.type, + room_id=j.room_id, + state_key=j.state_key, + content=content, + user_id=j.state_key, + ) + + yield self.state_handler.handle_new_event(new_event, snapshot) + yield self._on_new_room_event(new_event, snapshot) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index bee052274f..a019d770d4 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -17,7 +17,9 @@ from twisted.internet import defer from synapse.types import UserID -from synapse.api.errors import SynapseError, RegistrationError +from synapse.api.errors import ( + SynapseError, RegistrationError, InvalidCaptchaError +) from ._base import BaseHandler import synapse.util.stringutils as stringutils from synapse.http.client import PlainHttpClient @@ -38,7 +40,7 @@ class RegistrationHandler(BaseHandler): self.distributor.declare("registered_user") @defer.inlineCallbacks - def register(self, localpart=None, password=None, threepidCreds=None): + def register(self, localpart=None, password=None): """Registers a new client on the server. Args: @@ -51,20 +53,6 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - - if threepidCreds: - for c in threepidCreds: - logger.info("validating theeepidcred sid %s on id server %s", c['sid'], c['idServer']) - try: - threepid = yield self._threepid_from_creds(c) - except: - logger.err() - raise RegistrationError(400, "Couldn't validate 3pid") - - if not threepid: - raise RegistrationError(400, "Couldn't validate 3pid") - logger.info("got threepid medium %s address %s", threepid['medium'], threepid['address']) - password_hash = None if password: password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) @@ -106,15 +94,54 @@ class RegistrationHandler(BaseHandler): raise RegistrationError( 500, "Cannot generate user ID.") - # Now we have a matrix ID, bind it to the threepids we were given - if threepidCreds: - for c in threepidCreds: - # XXX: This should be a deferred list, shouldn't it? - yield self._bind_threepid(c, user_id) - - defer.returnValue((user_id, token)) + @defer.inlineCallbacks + def check_recaptcha(self, ip, private_key, challenge, response): + """Checks a recaptcha is correct.""" + + captcha_response = yield self._validate_captcha( + ip, + private_key, + challenge, + response + ) + if not captcha_response["valid"]: + logger.info("Invalid captcha entered from %s. Error: %s", + ip, captcha_response["error_url"]) + raise InvalidCaptchaError( + error_url=captcha_response["error_url"] + ) + else: + logger.info("Valid captcha entered from %s", ip) + + @defer.inlineCallbacks + def register_email(self, threepidCreds): + """Registers emails with an identity server.""" + + for c in threepidCreds: + logger.info("validating theeepidcred sid %s on id server %s", + c['sid'], c['idServer']) + try: + threepid = yield self._threepid_from_creds(c) + except: + logger.err() + raise RegistrationError(400, "Couldn't validate 3pid") + + if not threepid: + raise RegistrationError(400, "Couldn't validate 3pid") + logger.info("got threepid medium %s address %s", + threepid['medium'], threepid['address']) + + @defer.inlineCallbacks + def bind_emails(self, user_id, threepidCreds): + """Links emails with a user ID and informs an identity server.""" + + # Now we have a matrix ID, bind it to the threepids we were given + for c in threepidCreds: + # XXX: This should be a deferred list, shouldn't it? + yield self._bind_threepid(c, user_id) + def _generate_token(self, user_id): # urlsafe variant uses _ and - so use . as the separator and replace # all =s with .s so http clients don't quote =s when it is used as @@ -129,16 +156,17 @@ class RegistrationHandler(BaseHandler): def _threepid_from_creds(self, creds): httpCli = PlainHttpClient(self.hs) # XXX: make this configurable! - trustedIdServers = [ 'matrix.org:8090' ] + trustedIdServers = ['matrix.org:8090'] if not creds['idServer'] in trustedIdServers: - logger.warn('%s is not a trusted ID server: rejecting 3pid credentials', creds['idServer']) + logger.warn('%s is not a trusted ID server: rejecting 3pid ' + + 'credentials', creds['idServer']) defer.returnValue(None) data = yield httpCli.get_json( creds['idServer'], "/_matrix/identity/api/v1/3pid/getValidated3pid", - { 'sid': creds['sid'], 'clientSecret': creds['clientSecret'] } + {'sid': creds['sid'], 'clientSecret': creds['clientSecret']} ) - + if 'medium' in data: defer.returnValue(data) defer.returnValue(None) @@ -149,9 +177,45 @@ class RegistrationHandler(BaseHandler): data = yield httpCli.post_urlencoded_get_json( creds['idServer'], "/_matrix/identity/api/v1/3pid/bind", - { 'sid': creds['sid'], 'clientSecret': creds['clientSecret'], 'mxid':mxid } + {'sid': creds['sid'], 'clientSecret': creds['clientSecret'], + 'mxid': mxid} ) defer.returnValue(data) - - + + @defer.inlineCallbacks + def _validate_captcha(self, ip_addr, private_key, challenge, response): + """Validates the captcha provided. + + Returns: + dict: Containing 'valid'(bool) and 'error_url'(str) if invalid. + + """ + response = yield self._submit_captcha(ip_addr, private_key, challenge, + response) + # parse Google's response. Lovely format.. + lines = response.split('\n') + json = { + "valid": lines[0] == 'true', + "error_url": "http://www.google.com/recaptcha/api/challenge?" + + "error=%s" % lines[1] + } + defer.returnValue(json) + + @defer.inlineCallbacks + def _submit_captcha(self, ip_addr, private_key, challenge, response): + client = PlainHttpClient(self.hs) + data = yield client.post_urlencoded_get_raw( + "www.google.com:80", + "/recaptcha/api/verify", + # twisted dislikes google's response, no content length. + accept_partial=True, + args={ + 'privatekey': private_key, + 'remoteip': ip_addr, + 'challenge': challenge, + 'response': response + } + ) + defer.returnValue(data) + diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 8171e9eb45..5bc1280432 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,14 +25,14 @@ from synapse.api.events.room import ( RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent, RoomNameEvent, ) from synapse.util import stringutils -from ._base import BaseRoomHandler +from ._base import BaseHandler import logging logger = logging.getLogger(__name__) -class RoomCreationHandler(BaseRoomHandler): +class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def create_room(self, user_id, room_id, config): @@ -65,6 +65,13 @@ class RoomCreationHandler(BaseRoomHandler): else: room_alias = None + invite_list = config.get("invite", []) + for i in invite_list: + try: + self.hs.parse_userid(i) + except: + raise SynapseError(400, "Invalid user_id: %s" % (i,)) + is_public = config.get("visibility", None) == "public" if room_id: @@ -105,7 +112,9 @@ class RoomCreationHandler(BaseRoomHandler): ) if room_alias: - yield self.store.create_room_alias_association( + directory_handler = self.hs.get_handlers().directory_handler + yield directory_handler.create_association( + user_id=user_id, room_id=room_id, room_alias=room_alias, servers=[self.hs.hostname], @@ -132,7 +141,7 @@ class RoomCreationHandler(BaseRoomHandler): etype=RoomNameEvent.TYPE, room_id=room_id, user_id=user_id, - required_power_level=5, + required_power_level=50, content={"name": name}, ) @@ -143,7 +152,7 @@ class RoomCreationHandler(BaseRoomHandler): etype=RoomNameEvent.TYPE, room_id=room_id, user_id=user_id, - required_power_level=5, + required_power_level=50, content={"name": name}, ) @@ -155,7 +164,7 @@ class RoomCreationHandler(BaseRoomHandler): etype=RoomTopicEvent.TYPE, room_id=room_id, user_id=user_id, - required_power_level=5, + required_power_level=50, content={"topic": topic}, ) @@ -176,6 +185,25 @@ class RoomCreationHandler(BaseRoomHandler): do_auth=False ) + content = {"membership": Membership.INVITE} + for invitee in invite_list: + invite_event = self.event_factory.create_event( + etype=RoomMemberEvent.TYPE, + state_key=invitee, + room_id=room_id, + user_id=user_id, + content=content + ) + + yield self.hs.get_handlers().room_member_handler.change_membership( + invite_event, + do_auth=False + ) + + yield self.hs.get_handlers().room_member_handler.change_membership( + join_event, + do_auth=False + ) result = {"room_id": room_id} if room_alias: result["room_alias"] = room_alias.to_string() @@ -186,7 +214,7 @@ class RoomCreationHandler(BaseRoomHandler): event_keys = { "room_id": room_id, "user_id": creator.to_string(), - "required_power_level": 10, + "required_power_level": 100, } def create(etype, **content): @@ -203,7 +231,7 @@ class RoomCreationHandler(BaseRoomHandler): power_levels_event = self.event_factory.create_event( etype=RoomPowerLevelsEvent.TYPE, - content={creator.to_string(): 10, "default": 0}, + content={creator.to_string(): 100, "default": 0}, **event_keys ) @@ -215,7 +243,7 @@ class RoomCreationHandler(BaseRoomHandler): add_state_event = create( etype=RoomAddStateLevelEvent.TYPE, - level=10, + level=100, ) send_event = create( @@ -225,8 +253,8 @@ class RoomCreationHandler(BaseRoomHandler): ops = create( etype=RoomOpsPowerLevelsEvent.TYPE, - ban_level=5, - kick_level=5, + ban_level=50, + kick_level=50, ) return [ @@ -239,7 +267,7 @@ class RoomCreationHandler(BaseRoomHandler): ] -class RoomMemberHandler(BaseRoomHandler): +class RoomMemberHandler(BaseHandler): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns @@ -307,7 +335,7 @@ class RoomMemberHandler(BaseRoomHandler): member_list = yield self.store.get_room_members(room_id=room_id) event_list = [ - entry.get_dict() + self.hs.serialize_event(entry) for entry in member_list ] chunk_data = { @@ -560,11 +588,17 @@ class RoomMemberHandler(BaseRoomHandler): extra_users=[target_user] ) -class RoomListHandler(BaseRoomHandler): +class RoomListHandler(BaseHandler): @defer.inlineCallbacks def get_public_room_list(self): chunk = yield self.store.get_rooms(is_public=True) + for room in chunk: + joined_members = yield self.store.get_room_members( + room_id=room["room_id"], + membership=Membership.JOIN + ) + room["num_joined_members"] = len(joined_members) # FIXME (erikj): START is no longer a valid value defer.returnValue({"start": "START", "end": "END", "chunk": chunk}) diff --git a/synapse/http/client.py b/synapse/http/client.py index ebf1aa47c4..eb11bfd4d5 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -16,7 +16,7 @@ from twisted.internet import defer, reactor from twisted.internet.error import DNSLookupError -from twisted.web.client import _AgentBase, _URI, readBody, FileBodyProducer +from twisted.web.client import _AgentBase, _URI, readBody, FileBodyProducer, PartialDownloadError from twisted.web.http_headers import Headers from synapse.http.endpoint import matrix_endpoint @@ -122,7 +122,7 @@ class TwistedHttpClient(HttpClient): self.hs = hs @defer.inlineCallbacks - def put_json(self, destination, path, data): + def put_json(self, destination, path, data, on_send_callback=None): if destination in _destination_mappings: destination = _destination_mappings[destination] @@ -131,7 +131,8 @@ class TwistedHttpClient(HttpClient): "PUT", path.encode("ascii"), producer=_JsonProducer(data), - headers_dict={"Content-Type": ["application/json"]} + headers_dict={"Content-Type": ["application/json"]}, + on_send_callback=on_send_callback, ) logger.debug("Getting resp body") @@ -188,11 +189,37 @@ class TwistedHttpClient(HttpClient): body = yield readBody(response) defer.returnValue(json.loads(body)) + + # XXX FIXME : I'm so sorry. + @defer.inlineCallbacks + def post_urlencoded_get_raw(self, destination, path, accept_partial=False, args={}): + if destination in _destination_mappings: + destination = _destination_mappings[destination] + + query_bytes = urllib.urlencode(args, True) + + response = yield self._create_request( + destination.encode("ascii"), + "POST", + path.encode("ascii"), + producer=FileBodyProducer(StringIO(urllib.urlencode(args))), + headers_dict={"Content-Type": ["application/x-www-form-urlencoded"]} + ) + + try: + body = yield readBody(response) + defer.returnValue(body) + except PartialDownloadError as e: + if accept_partial: + defer.returnValue(e.response) + else: + raise e + @defer.inlineCallbacks def _create_request(self, destination, method, path_bytes, param_bytes=b"", query_bytes=b"", producer=None, headers_dict={}, - retry_on_dns_fail=True): + retry_on_dns_fail=True, on_send_callback=None): """ Creates and sends a request to the given url """ headers_dict[b"User-Agent"] = [b"Synapse"] @@ -216,6 +243,9 @@ class TwistedHttpClient(HttpClient): endpoint = self._getEndpoint(reactor, destination); while True: + if on_send_callback: + on_send_callback(destination, method, path_bytes, producer) + try: response = yield self.agent.request( destination, @@ -284,6 +314,9 @@ class _JsonProducer(object): """ Used by the twisted http client to create the HTTP body from json """ def __init__(self, jsn): + self.reset(jsn) + + def reset(self, jsn): self.body = encode_canonical_json(jsn) self.length = len(self.body) diff --git a/synapse/rest/directory.py b/synapse/rest/directory.py index 18df7c8d8b..31849246a1 100644 --- a/synapse/rest/directory.py +++ b/synapse/rest/directory.py @@ -45,6 +45,8 @@ class ClientDirectoryServer(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_alias): + user = yield self.auth.get_user_by_req(request) + content = _parse_json(request) if not "room_id" in content: raise SynapseError(400, "Missing room_id key", @@ -69,12 +71,13 @@ class ClientDirectoryServer(RestServlet): try: yield dir_handler.create_association( - room_alias, room_id, servers + user.to_string(), room_alias, room_id, servers ) except SynapseError as e: raise e except: logger.exception("Failed to create association") + raise defer.returnValue((200, {})) diff --git a/synapse/rest/events.py b/synapse/rest/events.py index 7fde143200..097195d7cc 100644 --- a/synapse/rest/events.py +++ b/synapse/rest/events.py @@ -59,7 +59,7 @@ class EventRestServlet(RestServlet): event = yield handler.get_event(auth_user, event_id) if event: - defer.returnValue((200, event.get_dict())) + defer.returnValue((200, self.hs.serialize_event(event))) else: defer.returnValue((404, "Event not found.")) diff --git a/synapse/rest/login.py b/synapse/rest/login.py index c7bf901c8e..ad71f6c61d 100644 --- a/synapse/rest/login.py +++ b/synapse/rest/login.py @@ -70,7 +70,28 @@ class LoginFallbackRestServlet(RestServlet): def on_GET(self, request): # TODO(kegan): This should be returning some HTML which is capable of # hitting LoginRestServlet - return (200, "") + return (200, {}) + + +class PasswordResetRestServlet(RestServlet): + PATTERN = client_path_pattern("/login/reset") + + @defer.inlineCallbacks + def on_POST(self, request): + reset_info = _parse_json(request) + try: + email = reset_info["email"] + user_id = reset_info["user_id"] + handler = self.handlers.login_handler + yield handler.reset_password(user_id, email) + # purposefully give no feedback to avoid people hammering different + # combinations. + defer.returnValue((200, {})) + except KeyError: + raise SynapseError( + 400, + "Missing keys. Requires 'email' and 'user_id'." + ) def _parse_json(request): @@ -85,3 +106,4 @@ def _parse_json(request): def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + # TODO PasswordResetRestServlet(hs).register(http_server) diff --git a/synapse/rest/profile.py b/synapse/rest/profile.py index 2e17f87fa1..dad5a208c7 100644 --- a/synapse/rest/profile.py +++ b/synapse/rest/profile.py @@ -51,7 +51,7 @@ class ProfileDisplaynameRestServlet(RestServlet): yield self.handlers.profile_handler.set_displayname( user, auth_user, new_name) - defer.returnValue((200, "")) + defer.returnValue((200, {})) def on_OPTIONS(self, request, user_id): return (200, {}) @@ -86,7 +86,7 @@ class ProfileAvatarURLRestServlet(RestServlet): yield self.handlers.profile_handler.set_avatar_url( user, auth_user, new_name) - defer.returnValue((200, "")) + defer.returnValue((200, {})) def on_OPTIONS(self, request, user_id): return (200, {}) diff --git a/synapse/rest/register.py b/synapse/rest/register.py index b8de3b250d..af528a44f6 100644 --- a/synapse/rest/register.py +++ b/synapse/rest/register.py @@ -16,58 +16,219 @@ """This module contains REST servlets to do with registration: /register""" from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, Codes +from synapse.api.constants import LoginType from base import RestServlet, client_path_pattern +import synapse.util.stringutils as stringutils import json +import logging import urllib +logger = logging.getLogger(__name__) + class RegisterRestServlet(RestServlet): + """Handles registration with the home server. + + This servlet is in control of the registration flow; the registration + handler doesn't have a concept of multi-stages or sessions. + """ + PATTERN = client_path_pattern("/register$") + def __init__(self, hs): + super(RegisterRestServlet, self).__init__(hs) + # sessions are stored as: + # self.sessions = { + # "session_id" : { __session_dict__ } + # } + # TODO: persistent storage + self.sessions = {} + + def on_GET(self, request): + if self.hs.config.enable_registration_captcha: + return (200, { + "flows": [ + { + "type": LoginType.RECAPTCHA, + "stages": ([LoginType.RECAPTCHA, + LoginType.EMAIL_IDENTITY, + LoginType.PASSWORD]) + }, + { + "type": LoginType.RECAPTCHA, + "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD] + } + ] + }) + else: + return (200, { + "flows": [ + { + "type": LoginType.EMAIL_IDENTITY, + "stages": ([LoginType.EMAIL_IDENTITY, + LoginType.PASSWORD]) + }, + { + "type": LoginType.PASSWORD + } + ] + }) + @defer.inlineCallbacks def on_POST(self, request): - desired_user_id = None - password = None + register_json = _parse_json(request) + + session = (register_json["session"] if "session" in register_json + else None) + login_type = None + if "type" not in register_json: + raise SynapseError(400, "Missing 'type' key.") + + try: + login_type = register_json["type"] + stages = { + LoginType.RECAPTCHA: self._do_recaptcha, + LoginType.PASSWORD: self._do_password, + LoginType.EMAIL_IDENTITY: self._do_email_identity + } + + session_info = self._get_session_info(request, session) + logger.debug("%s : session info %s request info %s", + login_type, session_info, register_json) + response = yield stages[login_type]( + request, + register_json, + session_info + ) + + if "access_token" not in response: + # isn't a final response + response["session"] = session_info["id"] + + defer.returnValue((200, response)) + except KeyError as e: + logger.exception(e) + raise SynapseError(400, "Missing JSON keys for login type %s." % login_type) + + def on_OPTIONS(self, request): + return (200, {}) + + def _get_session_info(self, request, session_id): + if not session_id: + # create a new session + while session_id is None or session_id in self.sessions: + session_id = stringutils.random_string(24) + self.sessions[session_id] = { + "id": session_id, + LoginType.EMAIL_IDENTITY: False, + LoginType.RECAPTCHA: False + } + + return self.sessions[session_id] + + def _save_session(self, session): + # TODO: Persistent storage + logger.debug("Saving session %s", session) + self.sessions[session["id"]] = session + + def _remove_session(self, session): + logger.debug("Removing session %s", session) + self.sessions.pop(session["id"]) + + @defer.inlineCallbacks + def _do_recaptcha(self, request, register_json, session): + if not self.hs.config.enable_registration_captcha: + raise SynapseError(400, "Captcha not required.") + + challenge = None + user_response = None try: - register_json = json.loads(request.content.read()) - if "password" in register_json: - password = register_json["password"].encode("utf-8") - - if type(register_json["user_id"]) == unicode: - desired_user_id = register_json["user_id"].encode("utf-8") - if urllib.quote(desired_user_id) != desired_user_id: - raise SynapseError( - 400, - "User ID must only contain characters which do not " + - "require URL encoding.") - except ValueError: - defer.returnValue((400, "No JSON object.")) + challenge = register_json["challenge"] + user_response = register_json["response"] except KeyError: - pass # user_id is optional + raise SynapseError(400, "Captcha response is required", + errcode=Codes.CAPTCHA_NEEDED) + + # May be an X-Forwarding-For header depending on config + ip_addr = request.getClientIP() + if self.hs.config.captcha_ip_origin_is_x_forwarded: + # use the header + if request.requestHeaders.hasHeader("X-Forwarded-For"): + ip_addr = request.requestHeaders.getRawHeaders( + "X-Forwarded-For")[0] + + handler = self.handlers.registration_handler + yield handler.check_recaptcha( + ip_addr, + self.hs.config.recaptcha_private_key, + challenge, + user_response + ) + session[LoginType.RECAPTCHA] = True # mark captcha as done + self._save_session(session) + defer.returnValue({ + "next": [LoginType.PASSWORD, LoginType.EMAIL_IDENTITY] + }) + + @defer.inlineCallbacks + def _do_email_identity(self, request, register_json, session): + if (self.hs.config.enable_registration_captcha and + not session[LoginType.RECAPTCHA]): + raise SynapseError(400, "Captcha is required.") - threepidCreds = None - if 'threepidCreds' in register_json: - threepidCreds = register_json['threepidCreds'] + threepidCreds = register_json['threepidCreds'] + handler = self.handlers.registration_handler + yield handler.register_email(threepidCreds) + session["threepidCreds"] = threepidCreds # store creds for next stage + session[LoginType.EMAIL_IDENTITY] = True # mark email as done + self._save_session(session) + defer.returnValue({ + "next": LoginType.PASSWORD + }) + + @defer.inlineCallbacks + def _do_password(self, request, register_json, session): + if (self.hs.config.enable_registration_captcha and + not session[LoginType.RECAPTCHA]): + # captcha should've been done by this stage! + raise SynapseError(400, "Captcha is required.") + password = register_json["password"].encode("utf-8") + desired_user_id = (register_json["user"].encode("utf-8") if "user" + in register_json else None) + if desired_user_id and urllib.quote(desired_user_id) != desired_user_id: + raise SynapseError( + 400, + "User ID must only contain characters which do not " + + "require URL encoding.") handler = self.handlers.registration_handler (user_id, token) = yield handler.register( localpart=desired_user_id, - password=password, - threepidCreds=threepidCreds) + password=password + ) + + if session[LoginType.EMAIL_IDENTITY]: + yield handler.bind_emails(user_id, session["threepidCreds"]) result = { "user_id": user_id, "access_token": token, "home_server": self.hs.hostname, } - defer.returnValue( - (200, result) - ) + self._remove_session(session) + defer.returnValue(result) - def on_OPTIONS(self, request): - return (200, {}) + +def _parse_json(request): + try: + content = json.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.") + return content + except ValueError: + raise SynapseError(400, "Content not JSON.") def register_servlets(hs, http_server): diff --git a/synapse/rest/room.py b/synapse/rest/room.py index 308b447090..ecb1e346d9 100644 --- a/synapse/rest/room.py +++ b/synapse/rest/room.py @@ -154,14 +154,14 @@ class RoomStateEventRestServlet(RestServlet): # membership events are special handler = self.handlers.room_member_handler yield handler.change_membership(event) - defer.returnValue((200, "")) + defer.returnValue((200, {})) else: # store random bits of state msg_handler = self.handlers.message_handler yield msg_handler.store_room_data( event=event ) - defer.returnValue((200, "")) + defer.returnValue((200, {})) # TODO: Needs unit testing for generic events + feedback @@ -249,7 +249,7 @@ class JoinRoomAliasServlet(RestServlet): ) handler = self.handlers.room_member_handler yield handler.change_membership(event) - defer.returnValue((200, "")) + defer.returnValue((200, {})) @defer.inlineCallbacks def on_PUT(self, request, room_identifier, txn_id): @@ -378,7 +378,7 @@ class RoomTriggerBackfill(RestServlet): handler = self.handlers.federation_handler events = yield handler.backfill(remote_server, room_id, limit) - res = [event.get_dict() for event in events] + res = [self.hs.serialize_event(event) for event in events] defer.returnValue((200, res)) @@ -416,7 +416,7 @@ class RoomMembershipRestServlet(RestServlet): ) handler = self.handlers.room_member_handler yield handler.change_membership(event) - defer.returnValue((200, "")) + defer.returnValue((200, {})) @defer.inlineCallbacks def on_PUT(self, request, room_id, membership_action, txn_id): diff --git a/synapse/server.py b/synapse/server.py index 83368ea5a7..cdea49e6ab 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -20,6 +20,7 @@ # Imports required for the default HomeServer() implementation from synapse.federation import initialize_http_replication +from synapse.api.events import serialize_event from synapse.api.events.factory import EventFactory from synapse.notifier import Notifier from synapse.api.auth import Auth @@ -57,6 +58,7 @@ class BaseHomeServer(object): DEPENDENCIES = [ 'clock', 'http_client', + 'db_name', 'db_pool', 'persistence_service', 'replication_layer', @@ -138,6 +140,9 @@ class BaseHomeServer(object): object.""" return RoomID.from_string(s, hs=self) + def serialize_event(self, e): + return serialize_event(self, e) + # Build magic accessors for every dependency for depname in BaseHomeServer.DEPENDENCIES: BaseHomeServer._make_dependency_method(depname) diff --git a/synapse/state.py b/synapse/state.py index 36d8210eb5..9db84c9b5c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,7 +16,7 @@ from twisted.internet import defer -from synapse.federation.pdu_codec import encode_event_id +from synapse.federation.pdu_codec import encode_event_id, decode_event_id from synapse.util.logutils import log_function from collections import namedtuple @@ -87,9 +87,11 @@ class StateHandler(object): # than the power level of the user # power_level = self._get_power_level_for_event(event) + pdu_id, origin = decode_event_id(event.event_id, self.server_name) + yield self.store.update_current_state( - pdu_id=event.event_id, - origin=self.server_name, + pdu_id=pdu_id, + origin=origin, context=key.context, pdu_type=key.type, state_key=key.state_key @@ -113,6 +115,8 @@ class StateHandler(object): is_new = yield self._handle_new_state(new_pdu) + logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin) + if is_new: yield self.store.update_current_state( pdu_id=new_pdu.pdu_id, @@ -132,7 +136,9 @@ class StateHandler(object): @defer.inlineCallbacks @log_function def _handle_new_state(self, new_pdu): - tree = yield self.store.get_unresolved_state_tree(new_pdu) + tree, missing_branch = yield self.store.get_unresolved_state_tree( + new_pdu + ) new_branch, current_branch = tree logger.debug( @@ -140,6 +146,28 @@ class StateHandler(object): new_branch, current_branch ) + if missing_branch is not None: + # We're missing some PDUs. Fetch them. + # TODO (erikj): Limit this. + missing_prev = tree[missing_branch][-1] + + pdu_id = missing_prev.prev_state_id + origin = missing_prev.prev_state_origin + + is_missing = yield self.store.get_pdu(pdu_id, origin) is None + if not is_missing: + raise Exception("Conflict resolution failed") + + yield self._replication.get_pdu( + destination=missing_prev.origin, + pdu_origin=origin, + pdu_id=pdu_id, + outlier=True + ) + + updated_current = yield self._handle_new_state(new_pdu) + defer.returnValue(updated_current) + if not current_branch: # There is no current state defer.returnValue(True) @@ -148,84 +176,85 @@ class StateHandler(object): n = new_branch[-1] c = current_branch[-1] - if n.pdu_id == c.pdu_id and n.origin == c.origin: - # We have all the PDUs we need, so we can just do the conflict - # resolution. + common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin + + if common_ancestor: + # We found a common ancestor! if len(current_branch) == 1: # This is a direct clobber so we can just... defer.returnValue(True) - conflict_res = [ - self._do_power_level_conflict_res, - self._do_chain_length_conflict_res, - self._do_hash_conflict_res, - ] - - for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) - - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) - - raise Exception("Conflict resolution failed.") - else: - # We need to ask for PDUs. - missing_prev = max( - new_branch[-1], current_branch[-1], - key=lambda x: x.depth - ) - - if not hasattr(missing_prev, "prev_state_id"): - # FIXME Hmm - # temporary fallback - for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) + # We didn't find a common ancestor. This is probably fine. + pass - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) - return + result = yield self._do_conflict_res( + new_branch, current_branch, common_ancestor + ) + defer.returnValue(result) - pdu_id = missing_prev.prev_state_id - origin = missing_prev.prev_state_origin + @defer.inlineCallbacks + def _do_conflict_res(self, new_branch, current_branch, common_ancestor): + conflict_res = [ + self._do_power_level_conflict_res, + self._do_chain_length_conflict_res, + self._do_hash_conflict_res, + ] - is_missing = yield self.store.get_pdu(pdu_id, origin) is None + for algo in conflict_res: + new_res, curr_res = yield defer.maybeDeferred( + algo, + new_branch, current_branch, common_ancestor + ) - if not is_missing: - raise Exception("Conflict resolution failed.") + if new_res < curr_res: + defer.returnValue(False) + elif new_res > curr_res: + defer.returnValue(True) - yield self._replication.get_pdu( - destination=missing_prev.origin, - pdu_origin=origin, - pdu_id=pdu_id, - outlier=True - ) + raise Exception("Conflict resolution failed.") - updated_current = yield self._handle_new_state(new_pdu) - defer.returnValue(updated_current) + @defer.inlineCallbacks + def _do_power_level_conflict_res(self, new_branch, current_branch, + common_ancestor): + new_powers_deferreds = [] + for e in new_branch[:-1] if common_ancestor else new_branch: + if hasattr(e, "user_id"): + new_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + current_powers_deferreds = [] + for e in current_branch[:-1] if common_ancestor else current_branch: + if hasattr(e, "user_id"): + current_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + new_powers = yield defer.gatherResults( + new_powers_deferreds, + consumeErrors=True + ) - def _do_power_level_conflict_res(self, new_branch, current_branch): - max_power_new = max( - new_branch[:-1], - key=lambda t: t.power_level - ).power_level + current_powers = yield defer.gatherResults( + current_powers_deferreds, + consumeErrors=True + ) - max_power_current = max( - current_branch[:-1], - key=lambda t: t.power_level - ).power_level + max_power_new = max(new_powers) + max_power_current = max(current_powers) - return (max_power_new, max_power_current) + defer.returnValue( + (max_power_new, max_power_current) + ) - def _do_chain_length_conflict_res(self, new_branch, current_branch): + def _do_chain_length_conflict_res(self, new_branch, current_branch, + common_ancestor): return (len(new_branch), len(current_branch)) - def _do_hash_conflict_res(self, new_branch, current_branch): + def _do_hash_conflict_res(self, new_branch, current_branch, + common_ancestor): new_str = "".join([p.pdu_id + p.origin for p in new_branch]) c_str = "".join([p.pdu_id + p.origin for p in current_branch]) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index d97014f4da..66658f6721 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -36,7 +36,7 @@ from .registration import RegistrationStore from .room import RoomStore from .roommember import RoomMemberStore from .stream import StreamStore -from .pdu import StatePduStore, PduStore +from .pdu import StatePduStore, PduStore, PdusTable from .transactions import TransactionStore from .keys import KeyStore @@ -48,6 +48,28 @@ import os logger = logging.getLogger(__name__) +SCHEMAS = [ + "transactions", + "pdu", + "users", + "profiles", + "presence", + "im", + "room_aliases", +] + + +# Remember to update this number every time an incompatible change is made to +# database schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 3 + + +class _RollbackButIsFineException(Exception): + """ This exception is used to rollback a transaction without implying + something went wrong. + """ + pass + class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore, PresenceStore, PduStore, StatePduStore, TransactionStore, @@ -63,7 +85,8 @@ class DataStore(RoomMemberStore, RoomStore, @defer.inlineCallbacks @log_function - def persist_event(self, event=None, backfilled=False, pdu=None): + def persist_event(self, event=None, backfilled=False, pdu=None, + is_new_state=True): stream_ordering = None if backfilled: if not self.min_token_deferred.called: @@ -71,17 +94,20 @@ class DataStore(RoomMemberStore, RoomStore, self.min_token -= 1 stream_ordering = self.min_token - latest = yield self._db_pool.runInteraction( - self._persist_pdu_event_txn, - pdu=pdu, - event=event, - backfilled=backfilled, - stream_ordering=stream_ordering, - ) - defer.returnValue(latest) + try: + yield self.runInteraction( + self._persist_pdu_event_txn, + pdu=pdu, + event=event, + backfilled=backfilled, + stream_ordering=stream_ordering, + is_new_state=is_new_state, + ) + except _RollbackButIsFineException as e: + pass @defer.inlineCallbacks - def get_event(self, event_id): + def get_event(self, event_id, allow_none=False): events_dict = yield self._simple_select_one( "events", {"event_id": event_id}, @@ -92,18 +118,24 @@ class DataStore(RoomMemberStore, RoomStore, "content", "unrecognized_keys" ], + allow_none=allow_none, ) + if not events_dict: + defer.returnValue(None) + event = self._parse_event_from_row(events_dict) defer.returnValue(event) def _persist_pdu_event_txn(self, txn, pdu=None, event=None, - backfilled=False, stream_ordering=None): + backfilled=False, stream_ordering=None, + is_new_state=True): if pdu is not None: self._persist_event_pdu_txn(txn, pdu) if event is not None: return self._persist_event_txn( - txn, event, backfilled, stream_ordering + txn, event, backfilled, stream_ordering, + is_new_state=is_new_state, ) def _persist_event_pdu_txn(self, txn, pdu): @@ -112,6 +144,12 @@ class DataStore(RoomMemberStore, RoomStore, del cols["content"] del cols["prev_pdus"] cols["content_json"] = json.dumps(pdu.content) + + unrec_keys.update({ + k: v for k, v in cols.items() + if k not in PdusTable.fields + }) + cols["unrecognized_keys"] = json.dumps(unrec_keys) logger.debug("Persisting: %s", repr(cols)) @@ -124,7 +162,8 @@ class DataStore(RoomMemberStore, RoomStore, self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth) @log_function - def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None): + def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, + is_new_state=True): if event.type == RoomMemberEvent.TYPE: self._store_room_member_txn(txn, event) elif event.type == FeedbackEvent.TYPE: @@ -171,13 +210,14 @@ class DataStore(RoomMemberStore, RoomStore, try: self._simple_insert_txn(txn, "events", vals) except: - logger.exception( + logger.warn( "Failed to persist, probably duplicate: %s", - event.event_id + event.event_id, + exc_info=True, ) - return + raise _RollbackButIsFineException("_persist_event") - if not backfilled and hasattr(event, "state_key"): + if is_new_state and hasattr(event, "state_key"): vals = { "event_id": event.event_id, "room_id": event.room_id, @@ -201,8 +241,6 @@ class DataStore(RoomMemberStore, RoomStore, } ) - return self._get_room_events_max_id_txn(txn) - @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): sql = ( @@ -220,7 +258,8 @@ class DataStore(RoomMemberStore, RoomStore, results = yield self._execute_and_decode(sql, *args) - defer.returnValue([self._parse_event_from_row(r) for r in results]) + events = yield self._parse_events(results) + defer.returnValue(events) @defer.inlineCallbacks def _get_min_token(self): @@ -269,7 +308,7 @@ class DataStore(RoomMemberStore, RoomStore, prev_state_pdu=prev_state_pdu, ) - return self._db_pool.runInteraction(_snapshot) + return self.runInteraction(_snapshot) class Snapshot(object): @@ -339,3 +378,42 @@ def read_schema(schema): """ with open(schema_path(schema)) as schema_file: return schema_file.read() + + +def prepare_database(db_conn): + """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we + don't have to worry about overwriting existing content. + """ + c = db_conn.cursor() + c.execute("PRAGMA user_version") + row = c.fetchone() + + if row and row[0]: + user_version = row[0] + + if user_version > SCHEMA_VERSION: + raise ValueError("Cannot use this database as it is too " + + "new for the server to understand" + ) + elif user_version < SCHEMA_VERSION: + logging.info("Upgrading database from version %d", + user_version + ) + + # Run every version since after the current version. + for v in range(user_version + 1, SCHEMA_VERSION + 1): + sql_script = read_schema("delta/v%d" % (v)) + c.executescript(sql_script) + + db_conn.commit() + + else: + for sql_loc in SCHEMAS: + sql_script = read_schema(sql_loc) + + c.executescript(sql_script) + db_conn.commit() + c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION) + + c.close() + diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index bae50e7d1f..76ed7d06fb 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.util.logutils import log_function import collections import copy @@ -25,6 +26,44 @@ import json logger = logging.getLogger(__name__) +sql_logger = logging.getLogger("synapse.storage.SQL") + + +class LoggingTransaction(object): + """An object that almost-transparently proxies for the 'txn' object + passed to the constructor. Adds logging to the .execute() method.""" + __slots__ = ["txn"] + + def __init__(self, txn): + object.__setattr__(self, "txn", txn) + + def __getattribute__(self, name): + if name == "execute": + return object.__getattribute__(self, "execute") + + return getattr(object.__getattribute__(self, "txn"), name) + + def __setattr__(self, name, value): + setattr(object.__getattribute__(self, "txn"), name, value) + + def execute(self, sql, *args, **kwargs): + # TODO(paul): Maybe use 'info' and 'debug' for values? + sql_logger.debug("[SQL] %s", sql) + try: + if args and args[0]: + values = args[0] + sql_logger.debug("[SQL values] " + + ", ".join(("<%s>",) * len(values)), *values) + except: + # Don't let logging failures stop SQL from working + pass + + # TODO(paul): Here would be an excellent place to put some timing + # measurements, and log (warning?) slow queries. + return object.__getattribute__(self, "txn").execute( + sql, *args, **kwargs + ) + class SQLBaseStore(object): @@ -34,6 +73,13 @@ class SQLBaseStore(object): self.event_factory = hs.get_event_factory() self._clock = hs.get_clock() + def runInteraction(self, func, *args, **kwargs): + """Wraps the .runInteraction() method on the underlying db_pool.""" + def inner_func(txn, *args, **kwargs): + return func(LoggingTransaction(txn), *args, **kwargs) + + return self._db_pool.runInteraction(inner_func, *args, **kwargs) + def cursor_to_dict(self, cursor): """Converts a SQL cursor into an list of dicts. @@ -59,11 +105,6 @@ class SQLBaseStore(object): Returns: The result of decoder(results) """ - logger.debug( - "[SQL] %s Args=%s Func=%s", - query, args, decoder.__name__ if decoder else None - ) - def interaction(txn): cursor = txn.execute(query, args) if decoder: @@ -71,7 +112,7 @@ class SQLBaseStore(object): else: return cursor.fetchall() - return self._db_pool.runInteraction(interaction) + return self.runInteraction(interaction) def _execute_and_decode(self, query, *args): return self._execute(self.cursor_to_dict, query, *args) @@ -87,10 +128,11 @@ class SQLBaseStore(object): values : dict of new column names and values for them or_replace : bool; if True performs an INSERT OR REPLACE """ - return self._db_pool.runInteraction( + return self.runInteraction( self._simple_insert_txn, table, values, or_replace=or_replace ) + @log_function def _simple_insert_txn(self, txn, table, values, or_replace=False): sql = "%s INTO %s (%s) VALUES(%s)" % ( ("INSERT OR REPLACE" if or_replace else "INSERT"), @@ -98,6 +140,12 @@ class SQLBaseStore(object): ", ".join(k for k in values), ", ".join("?" for k in values) ) + + logger.debug( + "[SQL] %s Args=%s Func=%s", + sql, values.values(), + ) + txn.execute(sql, values.values()) return txn.lastrowid @@ -164,7 +212,7 @@ class SQLBaseStore(object): txn.execute(sql, keyvalues.values()) return txn.fetchall() - res = yield self._db_pool.runInteraction(func) + res = yield self.runInteraction(func) defer.returnValue([r[0] for r in res]) @@ -187,7 +235,7 @@ class SQLBaseStore(object): txn.execute(sql, keyvalues.values()) return self.cursor_to_dict(txn) - return self._db_pool.runInteraction(func) + return self.runInteraction(func) def _simple_update_one(self, table, keyvalues, updatevalues, retcols=None): @@ -255,7 +303,7 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched") return ret - return self._db_pool.runInteraction(func) + return self.runInteraction(func) def _simple_delete_one(self, table, keyvalues): """Executes a DELETE query on the named table, expecting to delete a @@ -276,7 +324,7 @@ class SQLBaseStore(object): raise StoreError(404, "No row found") if txn.rowcount > 1: raise StoreError(500, "more than one row matched") - return self._db_pool.runInteraction(func) + return self.runInteraction(func) def _simple_max_id(self, table): """Executes a SELECT query on the named table, expecting to return the @@ -294,7 +342,7 @@ class SQLBaseStore(object): return 0 return max_id - return self._db_pool.runInteraction(func) + return self.runInteraction(func) def _parse_event_from_row(self, row_dict): d = copy.deepcopy({k: v for k, v in row_dict.items() if v}) @@ -307,11 +355,34 @@ class SQLBaseStore(object): d["content"] = json.loads(d["content"]) del d["unrecognized_keys"] + if "age_ts" not in d: + # For compatibility + d["age_ts"] = d["ts"] if "ts" in d else 0 + return self.event_factory.create_event( etype=d["type"], **d ) + def _parse_events(self, rows): + return self.runInteraction(self._parse_events_txn, rows) + + def _parse_events_txn(self, txn, rows): + events = [self._parse_event_from_row(r) for r in rows] + + sql = "SELECT * FROM events WHERE event_id = ?" + + for ev in events: + if hasattr(ev, "prev_state"): + # Load previous state_content. + # TODO: Should we be pulling this out above? + cursor = txn.execute(sql, (ev.prev_state,)) + prevs = self.cursor_to_dict(cursor) + if prevs: + prev = self._parse_event_from_row(prevs[0]) + ev.prev_content = prev.content + + return events class Table(object): """ A base class used to store information about a particular table. diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index bf55449253..540eb4c2c4 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -92,3 +92,10 @@ class DirectoryStore(SQLBaseStore): "server": server, } ) + + def get_aliases_for_room(self, room_id): + return self._simple_select_onecol( + "room_aliases", + {"room_id": room_id}, + "room_alias", + ) diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index 0bf97e37ee..d70467dcd6 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -17,6 +17,7 @@ from twisted.internet import defer from ._base import SQLBaseStore, Table, JoinHelper +from synapse.federation.units import Pdu from synapse.util.logutils import log_function from collections import namedtuple @@ -42,7 +43,7 @@ class PduStore(SQLBaseStore): PduTuple: If the pdu does not exist in the database, returns None """ - return self._db_pool.runInteraction( + return self.runInteraction( self._get_pdu_tuple, pdu_id, origin ) @@ -94,7 +95,7 @@ class PduStore(SQLBaseStore): list: A list of PduTuples """ - return self._db_pool.runInteraction( + return self.runInteraction( self._get_current_state_for_context, context ) @@ -142,7 +143,7 @@ class PduStore(SQLBaseStore): pdu_origin (str) """ - return self._db_pool.runInteraction( + return self.runInteraction( self._mark_as_processed, pdu_id, pdu_origin ) @@ -151,7 +152,7 @@ class PduStore(SQLBaseStore): def get_all_pdus_from_context(self, context): """Get a list of all PDUs for a given context.""" - return self._db_pool.runInteraction( + return self.runInteraction( self._get_all_pdus_from_context, context, ) @@ -178,7 +179,7 @@ class PduStore(SQLBaseStore): Return: list: A list of PduTuples """ - return self._db_pool.runInteraction( + return self.runInteraction( self._get_backfill, context, pdu_list, limit ) @@ -239,7 +240,7 @@ class PduStore(SQLBaseStore): txn context (str) """ - return self._db_pool.runInteraction( + return self.runInteraction( self._get_min_depth_for_context, context ) @@ -308,8 +309,8 @@ class PduStore(SQLBaseStore): @defer.inlineCallbacks def get_oldest_pdus_in_context(self, context): - """Get a list of Pdus that we haven't backfilled beyond yet (and haven't - seen). This list is used when we want to backfill backwards and is the + """Get a list of Pdus that we haven't backfilled beyond yet (and havent + seen). This list is used when we want to backfill backwards and is the list we send to the remote server. Args: @@ -345,7 +346,7 @@ class PduStore(SQLBaseStore): bool """ - return self._db_pool.runInteraction( + return self.runInteraction( self._is_pdu_new, pdu_id=pdu_id, origin=origin, @@ -498,7 +499,7 @@ class StatePduStore(SQLBaseStore): ) def get_unresolved_state_tree(self, new_state_pdu): - return self._db_pool.runInteraction( + return self.runInteraction( self._get_unresolved_state_tree, new_state_pdu ) @@ -516,7 +517,7 @@ class StatePduStore(SQLBaseStore): if not current: logger.debug("get_unresolved_state_tree No current state.") - return return_value + return (return_value, None) return_value.current_branch.append(current) @@ -524,17 +525,20 @@ class StatePduStore(SQLBaseStore): txn, new_pdu, current ) + missing_branch = None for branch, prev_state, state in enum_branches: if state: return_value[branch].append(state) else: + # We don't have prev_state :( + missing_branch = branch break - return return_value + return (return_value, missing_branch) def update_current_state(self, pdu_id, origin, context, pdu_type, state_key): - return self._db_pool.runInteraction( + return self.runInteraction( self._update_current_state, pdu_id, origin, context, pdu_type, state_key ) @@ -573,7 +577,7 @@ class StatePduStore(SQLBaseStore): PduEntry """ - return self._db_pool.runInteraction( + return self.runInteraction( self._get_current_state_pdu, context, pdu_type, state_key ) @@ -622,53 +626,6 @@ class StatePduStore(SQLBaseStore): return result - def get_next_missing_pdu(self, new_pdu): - """When we get a new state pdu we need to check whether we need to do - any conflict resolution, if we do then we need to check if we need - to go back and request some more state pdus that we haven't seen yet. - - Args: - txn - new_pdu - - Returns: - PduIdTuple: A pdu that we are missing, or None if we have all the - pdus required to do the conflict resolution. - """ - return self._db_pool.runInteraction( - self._get_next_missing_pdu, new_pdu - ) - - def _get_next_missing_pdu(self, txn, new_pdu): - logger.debug( - "get_next_missing_pdu %s %s", - new_pdu.pdu_id, new_pdu.origin - ) - - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - if (not current or not current.prev_state_id - or not current.prev_state_origin): - return None - - # Oh look, it's a straight clobber, so wooooo almost no-op. - if (new_pdu.prev_state_id == current.pdu_id - and new_pdu.prev_state_origin == current.origin): - return None - - enum_branches = self._enumerate_state_branches(txn, new_pdu, current) - for branch, prev_state, state in enum_branches: - if not state: - return PduIdTuple( - prev_state.prev_state_id, - prev_state.prev_state_origin - ) - - return None - def handle_new_state(self, new_pdu): """Actually perform conflict resolution on the new_pdu on the assumption we have all the pdus required to perform it. @@ -679,7 +636,7 @@ class StatePduStore(SQLBaseStore): Returns: bool: True if the new_pdu clobbered the current state, False if not """ - return self._db_pool.runInteraction( + return self.runInteraction( self._handle_new_state, new_pdu ) @@ -752,24 +709,11 @@ class StatePduStore(SQLBaseStore): return is_current - @classmethod @log_function - def _enumerate_state_branches(cls, txn, pdu_a, pdu_b): + def _enumerate_state_branches(self, txn, pdu_a, pdu_b): branch_a = pdu_a branch_b = pdu_b - get_query = ( - "SELECT %(fields)s FROM %(pdus)s as p " - "LEFT JOIN %(state)s as s " - "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " - "WHERE p.pdu_id = ? AND p.origin = ? " - ) % { - "fields": _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s"), - "pdus": PdusTable.table_name, - "state": StatePdusTable.table_name, - } - while True: if (branch_a.pdu_id == branch_b.pdu_id and branch_a.origin == branch_b.origin): @@ -801,13 +745,12 @@ class StatePduStore(SQLBaseStore): branch_a.prev_state_origin ) - logger.debug("getting branch_a prev %s", pdu_tuple) - txn.execute(get_query, pdu_tuple) - prev_branch = branch_a - res = txn.fetchone() - branch_a = PduEntry(*res) if res else None + logger.debug("getting branch_a prev %s", pdu_tuple) + branch_a = self._get_pdu_tuple(txn, *pdu_tuple) + if branch_a: + branch_a = Pdu.from_pdu_tuple(branch_a) logger.debug("branch_a=%s", branch_a) @@ -820,14 +763,13 @@ class StatePduStore(SQLBaseStore): branch_b.prev_state_id, branch_b.prev_state_origin ) - txn.execute(get_query, pdu_tuple) - - logger.debug("getting branch_b prev %s", pdu_tuple) prev_branch = branch_b - res = txn.fetchone() - branch_b = PduEntry(*res) if res else None + logger.debug("getting branch_b prev %s", pdu_tuple) + branch_b = self._get_pdu_tuple(txn, *pdu_tuple) + if branch_b: + branch_b = Pdu.from_pdu_tuple(branch_b) logger.debug("branch_b=%s", branch_b) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index fd762bc643..db20b1daa0 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if the user_id could not be registered. """ - yield self._db_pool.runInteraction(self._register, user_id, token, + yield self.runInteraction(self._register, user_id, token, password_hash) def _register(self, txn, user_id, token, password_hash): @@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if no user was found. """ - user_id = yield self._db_pool.runInteraction(self._query_for_auth, + user_id = yield self.runInteraction(self._query_for_auth, token) defer.returnValue(user_id) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 017169ce00..5adf8cdf1b 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -149,7 +149,7 @@ class RoomStore(SQLBaseStore): defer.returnValue(None) def get_power_level(self, room_id, user_id): - return self._db_pool.runInteraction( + return self.runInteraction( self._get_power_level, room_id, user_id, ) @@ -182,7 +182,7 @@ class RoomStore(SQLBaseStore): return None def get_ops_levels(self, room_id): - return self._db_pool.runInteraction( + return self.runInteraction( self._get_ops_levels, room_id, ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 75c9a60101..04b4067d03 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -18,6 +18,7 @@ from twisted.internet import defer from ._base import SQLBaseStore from synapse.api.constants import Membership +from synapse.util.logutils import log_function import logging @@ -29,8 +30,18 @@ class RoomMemberStore(SQLBaseStore): def _store_room_member_txn(self, txn, event): """Store a room member in the database. """ - target_user_id = event.state_key - domain = self.hs.parse_userid(target_user_id).domain + try: + target_user_id = event.state_key + domain = self.hs.parse_userid(target_user_id).domain + except: + logger.exception("Failed to parse target_user_id=%s", target_user_id) + raise + + logger.debug( + "_store_room_member_txn: target_user_id=%s, membership=%s", + target_user_id, + event.membership, + ) self._simple_insert_txn( txn, @@ -51,12 +62,30 @@ class RoomMemberStore(SQLBaseStore): "VALUES (?, ?)" ) txn.execute(sql, (event.room_id, domain)) - else: - sql = ( - "DELETE FROM room_hosts WHERE room_id = ? AND host = ?" + elif event.membership != Membership.INVITE: + # Check if this was the last person to have left. + member_events = self._get_members_query_txn( + txn, + where_clause="c.room_id = ? AND m.membership = ? AND m.user_id != ?", + where_values=(event.room_id, Membership.JOIN, target_user_id,) ) - txn.execute(sql, (event.room_id, domain)) + joined_domains = set() + for e in member_events: + try: + joined_domains.add( + self.hs.parse_userid(e.state_key).domain + ) + except: + # FIXME: How do we deal with invalid user ids in the db? + logger.exception("Invalid user_id: %s", event.state_key) + + if domain not in joined_domains: + sql = ( + "DELETE FROM room_hosts WHERE room_id = ? AND host = ?" + ) + + txn.execute(sql, (event.room_id, domain)) @defer.inlineCallbacks def get_room_member(self, user_id, room_id): @@ -88,7 +117,7 @@ class RoomMemberStore(SQLBaseStore): txn.execute(sql, (user_id, room_id)) rows = self.cursor_to_dict(txn) if rows: - return self._parse_event_from_row(rows[0]) + return self._parse_events_txn(txn, rows)[0] else: return None @@ -120,7 +149,7 @@ class RoomMemberStore(SQLBaseStore): membership_list (list): A list of synapse.api.constants.Membership values which the user must be in. Returns: - A list of dicts with "room_id" and "membership" keys. + A list of RoomMemberEvent objects """ if not membership_list: return defer.succeed(None) @@ -146,8 +175,13 @@ class RoomMemberStore(SQLBaseStore): vals = where_dict.values() return self._get_members_query(clause, vals) - @defer.inlineCallbacks def _get_members_query(self, where_clause, where_values): + return self._db_pool.runInteraction( + self._get_members_query_txn, + where_clause, where_values + ) + + def _get_members_query_txn(self, txn, where_clause, where_values): sql = ( "SELECT e.* FROM events as e " "INNER JOIN room_memberships as m " @@ -157,18 +191,18 @@ class RoomMemberStore(SQLBaseStore): "WHERE %s " ) % (where_clause,) - rows = yield self._execute_and_decode(sql, *where_values) - - # logger.debug("_get_members_query Got rows %s", rows) + txn.execute(sql, where_values) + rows = self.cursor_to_dict(txn) - results = [self._parse_event_from_row(r) for r in rows] - defer.returnValue(results) + results = self._parse_events_txn(txn, rows) + return results @defer.inlineCallbacks - def user_rooms_intersect(self, user_list): - """ Checks whether a list of users share a room. + def user_rooms_intersect(self, user_id_list): + """ Checks whether all the users whose IDs are given in a list share a + room. """ - user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_list)) + user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list)) sql = ( "SELECT m.room_id FROM room_memberships as m " "INNER JOIN current_state_events as c " @@ -178,8 +212,8 @@ class RoomMemberStore(SQLBaseStore): "GROUP BY m.room_id HAVING COUNT(m.room_id) = ?" ) % {"clause": user_list_clause} - args = user_list - args.append(len(user_list)) + args = list(user_id_list) + args.append(len(user_id_list)) rows = yield self._execute(None, sql, *args) diff --git a/synapse/storage/schema/delta/v3.sql b/synapse/storage/schema/delta/v3.sql new file mode 100644 index 0000000000..cade295989 --- /dev/null +++ b/synapse/storage/schema/delta/v3.sql @@ -0,0 +1,27 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE INDEX IF NOT EXISTS room_aliases_alias ON room_aliases(room_alias); +CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id); + + +CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias); + +DELETE FROM room_aliases WHERE rowid NOT IN (SELECT max(rowid) FROM room_aliases GROUP BY room_alias, room_id); + +CREATE UNIQUE INDEX IF NOT EXISTS room_aliases_uniq ON room_aliases(room_alias, room_id); + +PRAGMA user_version = 3; diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 2cb0067a67..a76fecf24f 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -146,7 +146,7 @@ class StreamStore(SQLBaseStore): current_room_membership_sql = ( "SELECT m.room_id FROM room_memberships as m " "INNER JOIN current_state_events as c ON m.event_id = c.event_id " - "WHERE m.user_id = ?" + "WHERE m.user_id = ? AND m.membership = 'join'" ) # We also want to get any membership events about that user, e.g. @@ -188,7 +188,7 @@ class StreamStore(SQLBaseStore): user_id, user_id, from_id, to_id ) - ret = [self._parse_event_from_row(r) for r in rows] + ret = yield self._parse_events(rows) if rows: key = "s%d" % max([r["stream_ordering"] for r in rows]) @@ -243,9 +243,11 @@ class StreamStore(SQLBaseStore): # TODO (erikj): We should work out what to do here instead. next_token = to_key if to_key else from_key + events = yield self._parse_events(rows) + defer.returnValue( ( - [self._parse_event_from_row(r) for r in rows], + events, next_token ) ) @@ -277,15 +279,14 @@ class StreamStore(SQLBaseStore): else: token = (end_token, end_token) - defer.returnValue( - ( - [self._parse_event_from_row(r) for r in rows], - token - ) - ) + events = yield self._parse_events(rows) + + ret = (events, token) + + defer.returnValue(ret) def get_room_events_max_id(self): - return self._db_pool.runInteraction(self._get_room_events_max_id_txn) + return self.runInteraction(self._get_room_events_max_id_txn) def _get_room_events_max_id_txn(self, txn): txn.execute( diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 7467e1035b..ab4599b468 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -41,7 +41,7 @@ class TransactionStore(SQLBaseStore): this transaction or a 2-tuple of (int, dict) """ - return self._db_pool.runInteraction( + return self.runInteraction( self._get_received_txn_response, transaction_id, origin ) @@ -72,7 +72,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ - return self._db_pool.runInteraction( + return self.runInteraction( self._set_received_txn_response, transaction_id, origin, code, response_dict ) @@ -104,7 +104,7 @@ class TransactionStore(SQLBaseStore): list: A list of previous transaction ids. """ - return self._db_pool.runInteraction( + return self.runInteraction( self._prep_send_transaction, transaction_id, destination, ts, pdu_list ) @@ -159,7 +159,7 @@ class TransactionStore(SQLBaseStore): code (int) response_json (str) """ - return self._db_pool.runInteraction( + return self.runInteraction( self._delivered_txn, transaction_id, destination, code, response_dict ) @@ -184,7 +184,7 @@ class TransactionStore(SQLBaseStore): Returns: list: A list of `ReceivedTransactionsTable.EntryType` """ - return self._db_pool.runInteraction( + return self.runInteraction( self._get_transactions_after, transaction_id, destination ) @@ -214,7 +214,7 @@ class TransactionStore(SQLBaseStore): Returns list: A list of PduTuple """ - return self._db_pool.runInteraction( + return self.runInteraction( self._get_pdus_after_transaction, transaction_id, destination ) diff --git a/synapse/util/emailutils.py b/synapse/util/emailutils.py new file mode 100644 index 0000000000..cdb0abd7ea --- /dev/null +++ b/synapse/util/emailutils.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" This module allows you to send out emails. +""" +import email.utils +import smtplib +import twisted.python.log +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart + +import logging + +logger = logging.getLogger(__name__) + + +class EmailException(Exception): + pass + + +def send_email(smtp_server, from_addr, to_addr, subject, body): + """Sends an email. + + Args: + smtp_server(str): The SMTP server to use. + from_addr(str): The address to send from. + to_addr(str): The address to send to. + subject(str): The subject of the email. + body(str): The plain text body of the email. + Raises: + EmailException if there was a problem sending the mail. + """ + if not smtp_server or not from_addr or not to_addr: + raise EmailException("Need SMTP server, from and to addresses. Check " + + "the config to set these.") + + msg = MIMEMultipart('alternative') + msg['Subject'] = subject + msg['From'] = from_addr + msg['To'] = to_addr + plain_part = MIMEText(body) + msg.attach(plain_part) + + raw_from = email.utils.parseaddr(from_addr)[1] + raw_to = email.utils.parseaddr(to_addr)[1] + if not raw_from or not raw_to: + raise EmailException("Couldn't parse from/to address.") + + logger.info("Sending email to %s on server %s with subject %s", + to_addr, smtp_server, subject) + + try: + smtp = smtplib.SMTP(smtp_server) + smtp.sendmail(raw_from, raw_to, msg.as_string()) + smtp.quit() + except Exception as origException: + twisted.python.log.err() + ese = EmailException() + ese.cause = origException + raise ese \ No newline at end of file |