diff options
57 files changed, 1036 insertions, 731 deletions
diff --git a/CHANGES.rst b/CHANGES.rst index 5a284c3853..3cd08938a2 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,11 @@ +Changes in synapse 0.5.1 (2014-11-26) +===================================== +See UPGRADES.rst for specific instructions on how to upgrade. + + * Fix bug where we served up an Event that did not match its signatures. + * Fix regression where we no longer correctly handled the case where a + homeserver receives an event for a room it doesn't recognise (but is in.) + Changes in synapse 0.5.0 (2014-11-19) ===================================== This release includes changes to the federation protocol and client-server API diff --git a/README.rst b/README.rst index 542f199874..5e020081a8 100644 --- a/README.rst +++ b/README.rst @@ -69,8 +69,8 @@ command line utility which lets you easily see what the JSON APIs are up to). Meanwhile, iOS and Android SDKs and clients are currently in development and available from: - * https://github.com/matrix-org/matrix-ios-sdk - * https://github.com/matrix-org/matrix-android-sdk +- https://github.com/matrix-org/matrix-ios-sdk +- https://github.com/matrix-org/matrix-android-sdk We'd like to invite you to join #matrix:matrix.org (via http://matrix.org/alpha), run a homeserver, take a look at the Matrix spec at http://matrix.org/docs/spec, experiment with the APIs and the demo @@ -94,7 +94,7 @@ header files for python C extensions. Installing prerequisites on Ubuntu or Debian:: $ sudo apt-get install build-essential python2.7-dev libffi-dev \ - python-pip python-setuptools + python-pip python-setuptools sqlite3 Installing prerequisites on Mac OS X:: @@ -125,7 +125,7 @@ created. To reset the installation:: pip seems to leak *lots* of memory during installation. For instance, a Linux host with 512MB of RAM may run out of memory whilst installing Twisted. If this happens, you will have to individually install the dependencies which are -failing, e.g.: +failing, e.g.:: $ pip install --user twisted @@ -148,7 +148,7 @@ Troubleshooting Running ----------------------- If ``synctl`` fails with ``pkg_resources.DistributionNotFound`` errors you may -need a newer version of setuptools than that provided by your OS. +need a newer version of setuptools than that provided by your OS.:: $ sudo pip install setuptools --upgrade @@ -172,7 +172,7 @@ Homeserver Development ====================== To check out a homeserver for development, clone the git repo into a working -directory of your choice: +directory of your choice:: $ git clone https://github.com/matrix-org/synapse.git $ cd synapse diff --git a/UPGRADE.rst b/UPGRADE.rst index 961f4da31c..5ebdd455c1 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -1,3 +1,12 @@ +Upgrading to v0.5.1 +=================== + +Depending on precisely when you installed v0.5.0 you may have ended up with +a stale release of the reference matrix webclient installed as a python module. +To uninstall it and ensure you are depending on the latest module, please run:: + + $ pip uninstall syweb + Upgrading to v0.5.0 =================== diff --git a/VERSION b/VERSION index 8f0916f768..4b9fcbec10 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.5.0 +0.5.1 diff --git a/scripts/check_signature.py b/scripts/check_signature.py index e146e18e24..59e3d603ac 100644 --- a/scripts/check_signature.py +++ b/scripts/check_signature.py @@ -23,7 +23,7 @@ def get_targets(server_name): for srv in answers: yield (srv.target, srv.port) except dns.resolver.NXDOMAIN: - yield (server_name, 8480) + yield (server_name, 8448) def get_server_keys(server_name, target, port): url = "https://%s:%i/_matrix/key/v1" % (target, port) diff --git a/setup.py b/setup.py index d0d649612d..6b4320f0c9 100755 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ setup( description="Reference Synapse Home Server", install_requires=[ "syutil==0.0.2", - "matrix_angular_sdk==0.5.0", + "matrix_angular_sdk==0.5.1", "Twisted>=14.0.0", "service_identity>=1.0.0", "pyopenssl>=0.14", @@ -45,7 +45,7 @@ setup( dependency_links=[ "https://github.com/matrix-org/syutil/tarball/v0.0.2#egg=syutil-0.0.2", "https://github.com/pyca/pynacl/tarball/52dbe2dc33f1#egg=pynacl-0.3.0", - "https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.5.0/#egg=matrix_angular_sdk-0.5.0", + "https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.5.1/#egg=matrix_angular_sdk-0.5.1", ], setup_requires=[ "setuptools_trial", diff --git a/synapse/__init__.py b/synapse/__init__.py index 14564e735e..1c10c2074e 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a synapse home server. """ -__version__ = "0.5.0" +__version__ = "0.5.1" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 6d8a9e4df7..fb911e51a6 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -38,79 +38,66 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - def check(self, event, raises=False): + def check(self, event, auth_events): """ Checks if this event is correctly authed. Returns: True if the auth checks pass. - Raises: - AuthError if there was a problem authorising this event. This will - be raised only if raises=True. """ try: - if hasattr(event, "room_id"): - if event.old_state_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) - return True - - if hasattr(event, "outlier") and event.outlier is True: - # TODO (erikj): Auth for outliers is done differently. - return True + if not hasattr(event, "room_id"): + raise AuthError(500, "Event has no room_id: %s" % event) + if auth_events is None: + # Oh, we don't know what the state of the room was, so we + # are trusting that this is allowed (at least for now) + logger.warn("Trusting event: %s", event.event_id) + return True - if event.type == RoomCreateEvent.TYPE: - # FIXME - return True + if event.type == RoomCreateEvent.TYPE: + # FIXME + return True - # FIXME: Temp hack - if event.type == RoomAliasesEvent.TYPE: - return True + # FIXME: Temp hack + if event.type == RoomAliasesEvent.TYPE: + return True - if event.type == RoomMemberEvent.TYPE: - allowed = self.is_membership_change_allowed(event) - if allowed: - logger.debug("Allowing! %s", event) - else: - logger.debug("Denying! %s", event) - return allowed + if event.type == RoomMemberEvent.TYPE: + allowed = self.is_membership_change_allowed( + event, auth_events + ) + if allowed: + logger.debug("Allowing! %s", event) + else: + logger.debug("Denying! %s", event) + return allowed - self.check_event_sender_in_room(event) - self._can_send_event(event) + self.check_event_sender_in_room(event, auth_events) + self._can_send_event(event, auth_events) - if event.type == RoomPowerLevelsEvent.TYPE: - self._check_power_levels(event) + if event.type == RoomPowerLevelsEvent.TYPE: + self._check_power_levels(event, auth_events) - if event.type == RoomRedactionEvent.TYPE: - self._check_redaction(event) + if event.type == RoomRedactionEvent.TYPE: + self._check_redaction(event, auth_events) - logger.debug("Allowing! %s", event) - return True - else: - raise AuthError(500, "Unknown event: %s" % event) + logger.debug("Allowing! %s", event) except AuthError as e: logger.info( "Event auth check failed on event %s with msg: %s", event, e.msg ) logger.info("Denying! %s", event) - if raises: - raise - - return False + raise @defer.inlineCallbacks def check_joined_room(self, room_id, user_id): - try: - member = yield self.store.get_room_member( - room_id=room_id, - user_id=user_id - ) - self._check_joined_room(member, user_id, room_id) - defer.returnValue(member) - except AttributeError: - pass - defer.returnValue(None) + member = yield self.state.get_current_state( + room_id=room_id, + event_type=RoomMemberEvent.TYPE, + state_key=user_id + ) + self._check_joined_room(member, user_id, room_id) + defer.returnValue(member) @defer.inlineCallbacks def check_host_in_room(self, room_id, host): @@ -130,9 +117,9 @@ class Auth(object): defer.returnValue(False) - def check_event_sender_in_room(self, event): + def check_event_sender_in_room(self, event, auth_events): key = (RoomMemberEvent.TYPE, event.user_id, ) - member_event = event.state_events.get(key) + member_event = auth_events.get(key) return self._check_joined_room( member_event, @@ -147,15 +134,15 @@ class Auth(object): )) @log_function - def is_membership_change_allowed(self, event): + def is_membership_change_allowed(self, event, auth_events): membership = event.content["membership"] # Check if this is the room creator joining: if len(event.prev_events) == 1 and Membership.JOIN == membership: # Get room creation event: key = (RoomCreateEvent.TYPE, "", ) - create = event.old_state_events.get(key) - if event.prev_events[0][0] == create.event_id: + create = auth_events.get(key) + if create and event.prev_events[0][0] == create.event_id: if create.content["creator"] == event.state_key: return True @@ -163,19 +150,19 @@ class Auth(object): # get info about the caller key = (RoomMemberEvent.TYPE, event.user_id, ) - caller = event.old_state_events.get(key) + caller = auth_events.get(key) caller_in_room = caller and caller.membership == Membership.JOIN caller_invited = caller and caller.membership == Membership.INVITE # get info about the target key = (RoomMemberEvent.TYPE, target_user_id, ) - target = event.old_state_events.get(key) + target = auth_events.get(key) target_in_room = target and target.membership == Membership.JOIN key = (RoomJoinRulesEvent.TYPE, "", ) - join_rule_event = event.old_state_events.get(key) + join_rule_event = auth_events.get(key) if join_rule_event: join_rule = join_rule_event.content.get( "join_rule", JoinRules.INVITE @@ -186,11 +173,13 @@ class Auth(object): user_level = self._get_power_level_from_event_state( event, event.user_id, + auth_events, ) ban_level, kick_level, redact_level = ( self._get_ops_level_from_event_state( - event + event, + auth_events, ) ) @@ -260,9 +249,9 @@ class Auth(object): return True - def _get_power_level_from_event_state(self, event, user_id): + def _get_power_level_from_event_state(self, event, user_id, auth_events): key = (RoomPowerLevelsEvent.TYPE, "", ) - power_level_event = event.old_state_events.get(key) + power_level_event = auth_events.get(key) level = None if power_level_event: level = power_level_event.content.get("users", {}).get(user_id) @@ -270,16 +259,16 @@ class Auth(object): level = power_level_event.content.get("users_default", 0) else: key = (RoomCreateEvent.TYPE, "", ) - create_event = event.old_state_events.get(key) + create_event = auth_events.get(key) if (create_event is not None and - create_event.content["creator"] == user_id): + create_event.content["creator"] == user_id): return 100 return level - def _get_ops_level_from_event_state(self, event): + def _get_ops_level_from_event_state(self, event, auth_events): key = (RoomPowerLevelsEvent.TYPE, "", ) - power_level_event = event.old_state_events.get(key) + power_level_event = auth_events.get(key) if power_level_event: return ( @@ -375,6 +364,11 @@ class Auth(object): key = (RoomMemberEvent.TYPE, event.user_id, ) member_event = event.old_state_events.get(key) + key = (RoomCreateEvent.TYPE, "", ) + create_event = event.old_state_events.get(key) + if create_event: + auth_events.append(create_event.event_id) + if join_rule_event: join_rule = join_rule_event.content.get("join_rule") is_public = join_rule == JoinRules.PUBLIC if join_rule else False @@ -406,9 +400,9 @@ class Auth(object): event.auth_events = zip(auth_events, hashes) @log_function - def _can_send_event(self, event): + def _can_send_event(self, event, auth_events): key = (RoomPowerLevelsEvent.TYPE, "", ) - send_level_event = event.old_state_events.get(key) + send_level_event = auth_events.get(key) send_level = None if send_level_event: send_level = send_level_event.content.get("events", {}).get( @@ -432,6 +426,7 @@ class Auth(object): user_level = self._get_power_level_from_event_state( event, event.user_id, + auth_events, ) if user_level: @@ -468,14 +463,16 @@ class Auth(object): return True - def _check_redaction(self, event): + def _check_redaction(self, event, auth_events): user_level = self._get_power_level_from_event_state( event, event.user_id, + auth_events, ) _, _, redact_level = self._get_ops_level_from_event_state( - event + event, + auth_events, ) if user_level < redact_level: @@ -484,7 +481,7 @@ class Auth(object): "You don't have permission to redact events" ) - def _check_power_levels(self, event): + def _check_power_levels(self, event, auth_events): user_list = event.content.get("users", {}) # Validate users for k, v in user_list.items(): @@ -499,7 +496,7 @@ class Auth(object): raise SynapseError(400, "Not a valid power level: %s" % (v,)) key = (event.type, event.state_key, ) - current_state = event.old_state_events.get(key) + current_state = auth_events.get(key) if not current_state: return @@ -507,6 +504,7 @@ class Auth(object): user_level = self._get_power_level_from_event_state( event, event.user_id, + auth_events, ) # Check other levels: diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 33d15072af..581439ceb3 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -17,6 +17,8 @@ import logging +logger = logging.getLogger(__name__) + class Codes(object): UNAUTHORIZED = "M_UNAUTHORIZED" @@ -38,7 +40,7 @@ class CodeMessageException(Exception): """An exception with integer code and message string attributes.""" def __init__(self, code, msg): - logging.error("%s: %s, %s", type(self).__name__, code, msg) + logger.info("%s: %s, %s", type(self).__name__, code, msg) super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) self.code = code self.msg = msg @@ -140,7 +142,8 @@ def cs_exception(exception): if isinstance(exception, CodeMessageException): return exception.error_dict() else: - logging.error("Unknown exception type: %s", type(exception)) + logger.error("Unknown exception type: %s", type(exception)) + return {} def cs_error(msg, code=Codes.UNKNOWN, **kwargs): diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py index a01c4a1351..22939d011a 100644 --- a/synapse/api/events/__init__.py +++ b/synapse/api/events/__init__.py @@ -83,6 +83,8 @@ class SynapseEvent(JsonEncodedObject): "content", ] + outlier = False + def __init__(self, raises=True, **kwargs): super(SynapseEvent, self).__init__(**kwargs) # if "content" in kwargs: @@ -123,6 +125,7 @@ class SynapseEvent(JsonEncodedObject): pdu_json.pop("outlier", None) pdu_json.pop("replaces_state", None) pdu_json.pop("redacted", None) + pdu_json.pop("prev_content", None) state_hash = pdu_json.pop("state_hash", None) if state_hash is not None: pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash diff --git a/synapse/api/events/validator.py b/synapse/api/events/validator.py index 2d4f2a3aa7..067215f6ef 100644 --- a/synapse/api/events/validator.py +++ b/synapse/api/events/validator.py @@ -84,4 +84,4 @@ class EventValidator(object): template[key][0] ) if msg: - return msg \ No newline at end of file + return msg diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 85284a4919..855fe8e170 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -26,7 +26,7 @@ from twisted.web.server import Site from synapse.http.server import JsonResource, RootRedirect from synapse.http.content_repository import ContentRepoResource from synapse.http.server_key_resource import LocalKey -from synapse.http.client import MatrixHttpClient +from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.api.urls import ( CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, SERVER_KEY_PREFIX, @@ -51,7 +51,7 @@ logger = logging.getLogger(__name__) class SynapseHomeServer(HomeServer): def build_http_client(self): - return MatrixHttpClient(self) + return MatrixFederationHttpClient(self) def build_resource_for_client(self): return JsonResource() @@ -116,7 +116,7 @@ class SynapseHomeServer(HomeServer): # extra resources to existing nodes. See self._resource_id for the key. resource_mappings = {} for (full_path, resource) in desired_tree: - logging.info("Attaching %s to path %s", resource, full_path) + logger.info("Attaching %s to path %s", resource, full_path) last_resource = self.root_resource for path_seg in full_path.split('/')[1:-1]: if not path_seg in last_resource.listNames(): @@ -221,12 +221,12 @@ def setup(): db_name = hs.get_db_name() - logging.info("Preparing database: %s...", db_name) + logger.info("Preparing database: %s...", db_name) with sqlite3.connect(db_name) as db_conn: prepare_database(db_conn) - logging.info("Database prepared in %s.", db_name) + logger.info("Database prepared in %s.", db_name) hs.get_db_pool() @@ -257,13 +257,16 @@ def setup(): else: reactor.run() + def run(): with LoggingContext("run"): reactor.run() + def main(): with LoggingContext("main"): setup() + if __name__ == '__main__': main() diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index abe055a64c..52a0b729f4 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -21,11 +21,12 @@ import signal SYNAPSE = ["python", "-m", "synapse.app.homeserver"] -CONFIGFILE="homeserver.yaml" -PIDFILE="homeserver.pid" +CONFIGFILE = "homeserver.yaml" +PIDFILE = "homeserver.pid" + +GREEN = "\x1b[1;32m" +NORMAL = "\x1b[m" -GREEN="\x1b[1;32m" -NORMAL="\x1b[m" def start(): if not os.path.exists(CONFIGFILE): @@ -43,12 +44,14 @@ def start(): subprocess.check_call(args) print GREEN + "started" + NORMAL + def stop(): if os.path.exists(PIDFILE): pid = int(open(PIDFILE).read()) os.kill(pid, signal.SIGTERM) print GREEN + "stopped" + NORMAL + def main(): action = sys.argv[1] if sys.argv[1:] else "usage" if action == "start": @@ -62,5 +65,6 @@ def main(): sys.stderr.write("Usage: %s [start|stop|restart]\n" % (sys.argv[0],)) sys.exit(1) -if __name__=='__main__': + +if __name__ == "__main__": main() diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 4dff2c0ec2..a9d8953239 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) def check_event_content_hash(event, hash_algorithm=hashlib.sha256): """Check whether the hash for this PDU matches the contents""" computed_hash = _compute_content_hash(event, hash_algorithm) - logging.debug("Expecting hash: %s", encode_base64(computed_hash.digest())) + logger.debug("Expecting hash: %s", encode_base64(computed_hash.digest())) if computed_hash.name not in event.hashes: raise SynapseError( 400, diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index bb1f400b54..3f37c99261 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -17,7 +17,7 @@ from twisted.web.http import HTTPClient from twisted.internet.protocol import Factory from twisted.internet import defer, reactor -from synapse.http.endpoint import matrix_endpoint +from synapse.http.endpoint import matrix_federation_endpoint from synapse.util.logcontext import PreserveLoggingContext import json import logging @@ -31,7 +31,7 @@ def fetch_server_key(server_name, ssl_context_factory): """Fetch the keys for a remote server.""" factory = SynapseKeyClientFactory() - endpoint = matrix_endpoint( + endpoint = matrix_federation_endpoint( reactor, server_name, ssl_context_factory, timeout=30 ) @@ -48,7 +48,7 @@ def fetch_server_key(server_name, ssl_context_factory): class SynapseKeyClientError(Exception): - """The key wasn't retireved from the remote server.""" + """The key wasn't retrieved from the remote server.""" pass diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 694aed3a7d..ceb03ce6c2 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -135,7 +135,7 @@ class Keyring(object): time_now_ms = self.clock.time_msec() - self.store.store_server_certificate( + yield self.store.store_server_certificate( server_name, server_name, time_now_ms, @@ -143,7 +143,7 @@ class Keyring(object): ) for key_id, key in verify_keys.items(): - self.store.store_server_verify_key( + yield self.store.store_server_verify_key( server_name, server_name, time_now_ms, key ) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 65a53ae17c..6bfb30b42d 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -24,6 +24,7 @@ from .units import Transaction, Edu from .persistence import TransactionActions from synapse.util.logutils import log_function +from synapse.util.logcontext import PreserveLoggingContext import logging @@ -319,19 +320,20 @@ class ReplicationLayer(object): logger.debug("[%s] Transacition is new", transaction.transaction_id) - dl = [] - for pdu in pdu_list: - dl.append(self._handle_new_pdu(transaction.origin, pdu)) + with PreserveLoggingContext(): + dl = [] + for pdu in pdu_list: + dl.append(self._handle_new_pdu(transaction.origin, pdu)) - if hasattr(transaction, "edus"): - for edu in [Edu(**x) for x in transaction.edus]: - self.received_edu( - transaction.origin, - edu.edu_type, - edu.content - ) + if hasattr(transaction, "edus"): + for edu in [Edu(**x) for x in transaction.edus]: + self.received_edu( + transaction.origin, + edu.edu_type, + edu.content + ) - results = yield defer.DeferredList(dl) + results = yield defer.DeferredList(dl) ret = [] for r in results: @@ -425,7 +427,9 @@ class ReplicationLayer(object): time_now = self._clock.time_msec() defer.returnValue((200, { "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], + "auth_chain": [ + p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] + ], })) @defer.inlineCallbacks @@ -436,7 +440,9 @@ class ReplicationLayer(object): ( 200, { - "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], + "auth_chain": [ + a.get_pdu_json(time_now) for a in auth_pdus + ], } ) ) @@ -457,7 +463,7 @@ class ReplicationLayer(object): @defer.inlineCallbacks def send_join(self, destination, pdu): - time_now = self._clock.time_msec() + time_now = self._clock.time_msec() _, content = yield self.transport_layer.send_join( destination, pdu.room_id, @@ -475,11 +481,17 @@ class ReplicationLayer(object): # FIXME: We probably want to do something with the auth_chain given # to us - # auth_chain = [ - # Pdu(outlier=True, **p) for p in content.get("auth_chain", []) - # ] + auth_chain = [ + self.event_from_pdu_json(p, outlier=True) + for p in content.get("auth_chain", []) + ] - defer.returnValue(state) + auth_chain.sort(key=lambda e: e.depth) + + defer.returnValue({ + "state": state, + "auth_chain": auth_chain, + }) @defer.inlineCallbacks def send_invite(self, destination, context, event_id, pdu): @@ -498,13 +510,15 @@ class ReplicationLayer(object): defer.returnValue(self.event_from_pdu_json(pdu_dict)) @log_function - def _get_persisted_pdu(self, origin, event_id): + def _get_persisted_pdu(self, origin, event_id, do_auth=True): """ Get a PDU from the database with given origin and id. Returns: Deferred: Results in a `Pdu`. """ - return self.handler.get_persisted_pdu(origin, event_id) + return self.handler.get_persisted_pdu( + origin, event_id, do_auth=do_auth + ) def _transaction_from_pdus(self, pdu_list): """Returns a new Transaction containing the given PDUs suitable for @@ -523,7 +537,9 @@ class ReplicationLayer(object): @log_function def _handle_new_pdu(self, origin, pdu, backfilled=False): # We reprocess pdus when we have seen them only as outliers - existing = yield self._get_persisted_pdu(origin, pdu.event_id) + existing = yield self._get_persisted_pdu( + origin, pdu.event_id, do_auth=False + ) if existing and (not existing.outlier or pdu.outlier): logger.debug("Already seen pdu %s", pdu.event_id) @@ -532,6 +548,36 @@ class ReplicationLayer(object): state = None + # We need to make sure we have all the auth events. + for e_id, _ in pdu.auth_events: + exists = yield self._get_persisted_pdu( + origin, + e_id, + do_auth=False + ) + + if not exists: + try: + logger.debug( + "_handle_new_pdu fetch missing auth event %s from %s", + e_id, + origin, + ) + + yield self.get_pdu( + origin, + event_id=e_id, + outlier=True, + ) + + logger.debug("Processed pdu %s", e_id) + except: + logger.warn( + "Failed to get auth event %s from %s", + e_id, + origin + ) + # Get missing pdus if necessary. if not pdu.outlier: # We only backfill backwards to the min depth. @@ -539,16 +585,28 @@ class ReplicationLayer(object): pdu.room_id ) + logger.debug( + "_handle_new_pdu min_depth for %s: %d", + pdu.room_id, min_depth + ) + if min_depth and pdu.depth > min_depth: for event_id, hashes in pdu.prev_events: - exists = yield self._get_persisted_pdu(origin, event_id) + exists = yield self._get_persisted_pdu( + origin, + event_id, + do_auth=False + ) if not exists: - logger.debug("Requesting pdu %s", event_id) + logger.debug( + "_handle_new_pdu requesting pdu %s", + event_id + ) try: yield self.get_pdu( - pdu.origin, + origin, event_id=event_id, ) logger.debug("Processed pdu %s", event_id) @@ -558,6 +616,10 @@ class ReplicationLayer(object): else: # We need to get the state at this event, since we have reached # a backward extremity edge. + logger.debug( + "_handle_new_pdu getting state for %s", + pdu.room_id + ) state = yield self.get_state_for_context( origin, pdu.room_id, pdu.event_id, ) @@ -649,7 +711,8 @@ class _TransactionQueue(object): (pdu, deferred, order) ) - self._attempt_new_transaction(destination) + with PreserveLoggingContext(): + self._attempt_new_transaction(destination) deferreds.append(deferred) @@ -669,7 +732,9 @@ class _TransactionQueue(object): deferred.errback(failure) else: logger.exception("Failed to send edu", failure) - self._attempt_new_transaction(destination).addErrback(eb) + + with PreserveLoggingContext(): + self._attempt_new_transaction(destination).addErrback(eb) return deferred diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 6e708edb8c..1bcd0548c2 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -25,7 +25,6 @@ import logging logger = logging.getLogger(__name__) - class Edu(JsonEncodedObject): """ An Edu represents a piece of data sent from one homeserver to another. diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 30c6733063..15adc9dc2c 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -78,7 +78,7 @@ class BaseHandler(object): if not suppress_auth: logger.debug("Authing...") - self.auth.check(event, raises=True) + self.auth.check(event, auth_events=event.old_state_events) logger.debug("Authed") else: logger.debug("Suppressed auth.") @@ -112,7 +112,7 @@ class BaseHandler(object): event.destinations = list(destinations) - self.notifier.on_new_room_event(event, extra_users=extra_users) + yield self.notifier.on_new_room_event(event, extra_users=extra_users) federation_handler = self.hs.get_handlers().federation_handler yield federation_handler.handle_new_event(event, snapshot) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index ed9b0f8551..3b37e49e6f 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -17,7 +17,7 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, Codes, CodeMessageException from synapse.api.events.room import RoomAliasesEvent import logging @@ -84,22 +84,32 @@ class DirectoryHandler(BaseHandler): room_id = result.room_id servers = result.servers else: - result = yield self.federation.make_query( - destination=room_alias.domain, - query_type="directory", - args={ - "room_alias": room_alias.to_string(), - }, - retry_on_dns_fail=False, - ) + try: + result = yield self.federation.make_query( + destination=room_alias.domain, + query_type="directory", + args={ + "room_alias": room_alias.to_string(), + }, + retry_on_dns_fail=False, + ) + except CodeMessageException as e: + logging.warn("Error retrieving alias") + if e.code == 404: + result = None + else: + raise if result and "room_id" in result and "servers" in result: room_id = result["room_id"] servers = result["servers"] if not room_id: - defer.returnValue({}) - return + raise SynapseError( + 404, + "Room alias %r not found" % (room_alias.to_string(),), + Codes.NOT_FOUND + ) extra_servers = yield self.store.get_joined_hosts_for_room(room_id) servers = list(set(extra_servers) | set(servers)) @@ -128,8 +138,11 @@ class DirectoryHandler(BaseHandler): "servers": result.servers, }) else: - raise SynapseError(404, "Room alias \"%s\" not found", room_alias) - + raise SynapseError( + 404, + "Room alias %r not found" % (room_alias.to_string(),), + Codes.NOT_FOUND + ) @defer.inlineCallbacks def send_room_alias_update_event(self, user_id, room_id): diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 4993c92b74..d59221a4fb 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -56,7 +56,7 @@ class EventStreamHandler(BaseHandler): self.clock.cancel_call_later( self._stop_timer_per_user.pop(auth_user)) else: - self.distributor.fire( + yield self.distributor.fire( "started_user_eventstream", auth_user ) self._streams_per_user[auth_user] += 1 @@ -65,8 +65,10 @@ class EventStreamHandler(BaseHandler): pagin_config.from_token = None rm_handler = self.hs.get_handlers().room_member_handler + logger.debug("BETA") room_ids = yield rm_handler.get_rooms_for_user(auth_user) + logger.debug("ALPHA") with PreserveLoggingContext(): events, tokens = yield self.notifier.get_events_for( auth_user, room_ids, pagin_config, timeout @@ -93,7 +95,7 @@ class EventStreamHandler(BaseHandler): logger.debug( "_later stopped_user_eventstream %s", auth_user ) - self.distributor.fire( + yield self.distributor.fire( "stopped_user_eventstream", auth_user ) del self._stop_timer_per_user[auth_user] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 492005a170..252c1f1684 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -24,7 +24,8 @@ from synapse.api.constants import Membership from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.crypto.event_signing import ( - compute_event_signature, check_event_content_hash + compute_event_signature, check_event_content_hash, + add_hashes_and_signatures, ) from syutil.jsonutil import encode_canonical_json @@ -122,7 +123,8 @@ class FederationHandler(BaseHandler): event.origin, redacted_pdu_json ) except SynapseError as e: - logger.warn("Signature check failed for %s redacted to %s", + logger.warn( + "Signature check failed for %s redacted to %s", encode_canonical_json(pdu.get_pdu_json()), encode_canonical_json(redacted_pdu_json), ) @@ -140,15 +142,27 @@ class FederationHandler(BaseHandler): ) event = redacted_event - is_new_state = yield self.state_handler.annotate_event_with_state( - event, - old_state=state - ) - logger.debug("Event: %s", event) + # FIXME (erikj): Awful hack to make the case where we are not currently + # in the room work + current_state = None + if state: + is_in_room = yield self.auth.check_host_in_room( + event.room_id, + self.server_name + ) + if not is_in_room: + logger.debug("Got event for room we're not in.") + current_state = state + try: - self.auth.check(event, raises=True) + yield self._handle_new_event( + event, + state=state, + backfilled=backfilled, + current_state=current_state, + ) except AuthError as e: raise FederationError( "ERROR", @@ -157,43 +171,14 @@ class FederationHandler(BaseHandler): affected=event.event_id, ) - is_new_state = is_new_state and not backfilled - - # TODO: Implement something in federation that allows us to - # respond to PDU. - - yield self.store.persist_event( - event, - backfilled, - is_new_state=is_new_state - ) - room = yield self.store.get_room(event.room_id) if not room: - # Huh, let's try and get the current state - try: - yield self.replication_layer.get_state_for_context( - event.origin, event.room_id, event.event_id, - ) - - hosts = yield self.store.get_joined_hosts_for_room( - event.room_id - ) - if self.hs.hostname in hosts: - try: - yield self.store.store_room( - room_id=event.room_id, - room_creator_user_id="", - is_public=False, - ) - except: - pass - except: - logger.exception( - "Failed to get current state for room %s", - event.room_id - ) + yield self.store.store_room( + room_id=event.room_id, + room_creator_user_id="", + is_public=False, + ) if not backfilled: extra_users = [] @@ -209,7 +194,7 @@ class FederationHandler(BaseHandler): if event.type == RoomMemberEvent.TYPE: if event.membership == Membership.JOIN: user = self.hs.parse_userid(event.state_key) - self.distributor.fire( + yield self.distributor.fire( "user_joined_room", user=user, room_id=event.room_id ) @@ -254,6 +239,8 @@ class FederationHandler(BaseHandler): pdu=event ) + + defer.returnValue(pdu) @defer.inlineCallbacks @@ -275,6 +262,8 @@ class FederationHandler(BaseHandler): We suspend processing of any received events from this room until we have finished processing the join. """ + logger.debug("Joining %s to %s", joinee, room_id) + pdu = yield self.replication_layer.make_join( target_host, room_id, @@ -297,19 +286,28 @@ class FederationHandler(BaseHandler): try: event.event_id = self.event_factory.create_event_id() + event.origin = self.hs.hostname event.content = content - state = yield self.replication_layer.send_join( + if not hasattr(event, "signatures"): + event.signatures = {} + + add_hashes_and_signatures( + event, + self.hs.hostname, + self.hs.config.signing_key[0], + ) + + ret = yield self.replication_layer.send_join( target_host, event ) - logger.debug("do_invite_join state: %s", state) + state = ret["state"] + auth_chain = ret["auth_chain"] - yield self.state_handler.annotate_event_with_state( - event, - old_state=state - ) + logger.debug("do_invite_join auth_chain: %s", auth_chain) + logger.debug("do_invite_join state: %s", state) logger.debug("do_invite_join event: %s", event) @@ -323,34 +321,41 @@ class FederationHandler(BaseHandler): # FIXME pass - for e in state: - # FIXME: Auth these. + for e in auth_chain: e.outlier = True - - yield self.state_handler.annotate_event_with_state( - e, + yield self._handle_new_event(e) + yield self.notifier.on_new_room_event( + e, extra_users=[joinee] ) - yield self.store.persist_event( - e, - backfilled=False, - is_new_state=True + for e in state: + # FIXME: Auth these. + e.outlier = True + yield self._handle_new_event(e) + yield self.notifier.on_new_room_event( + e, extra_users=[joinee] ) - yield self.store.persist_event( + yield self._handle_new_event( event, - backfilled=False, - is_new_state=True + state=state, + current_state=state + ) + + yield self.notifier.on_new_room_event( + event, extra_users=[joinee] ) + + logger.debug("Finished joining %s to %s", joinee, room_id) finally: room_queue = self.room_queues[room_id] del self.room_queues[room_id] for p in room_queue: try: - yield self.on_receive_pdu(p, backfilled=False) + self.on_receive_pdu(p, backfilled=False) except: - pass + logger.exception("Couldn't handle pdu") defer.returnValue(True) @@ -374,7 +379,7 @@ class FederationHandler(BaseHandler): yield self.state_handler.annotate_event_with_state(event) yield self.auth.add_auth_events(event) - self.auth.check(event, raises=True) + self.auth.check(event, auth_events=event.old_state_events) pdu = event @@ -390,16 +395,7 @@ class FederationHandler(BaseHandler): event.outlier = False - is_new_state = yield self.state_handler.annotate_event_with_state(event) - self.auth.check(event, raises=True) - - # FIXME (erikj): All this is duplicated above :( - - yield self.store.persist_event( - event, - backfilled=False, - is_new_state=is_new_state - ) + yield self._handle_new_event(event) extra_users = [] if event.type == RoomMemberEvent.TYPE: @@ -412,9 +408,9 @@ class FederationHandler(BaseHandler): ) if event.type == RoomMemberEvent.TYPE: - if event.membership == Membership.JOIN: + if event.content["membership"] == Membership.JOIN: user = self.hs.parse_userid(event.state_key) - self.distributor.fire( + yield self.distributor.fire( "user_joined_room", user=user, room_id=event.room_id ) @@ -527,7 +523,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function - def get_persisted_pdu(self, origin, event_id): + def get_persisted_pdu(self, origin, event_id, do_auth=True): """ Get a PDU from the database with given origin and id. Returns: @@ -539,12 +535,13 @@ class FederationHandler(BaseHandler): ) if event: - in_room = yield self.auth.check_host_in_room( - event.room_id, - origin - ) - if not in_room: - raise AuthError(403, "Host not in room.") + if do_auth: + in_room = yield self.auth.check_host_in_room( + event.room_id, + origin + ) + if not in_room: + raise AuthError(403, "Host not in room.") defer.returnValue(event) else: @@ -562,3 +559,65 @@ class FederationHandler(BaseHandler): ) while waiters: waiters.pop().callback(None) + + @defer.inlineCallbacks + def _handle_new_event(self, event, state=None, backfilled=False, + current_state=None): + if state: + for s in state: + yield self._handle_new_event(s) + + is_new_state = yield self.state_handler.annotate_event_with_state( + event, + old_state=state + ) + + if event.old_state_events: + known_ids = set( + [s.event_id for s in event.old_state_events.values()] + ) + for e_id, _ in event.auth_events: + if e_id not in known_ids: + e = yield self.store.get_event( + e_id, + allow_none=True, + ) + + if not e: + # TODO: Do some conflict res to make sure that we're + # not the ones who are wrong. + logger.info( + "Rejecting %s as %s not in %s", + event.event_id, e_id, known_ids, + ) + raise AuthError(403, "Auth events are stale") + + auth_events = event.old_state_events + else: + # We need to get the auth events from somewhere. + + # TODO: Don't just hit the DBs? + + auth_events = {} + for e_id, _ in event.auth_events: + e = yield self.store.get_event( + e_id, + allow_none=True, + ) + + if not e: + raise AuthError( + 403, + "Can't find auth event %s." % (e_id, ) + ) + + auth_events[(e.type, e.state_key)] = e + + self.auth.check(event, auth_events=auth_events) + + yield self.store.persist_event( + event, + backfilled=backfilled, + is_new_state=(is_new_state and not backfilled), + current_state=current_state, + ) diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py index 99d15261d4..c98ae2cfb5 100644 --- a/synapse/handlers/login.py +++ b/synapse/handlers/login.py @@ -17,13 +17,12 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.errors import LoginError, Codes -from synapse.http.client import IdentityServerHttpClient +from synapse.http.client import SimpleHttpClient from synapse.util.emailutils import EmailException import synapse.util.emailutils as emailutils import bcrypt import logging -import urllib logger = logging.getLogger(__name__) @@ -97,10 +96,16 @@ class LoginHandler(BaseHandler): @defer.inlineCallbacks def _query_email(self, email): - httpCli = IdentityServerHttpClient(self.hs) + httpCli = SimpleHttpClient(self.hs) data = yield httpCli.get_json( - 'matrix.org:8090', # TODO FIXME This should be configurable. - "/_matrix/identity/api/v1/lookup?medium=email&address=" + - "%s" % urllib.quote(email) + # TODO FIXME This should be configurable. + # XXX: ID servers need to use HTTPS + "http://%s%s" % ( + "matrix.org:8090", "/_matrix/identity/api/v1/lookup" + ), + { + 'medium': 'email', + 'address': email + } ) defer.returnValue(data) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index de70486b29..42dc4d46f3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.constants import Membership from synapse.api.errors import RoomError from synapse.streams.config import PaginationConfig +from synapse.util.logcontext import PreserveLoggingContext from ._base import BaseHandler import logging @@ -86,9 +87,10 @@ class MessageHandler(BaseHandler): event, snapshot, suppress_auth=suppress_auth ) - self.hs.get_handlers().presence_handler.bump_presence_active_time( - user - ) + with PreserveLoggingContext(): + self.hs.get_handlers().presence_handler.bump_presence_active_time( + user + ) @defer.inlineCallbacks def get_messages(self, user_id=None, room_id=None, pagin_config=None, @@ -241,7 +243,7 @@ class MessageHandler(BaseHandler): public_room_ids = [r["room_id"] for r in public_rooms] limit = pagin_config.limit - if not limit: + if limit is None: limit = 10 for event in room_list: @@ -296,7 +298,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def room_initial_sync(self, user_id, room_id, pagin_config=None, - feedback=False): + feedback=False): yield self.auth.check_joined_room(room_id, user_id) # TODO(paul): I wish I was called with user objects not user_id @@ -304,7 +306,7 @@ class MessageHandler(BaseHandler): auth_user = self.hs.parse_userid(user_id) # TODO: These concurrently - state_tuples = yield self.store.get_current_state(room_id) + state_tuples = yield self.state_handler.get_current_state(room_id) state = [self.hs.serialize_event(x) for x in state_tuples] member_event = (yield self.store.get_room_member( @@ -340,8 +342,8 @@ class MessageHandler(BaseHandler): ) presence.append(member_presence) except Exception: - logger.exception("Failed to get member presence of %r", - m.user_id + logger.exception( + "Failed to get member presence of %r", m.user_id ) defer.returnValue({ diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index fcc92a8e32..b55d589daf 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -19,6 +19,7 @@ from synapse.api.errors import SynapseError, AuthError from synapse.api.constants import PresenceState from synapse.util.logutils import log_function +from synapse.util.logcontext import PreserveLoggingContext from ._base import BaseHandler @@ -142,7 +143,7 @@ class PresenceHandler(BaseHandler): return UserPresenceCache() def registered_user(self, user): - self.store.create_presence(user.localpart) + return self.store.create_presence(user.localpart) @defer.inlineCallbacks def is_presence_visible(self, observer_user, observed_user): @@ -241,14 +242,12 @@ class PresenceHandler(BaseHandler): was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]] now_level = self.STATE_LEVELS[state["presence"]] - yield defer.DeferredList([ - self.store.set_presence_state( - target_user.localpart, state_to_store - ), - self.distributor.fire( - "collect_presencelike_data", target_user, state - ), - ]) + yield self.store.set_presence_state( + target_user.localpart, state_to_store + ) + yield self.distributor.fire( + "collect_presencelike_data", target_user, state + ) if now_level > was_level: state["last_active"] = self.clock.time_msec() @@ -256,14 +255,15 @@ class PresenceHandler(BaseHandler): now_online = state["presence"] != PresenceState.OFFLINE was_polling = target_user in self._user_cachemap - if now_online and not was_polling: - self.start_polling_presence(target_user, state=state) - elif not now_online and was_polling: - self.stop_polling_presence(target_user) + with PreserveLoggingContext(): + if now_online and not was_polling: + self.start_polling_presence(target_user, state=state) + elif not now_online and was_polling: + self.stop_polling_presence(target_user) - # TODO(paul): perform a presence push as part of start/stop poll so - # we don't have to do this all the time - self.changed_presencelike_data(target_user, state) + # TODO(paul): perform a presence push as part of start/stop poll so + # we don't have to do this all the time + self.changed_presencelike_data(target_user, state) def bump_presence_active_time(self, user, now=None): if now is None: @@ -277,7 +277,7 @@ class PresenceHandler(BaseHandler): self._user_cachemap_latest_serial += 1 statuscache.update(state, serial=self._user_cachemap_latest_serial) - self.push_presence(user, statuscache=statuscache) + return self.push_presence(user, statuscache=statuscache) @log_function def started_user_eventstream(self, user): @@ -381,8 +381,10 @@ class PresenceHandler(BaseHandler): yield self.store.set_presence_list_accepted( observer_user.localpart, observed_user.to_string() ) - - self.start_polling_presence(observer_user, target_user=observed_user) + with PreserveLoggingContext(): + self.start_polling_presence( + observer_user, target_user=observed_user + ) @defer.inlineCallbacks def deny_presence(self, observed_user, observer_user): @@ -401,7 +403,10 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - self.stop_polling_presence(observer_user, target_user=observed_user) + with PreserveLoggingContext(): + self.stop_polling_presence( + observer_user, target_user=observed_user + ) @defer.inlineCallbacks def get_presence_list(self, observer_user, accepted=None): @@ -710,7 +715,8 @@ class PresenceHandler(BaseHandler): if not self._remote_sendmap[user]: del self._remote_sendmap[user] - yield defer.DeferredList(deferreds) + with PreserveLoggingContext(): + yield defer.DeferredList(deferreds) @defer.inlineCallbacks def push_update_to_local_and_remote(self, observed_user, statuscache, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 7853bf5098..814b3b68fe 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.constants import Membership +from synapse.util.logcontext import PreserveLoggingContext from ._base import BaseHandler @@ -46,7 +47,7 @@ class ProfileHandler(BaseHandler): ) def registered_user(self, user): - self.store.create_profile(user.localpart) + return self.store.create_profile(user.localpart) @defer.inlineCallbacks def get_displayname(self, target_user): @@ -152,13 +153,14 @@ class ProfileHandler(BaseHandler): if not user.is_mine: defer.returnValue(None) - (displayname, avatar_url) = yield defer.gatherResults( - [ - self.store.get_profile_displayname(user.localpart), - self.store.get_profile_avatar_url(user.localpart), - ], - consumeErrors=True - ) + with PreserveLoggingContext(): + (displayname, avatar_url) = yield defer.gatherResults( + [ + self.store.get_profile_displayname(user.localpart), + self.store.get_profile_avatar_url(user.localpart), + ], + consumeErrors=True + ) state["displayname"] = displayname state["avatar_url"] = avatar_url diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7df9d9b82d..48c326ebf0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -22,7 +22,7 @@ from synapse.api.errors import ( ) from ._base import BaseHandler import synapse.util.stringutils as stringutils -from synapse.http.client import IdentityServerHttpClient +from synapse.http.client import SimpleHttpClient from synapse.http.client import CaptchaServerHttpClient import base64 @@ -69,7 +69,7 @@ class RegistrationHandler(BaseHandler): password_hash=password_hash ) - self.distributor.fire("registered_user", user) + yield self.distributor.fire("registered_user", user) else: # autogen a random user ID attempts = 0 @@ -133,7 +133,7 @@ class RegistrationHandler(BaseHandler): if not threepid: raise RegistrationError(400, "Couldn't validate 3pid") - logger.info("got threepid medium %s address %s", + logger.info("got threepid with medium '%s' and address '%s'", threepid['medium'], threepid['address']) @defer.inlineCallbacks @@ -159,7 +159,7 @@ class RegistrationHandler(BaseHandler): def _threepid_from_creds(self, creds): # TODO: get this from the homeserver rather than creating a new one for # each request - httpCli = IdentityServerHttpClient(self.hs) + httpCli = SimpleHttpClient(self.hs) # XXX: make this configurable! trustedIdServers = ['matrix.org:8090'] if not creds['idServer'] in trustedIdServers: @@ -167,8 +167,11 @@ class RegistrationHandler(BaseHandler): 'credentials', creds['idServer']) defer.returnValue(None) data = yield httpCli.get_json( - creds['idServer'], - "/_matrix/identity/api/v1/3pid/getValidated3pid", + # XXX: This should be HTTPS + "http://%s%s" % ( + creds['idServer'], + "/_matrix/identity/api/v1/3pid/getValidated3pid" + ), {'sid': creds['sid'], 'clientSecret': creds['clientSecret']} ) @@ -178,16 +181,21 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def _bind_threepid(self, creds, mxid): - httpCli = IdentityServerHttpClient(self.hs) + yield + logger.debug("binding threepid") + httpCli = SimpleHttpClient(self.hs) data = yield httpCli.post_urlencoded_get_json( - creds['idServer'], - "/_matrix/identity/api/v1/3pid/bind", + # XXX: Change when ID servers are all HTTPS + "http://%s%s" % ( + creds['idServer'], "/_matrix/identity/api/v1/3pid/bind" + ), { 'sid': creds['sid'], 'clientSecret': creds['clientSecret'], 'mxid': mxid, } ) + logger.debug("bound threepid") defer.returnValue(data) @defer.inlineCallbacks @@ -215,10 +223,7 @@ class RegistrationHandler(BaseHandler): # each request client = CaptchaServerHttpClient(self.hs) data = yield client.post_urlencoded_get_raw( - "www.google.com:80", - "/recaptcha/api/verify", - # twisted dislikes google's response, no content length. - accept_partial=True, + "http://www.google.com:80/recaptcha/api/verify", args={ 'privatekey': private_key, 'remoteip': ip_addr, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7d9458e1d0..88955160c5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -178,7 +178,9 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - directory_handler.send_room_alias_update_event(user_id, room_id) + yield directory_handler.send_room_alias_update_event( + user_id, room_id + ) defer.returnValue(result) @@ -211,7 +213,6 @@ class RoomCreationHandler(BaseHandler): **event_keys ) - power_levels_event = self.event_factory.create_event( etype=RoomPowerLevelsEvent.TYPE, content={ @@ -480,7 +481,7 @@ class RoomMemberHandler(BaseHandler): ) user = self.hs.parse_userid(event.user_id) - self.distributor.fire( + yield self.distributor.fire( "user_joined_room", user=user, room_id=room_id ) diff --git a/synapse/http/client.py b/synapse/http/client.py index dea61ba1e0..048a428905 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -15,308 +15,45 @@ from twisted.internet import defer, reactor -from twisted.internet.error import DNSLookupError from twisted.web.client import ( - _AgentBase, _URI, readBody, FileBodyProducer, PartialDownloadError + Agent, readBody, FileBodyProducer, PartialDownloadError ) from twisted.web.http_headers import Headers -from synapse.http.endpoint import matrix_endpoint -from synapse.util.async import sleep -from synapse.util.logcontext import PreserveLoggingContext - -from syutil.jsonutil import encode_canonical_json - -from synapse.api.errors import CodeMessageException, SynapseError - -from syutil.crypto.jsonsign import sign_json - from StringIO import StringIO import json import logging import urllib -import urlparse logger = logging.getLogger(__name__) -class MatrixHttpAgent(_AgentBase): - - def __init__(self, reactor, pool=None): - _AgentBase.__init__(self, reactor, pool) - - def request(self, destination, endpoint, method, path, params, query, - headers, body_producer): - - host = b"" - port = 0 - fragment = b"" - - parsed_URI = _URI(b"http", destination, host, port, path, params, - query, fragment) - - # Set the connection pool key to be the destination. - key = destination - - return self._requestWithEndpoint(key, endpoint, method, parsed_URI, - headers, body_producer, - parsed_URI.originForm) - - -class BaseHttpClient(object): - """Base class for HTTP clients using twisted. +class SimpleHttpClient(object): """ - - def __init__(self, hs): - self.agent = MatrixHttpAgent(reactor) - self.hs = hs - - @defer.inlineCallbacks - def _create_request(self, destination, method, path_bytes, - body_callback, headers_dict={}, param_bytes=b"", - query_bytes=b"", retry_on_dns_fail=True): - """ Creates and sends a request to the given url - """ - headers_dict[b"User-Agent"] = [b"Synapse"] - headers_dict[b"Host"] = [destination] - - url_bytes = urlparse.urlunparse( - ("", "", path_bytes, param_bytes, query_bytes, "",) - ) - - logger.debug("Sending request to %s: %s %s", - destination, method, url_bytes) - - logger.debug( - "Types: %s", - [ - type(destination), type(method), type(path_bytes), - type(param_bytes), - type(query_bytes) - ] - ) - - retries_left = 5 - - endpoint = self._getEndpoint(reactor, destination) - - while True: - - producer = None - if body_callback: - producer = body_callback(method, url_bytes, headers_dict) - - try: - with PreserveLoggingContext(): - response = yield self.agent.request( - destination, - endpoint, - method, - path_bytes, - param_bytes, - query_bytes, - Headers(headers_dict), - producer - ) - - logger.debug("Got response to %s", method) - break - except Exception as e: - if not retry_on_dns_fail and isinstance(e, DNSLookupError): - logger.warn("DNS Lookup failed to %s with %s", destination, - e) - raise SynapseError(400, "Domain specified not found.") - - logger.exception("Got error in _create_request") - _print_ex(e) - - if retries_left: - yield sleep(2 ** (5 - retries_left)) - retries_left -= 1 - else: - raise - - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - pass - else: - # :'( - # Update transactions table? - logger.error( - "Got response %d %s", response.code, response.phrase - ) - raise CodeMessageException( - response.code, response.phrase - ) - - defer.returnValue(response) - - -class MatrixHttpClient(BaseHttpClient): - """ Wrapper around the twisted HTTP client api. Implements - - Attributes: - agent (twisted.web.client.Agent): The twisted Agent used to send the - requests. + A simple, no-frills HTTP client with methods that wrap up common ways of + using HTTP in Matrix """ - - RETRY_DNS_LOOKUP_FAILURES = "__retry_dns" - def __init__(self, hs): - self.signing_key = hs.config.signing_key[0] - self.server_name = hs.hostname - BaseHttpClient.__init__(self, hs) - - def sign_request(self, destination, method, url_bytes, headers_dict, - content=None): - request = { - "method": method, - "uri": url_bytes, - "origin": self.server_name, - "destination": destination, - } - - if content is not None: - request["content"] = content - - request = sign_json(request, self.server_name, self.signing_key) - - auth_headers = [] - - for key, sig in request["signatures"][self.server_name].items(): - auth_headers.append(bytes( - "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( - self.server_name, key, sig, - ) - )) - - headers_dict[b"Authorization"] = auth_headers - - @defer.inlineCallbacks - def put_json(self, destination, path, data={}, json_data_callback=None): - """ Sends the specifed json data using PUT - - Args: - destination (str): The remote server to send the HTTP request - to. - path (str): The HTTP path. - data (dict): A dict containing the data that will be used as - the request body. This will be encoded as JSON. - json_data_callback (callable): A callable returning the dict to - use as the request body. - - Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. On a 4xx or 5xx error response a - CodeMessageException is raised. - """ - - if not json_data_callback: - def json_data_callback(): - return data - - def body_callback(method, url_bytes, headers_dict): - json_data = json_data_callback() - self.sign_request( - destination, method, url_bytes, headers_dict, json_data - ) - producer = _JsonProducer(json_data) - return producer - - response = yield self._create_request( - destination.encode("ascii"), - "PUT", - path.encode("ascii"), - body_callback=body_callback, - headers_dict={"Content-Type": ["application/json"]}, - ) - - logger.debug("Getting resp body") - body = yield readBody(response) - logger.debug("Got resp body") - - defer.returnValue((response.code, body)) - - @defer.inlineCallbacks - def get_json(self, destination, path, args={}, retry_on_dns_fail=True): - """ Get's some json from the given host homeserver and path - - Args: - destination (str): The remote server to send the HTTP request - to. - path (str): The HTTP path. - args (dict): A dictionary used to create query strings, defaults to - None. - **Note**: The value of each key is assumed to be an iterable - and *not* a string. - - Returns: - Deferred: Succeeds when we get *any* HTTP response. - - The result of the deferred is a tuple of `(code, response)`, - where `response` is a dict representing the decoded JSON body. - """ - logger.debug("get_json args: %s", args) - - encoded_args = {} - for k, vs in args.items(): - if isinstance(vs, basestring): - vs = [vs] - encoded_args[k] = [v.encode("UTF-8") for v in vs] - - query_bytes = urllib.urlencode(encoded_args, True) - logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) - - def body_callback(method, url_bytes, headers_dict): - self.sign_request(destination, method, url_bytes, headers_dict) - return None - - response = yield self._create_request( - destination.encode("ascii"), - "GET", - path.encode("ascii"), - query_bytes=query_bytes, - body_callback=body_callback, - retry_on_dns_fail=retry_on_dns_fail - ) - - body = yield readBody(response) - - defer.returnValue(json.loads(body)) - - def _getEndpoint(self, reactor, destination): - return matrix_endpoint( - reactor, destination, timeout=10, - ssl_context_factory=self.hs.tls_context_factory - ) - - -class IdentityServerHttpClient(BaseHttpClient): - """Separate HTTP client for talking to the Identity servers since they - don't use SRV records and talk x-www-form-urlencoded rather than JSON. - """ - def _getEndpoint(self, reactor, destination): - #TODO: This should be talking TLS - return matrix_endpoint(reactor, destination, timeout=10) + self.hs = hs + # The default context factory in Twisted 14.0.0 (which we require) is + # BrowserLikePolicyForHTTPS which will do regular cert validation + # 'like a browser' + self.agent = Agent(reactor) @defer.inlineCallbacks - def post_urlencoded_get_json(self, destination, path, args={}): + def post_urlencoded_get_json(self, uri, args={}): logger.debug("post_urlencoded_get_json args: %s", args) query_bytes = urllib.urlencode(args, True) - def body_callback(method, url_bytes, headers_dict): - return FileBodyProducer(StringIO(query_bytes)) - - response = yield self._create_request( - destination.encode("ascii"), + response = yield self.agent.request( "POST", - path.encode("ascii"), - body_callback=body_callback, - headers_dict={ + uri.encode("ascii"), + headers=Headers({ "Content-Type": ["application/x-www-form-urlencoded"] - } + }), + bodyProducer=FileBodyProducer(StringIO(query_bytes)) ) body = yield readBody(response) @@ -324,13 +61,11 @@ class IdentityServerHttpClient(BaseHttpClient): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def get_json(self, destination, path, args={}, retry_on_dns_fail=True): - """ Get's some json from the given host homeserver and path + def get_json(self, uri, args={}): + """ Get's some json from the given host and path Args: - destination (str): The remote server to send the HTTP request - to. - path (str): The HTTP path. + uri (str): The URI to request, not including query parameters args (dict): A dictionary used to create query strings, defaults to None. **Note**: The value of each key is assumed to be an iterable @@ -342,18 +77,15 @@ class IdentityServerHttpClient(BaseHttpClient): The result of the deferred is a tuple of `(code, response)`, where `response` is a dict representing the decoded JSON body. """ - logger.debug("get_json args: %s", args) - query_bytes = urllib.urlencode(args, True) - logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) + yield + if len(args): + query_bytes = urllib.urlencode(args, True) + uri = "%s?%s" % (uri, query_bytes) - response = yield self._create_request( - destination.encode("ascii"), + response = yield self.agent.request( "GET", - path.encode("ascii"), - query_bytes=query_bytes, - retry_on_dns_fail=retry_on_dns_fail, - body_callback=None + uri.encode("ascii"), ) body = yield readBody(response) @@ -361,38 +93,31 @@ class IdentityServerHttpClient(BaseHttpClient): defer.returnValue(json.loads(body)) -class CaptchaServerHttpClient(MatrixHttpClient): - """Separate HTTP client for talking to google's captcha servers""" - - def _getEndpoint(self, reactor, destination): - return matrix_endpoint(reactor, destination, timeout=10) +class CaptchaServerHttpClient(SimpleHttpClient): + """ + Separate HTTP client for talking to google's captcha servers + Only slightly special because accepts partial download responses + """ @defer.inlineCallbacks - def post_urlencoded_get_raw(self, destination, path, accept_partial=False, - args={}): + def post_urlencoded_get_raw(self, url, args={}): query_bytes = urllib.urlencode(args, True) - def body_callback(method, url_bytes, headers_dict): - return FileBodyProducer(StringIO(query_bytes)) - - response = yield self._create_request( - destination.encode("ascii"), + response = yield self.agent.request( "POST", - path.encode("ascii"), - body_callback=body_callback, - headers_dict={ + url.encode("ascii"), + bodyProducer=FileBodyProducer(StringIO(query_bytes)), + headers=Headers({ "Content-Type": ["application/x-www-form-urlencoded"] - } + }) ) try: body = yield readBody(response) defer.returnValue(body) except PartialDownloadError as e: - if accept_partial: - defer.returnValue(e.response) - else: - raise e + # twisted dislikes google's response, no content length. + defer.returnValue(e.response) def _print_ex(e): @@ -401,24 +126,3 @@ def _print_ex(e): _print_ex(ex) else: logger.exception(e) - - -class _JsonProducer(object): - """ Used by the twisted http client to create the HTTP body from json - """ - def __init__(self, jsn): - self.reset(jsn) - - def reset(self, jsn): - self.body = encode_canonical_json(jsn) - self.length = len(self.body) - - def startProducing(self, consumer): - consumer.write(self.body) - return defer.succeed(None) - - def pauseProducing(self): - pass - - def stopProducing(self): - pass diff --git a/synapse/http/content_repository.py b/synapse/http/content_repository.py index 1306b35271..64ecb5346e 100644 --- a/synapse/http/content_repository.py +++ b/synapse/http/content_repository.py @@ -131,12 +131,14 @@ class ContentRepoResource(resource.Resource): request.setHeader('Content-Type', content_type) # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to recommend - # caching as it's sensitive or private - or at least select private. - # don't bother setting Expires as all our matrix clients are smart enough to - # be happy with Cache-Control (right?) - request.setHeader('Cache-Control', 'public,max-age=86400,s-maxage=86400') - + # XXX: we might want to turn this off for data we don't want to + # recommend caching as it's sensitive or private - or at least + # select private. don't bother setting Expires as all our matrix + # clients are smart enough to be happy with Cache-Control (right?) + request.setHeader( + "Cache-Control", "public,max-age=86400,s-maxage=86400" + ) + d = FileSender().beginFileTransfer(f, request) # after the file has been sent, clean up and finish the request @@ -179,7 +181,7 @@ class ContentRepoResource(resource.Resource): fname = yield self.map_request_to_name(request) - # TODO I have a suspcious feeling this is just going to block + # TODO I have a suspicious feeling this is just going to block with open(fname, "wb") as f: f.write(request.content.read()) @@ -188,7 +190,7 @@ class ContentRepoResource(resource.Resource): # FIXME: we can't assume what the repo's public mounted path is # ...plus self-signed SSL won't work to remote clients anyway # ...and we can't assume that it's SSL anyway, as we might want to - # server it via the non-SSL listener... + # serve it via the non-SSL listener... url = "%s/_matrix/content/%s" % ( self.external_addr, file_name ) diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 7018ee3458..9c8888f565 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -27,8 +27,8 @@ import random logger = logging.getLogger(__name__) -def matrix_endpoint(reactor, destination, ssl_context_factory=None, - timeout=None): +def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, + timeout=None): """Construct an endpoint for the given matrix destination. Args: diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py new file mode 100644 index 0000000000..510f07dd7b --- /dev/null +++ b/synapse/http/matrixfederationclient.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer, reactor +from twisted.internet.error import DNSLookupError +from twisted.web.client import readBody, _AgentBase, _URI +from twisted.web.http_headers import Headers + +from synapse.http.endpoint import matrix_federation_endpoint +from synapse.util.async import sleep +from synapse.util.logcontext import PreserveLoggingContext + +from syutil.jsonutil import encode_canonical_json + +from synapse.api.errors import CodeMessageException, SynapseError + +from syutil.crypto.jsonsign import sign_json + +import json +import logging +import urllib +import urlparse + + +logger = logging.getLogger(__name__) + + +class MatrixFederationHttpAgent(_AgentBase): + + def __init__(self, reactor, pool=None): + _AgentBase.__init__(self, reactor, pool) + + def request(self, destination, endpoint, method, path, params, query, + headers, body_producer): + + host = b"" + port = 0 + fragment = b"" + + parsed_URI = _URI(b"http", destination, host, port, path, params, + query, fragment) + + # Set the connection pool key to be the destination. + key = destination + + return self._requestWithEndpoint(key, endpoint, method, parsed_URI, + headers, body_producer, + parsed_URI.originForm) + + +class MatrixFederationHttpClient(object): + """HTTP client used to talk to other homeservers over the federation + protocol. Send client certificates and signs requests. + + Attributes: + agent (twisted.web.client.Agent): The twisted Agent used to send the + requests. + """ + + def __init__(self, hs): + self.hs = hs + self.signing_key = hs.config.signing_key[0] + self.server_name = hs.hostname + self.agent = MatrixFederationHttpAgent(reactor) + + @defer.inlineCallbacks + def _create_request(self, destination, method, path_bytes, + body_callback, headers_dict={}, param_bytes=b"", + query_bytes=b"", retry_on_dns_fail=True): + """ Creates and sends a request to the given url + """ + headers_dict[b"User-Agent"] = [b"Synapse"] + headers_dict[b"Host"] = [destination] + + url_bytes = urlparse.urlunparse( + ("", "", path_bytes, param_bytes, query_bytes, "",) + ) + + logger.debug("Sending request to %s: %s %s", + destination, method, url_bytes) + + logger.debug( + "Types: %s", + [ + type(destination), type(method), type(path_bytes), + type(param_bytes), + type(query_bytes) + ] + ) + + retries_left = 5 + + endpoint = self._getEndpoint(reactor, destination) + + while True: + producer = None + if body_callback: + producer = body_callback(method, url_bytes, headers_dict) + + try: + with PreserveLoggingContext(): + response = yield self.agent.request( + destination, + endpoint, + method, + path_bytes, + param_bytes, + query_bytes, + Headers(headers_dict), + producer + ) + + logger.debug("Got response to %s", method) + break + except Exception as e: + if not retry_on_dns_fail and isinstance(e, DNSLookupError): + logger.warn("DNS Lookup failed to %s with %s", destination, + e) + raise SynapseError(400, "Domain specified not found.") + + logger.exception("Got error in _create_request") + _print_ex(e) + + if retries_left: + yield sleep(2 ** (5 - retries_left)) + retries_left -= 1 + else: + raise + + if 200 <= response.code < 300: + # We need to update the transactions table to say it was sent? + pass + else: + # :'( + # Update transactions table? + logger.error( + "Got response %d %s", response.code, response.phrase + ) + raise CodeMessageException( + response.code, response.phrase + ) + + defer.returnValue(response) + + def sign_request(self, destination, method, url_bytes, headers_dict, + content=None): + request = { + "method": method, + "uri": url_bytes, + "origin": self.server_name, + "destination": destination, + } + + if content is not None: + request["content"] = content + + request = sign_json(request, self.server_name, self.signing_key) + + auth_headers = [] + + for key, sig in request["signatures"][self.server_name].items(): + auth_headers.append(bytes( + "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( + self.server_name, key, sig, + ) + )) + + headers_dict[b"Authorization"] = auth_headers + + @defer.inlineCallbacks + def put_json(self, destination, path, data={}, json_data_callback=None): + """ Sends the specifed json data using PUT + + Args: + destination (str): The remote server to send the HTTP request + to. + path (str): The HTTP path. + data (dict): A dict containing the data that will be used as + the request body. This will be encoded as JSON. + json_data_callback (callable): A callable returning the dict to + use as the request body. + + Returns: + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. On a 4xx or 5xx error response a + CodeMessageException is raised. + """ + + if not json_data_callback: + def json_data_callback(): + return data + + def body_callback(method, url_bytes, headers_dict): + json_data = json_data_callback() + self.sign_request( + destination, method, url_bytes, headers_dict, json_data + ) + producer = _JsonProducer(json_data) + return producer + + response = yield self._create_request( + destination.encode("ascii"), + "PUT", + path.encode("ascii"), + body_callback=body_callback, + headers_dict={"Content-Type": ["application/json"]}, + ) + + logger.debug("Getting resp body") + body = yield readBody(response) + logger.debug("Got resp body") + + defer.returnValue((response.code, body)) + + @defer.inlineCallbacks + def get_json(self, destination, path, args={}, retry_on_dns_fail=True): + """ Get's some json from the given host homeserver and path + + Args: + destination (str): The remote server to send the HTTP request + to. + path (str): The HTTP path. + args (dict): A dictionary used to create query strings, defaults to + None. + **Note**: The value of each key is assumed to be an iterable + and *not* a string. + + Returns: + Deferred: Succeeds when we get *any* HTTP response. + + The result of the deferred is a tuple of `(code, response)`, + where `response` is a dict representing the decoded JSON body. + """ + logger.debug("get_json args: %s", args) + + encoded_args = {} + for k, vs in args.items(): + if isinstance(vs, basestring): + vs = [vs] + encoded_args[k] = [v.encode("UTF-8") for v in vs] + + query_bytes = urllib.urlencode(encoded_args, True) + logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) + + def body_callback(method, url_bytes, headers_dict): + self.sign_request(destination, method, url_bytes, headers_dict) + return None + + response = yield self._create_request( + destination.encode("ascii"), + "GET", + path.encode("ascii"), + query_bytes=query_bytes, + body_callback=body_callback, + retry_on_dns_fail=retry_on_dns_fail + ) + + body = yield readBody(response) + + defer.returnValue(json.loads(body)) + + def _getEndpoint(self, reactor, destination): + return matrix_federation_endpoint( + reactor, destination, timeout=10, + ssl_context_factory=self.hs.tls_context_factory + ) + + +def _print_ex(e): + if hasattr(e, "reasons") and e.reasons: + for ex in e.reasons: + _print_ex(ex) + else: + logger.exception(e) + + +class _JsonProducer(object): + """ Used by the twisted http client to create the HTTP body from json + """ + def __init__(self, jsn): + self.reset(jsn) + + def reset(self, jsn): + self.body = encode_canonical_json(jsn) + self.length = len(self.body) + + def startProducing(self, consumer): + consumer.write(self.body) + return defer.succeed(None) + + def pauseProducing(self): + pass + + def stopProducing(self): + pass diff --git a/synapse/http/server.py b/synapse/http/server.py index ed1f1170cb..8024ff5bde 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -138,8 +138,7 @@ class JsonResource(HttpServer, resource.Resource): ) except CodeMessageException as e: if isinstance(e, SynapseError): - logger.error("%s SynapseError: %s - %s", request, e.code, - e.msg) + logger.info("%s SynapseError: %s - %s", request, e.code, e.msg) else: logger.exception(e) self._send_response( diff --git a/synapse/notifier.py b/synapse/notifier.py index c310a9fed6..5e14950449 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.async import run_on_reactor import logging @@ -96,6 +97,7 @@ class Notifier(object): listening to the room, and any listeners for the users in the `extra_users` param. """ + yield run_on_reactor() room_id = event.room_id room_source = self.event_sources.sources["room"] @@ -143,6 +145,7 @@ class Notifier(object): Will wake up all listeners for the given users and rooms. """ + yield run_on_reactor() presence_source = self.event_sources.sources["presence"] listeners = set() @@ -211,6 +214,7 @@ class Notifier(object): timeout, deferred, ) + def _timeout_listener(): # TODO (erikj): We should probably set to_token to the current # max rather than reusing from_token. diff --git a/synapse/rest/events.py b/synapse/rest/events.py index 92ff5e5ca7..3c1b041bfe 100644 --- a/synapse/rest/events.py +++ b/synapse/rest/events.py @@ -26,7 +26,6 @@ import logging logger = logging.getLogger(__name__) - class EventStreamRestServlet(RestServlet): PATTERN = client_path_pattern("/events$") diff --git a/synapse/rest/presence.py b/synapse/rest/presence.py index 138cc88a05..502ed0d4ca 100644 --- a/synapse/rest/presence.py +++ b/synapse/rest/presence.py @@ -117,8 +117,6 @@ class PresenceListRestServlet(RestServlet): logger.exception("JSON parse error") raise SynapseError(400, "Unable to parse content") - deferreds = [] - if "invite" in content: for u in content["invite"]: if not isinstance(u, basestring): @@ -126,8 +124,9 @@ class PresenceListRestServlet(RestServlet): if len(u) == 0: continue invited_user = self.hs.parse_userid(u) - deferreds.append(self.handlers.presence_handler.send_invite( - observer_user=user, observed_user=invited_user)) + yield self.handlers.presence_handler.send_invite( + observer_user=user, observed_user=invited_user + ) if "drop" in content: for u in content["drop"]: @@ -136,10 +135,9 @@ class PresenceListRestServlet(RestServlet): if len(u) == 0: continue dropped_user = self.hs.parse_userid(u) - deferreds.append(self.handlers.presence_handler.drop( - observer_user=user, observed_user=dropped_user)) - - yield defer.DeferredList(deferreds) + yield self.handlers.presence_handler.drop( + observer_user=user, observed_user=dropped_user + ) defer.returnValue((200, {})) diff --git a/synapse/rest/register.py b/synapse/rest/register.py index 5c15614ea9..f25e23a158 100644 --- a/synapse/rest/register.py +++ b/synapse/rest/register.py @@ -222,6 +222,7 @@ class RegisterRestServlet(RestServlet): threepidCreds = register_json['threepidCreds'] handler = self.handlers.registration_handler + logger.debug("Registering email. threepidcreds: %s" % (threepidCreds)) yield handler.register_email(threepidCreds) session["threepidCreds"] = threepidCreds # store creds for next stage session[LoginType.EMAIL_IDENTITY] = True # mark email as done @@ -232,6 +233,7 @@ class RegisterRestServlet(RestServlet): @defer.inlineCallbacks def _do_password(self, request, register_json, session): + yield if (self.hs.config.enable_registration_captcha and not session[LoginType.RECAPTCHA]): # captcha should've been done by this stage! @@ -259,6 +261,9 @@ class RegisterRestServlet(RestServlet): ) if session[LoginType.EMAIL_IDENTITY]: + logger.debug("Binding emails %s to %s" % ( + session["threepidCreds"], user_id) + ) yield handler.bind_emails(user_id, session["threepidCreds"]) result = { diff --git a/synapse/rest/room.py b/synapse/rest/room.py index 4f6d039b61..cc6ffb9aff 100644 --- a/synapse/rest/room.py +++ b/synapse/rest/room.py @@ -148,7 +148,7 @@ class RoomStateEventRestServlet(RestServlet): content = _parse_json(request) event = self.event_factory.create_event( - etype=urllib.unquote(event_type), + etype=event_type, # already urldecoded content=content, room_id=urllib.unquote(room_id), user_id=user.to_string(), diff --git a/synapse/state.py b/synapse/state.py index 1c999e4d79..430665f7ba 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -82,7 +82,7 @@ class StateHandler(object): if hasattr(event, "outlier") and event.outlier: event.state_group = None event.old_state_events = None - event.state_events = {} + event.state_events = None defer.returnValue(False) return diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 330d3b793f..1fb33171e8 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -67,7 +67,7 @@ SCHEMAS = [ # Remember to update this number every time an incompatible change is made to # database schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 7 +SCHEMA_VERSION = 8 class _RollbackButIsFineException(Exception): @@ -93,7 +93,8 @@ class DataStore(RoomMemberStore, RoomStore, @defer.inlineCallbacks @log_function - def persist_event(self, event, backfilled=False, is_new_state=True): + def persist_event(self, event, backfilled=False, is_new_state=True, + current_state=None): stream_ordering = None if backfilled: if not self.min_token_deferred.called: @@ -109,6 +110,7 @@ class DataStore(RoomMemberStore, RoomStore, backfilled=backfilled, stream_ordering=stream_ordering, is_new_state=is_new_state, + current_state=current_state, ) except _RollbackButIsFineException: pass @@ -137,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore, @log_function def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, - is_new_state=True): + is_new_state=True, current_state=None): if event.type == RoomMemberEvent.TYPE: self._store_room_member_txn(txn, event) elif event.type == FeedbackEvent.TYPE: @@ -206,8 +208,24 @@ class DataStore(RoomMemberStore, RoomStore, self._store_state_groups_txn(txn, event) + if current_state: + txn.execute("DELETE FROM current_state_events") + + for s in current_state: + self._simple_insert_txn( + txn, + "current_state_events", + { + "event_id": s.event_id, + "room_id": s.room_id, + "type": s.type, + "state_key": s.state_key, + }, + or_replace=True, + ) + is_state = hasattr(event, "state_key") and event.state_key is not None - if is_new_state and is_state: + if is_state: vals = { "event_id": event.event_id, "room_id": event.room_id, @@ -225,17 +243,18 @@ class DataStore(RoomMemberStore, RoomStore, or_replace=True, ) - self._simple_insert_txn( - txn, - "current_state_events", - { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - or_replace=True, - ) + if is_new_state: + self._simple_insert_txn( + txn, + "current_state_events", + { + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + }, + or_replace=True, + ) for e_id, h in event.prev_state: self._simple_insert_txn( @@ -312,7 +331,12 @@ class DataStore(RoomMemberStore, RoomStore, txn, event.event_id, ref_alg, ref_hash_bytes ) - self._update_min_depth_for_room_txn(txn, event.room_id, event.depth) + if not outlier: + self._update_min_depth_for_room_txn( + txn, + event.room_id, + event.depth + ) def _store_redaction(self, txn, event): txn.execute( @@ -508,7 +532,7 @@ def prepare_database(db_conn): "new for the server to understand" ) elif user_version < SCHEMA_VERSION: - logging.info( + logger.info( "Upgrading database from version %d", user_version ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 5d4be09a82..4881f03368 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -57,7 +57,7 @@ class LoggingTransaction(object): if args and args[0]: values = args[0] sql_logger.debug( - "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)), + "[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)), self.name, *values ) @@ -91,6 +91,7 @@ class SQLBaseStore(object): def runInteraction(self, desc, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" current_context = LoggingContext.current_context() + def inner_func(txn, *args, **kwargs): with LoggingContext("runInteraction") as context: current_context.copy_to(context) @@ -115,7 +116,6 @@ class SQLBaseStore(object): "[TXN END] {%s} %f", name, end - start ) - with PreserveLoggingContext(): result = yield self._db_pool.runInteraction( inner_func, *args, **kwargs @@ -246,7 +246,10 @@ class SQLBaseStore(object): raise StoreError(404, "No row found") def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): - sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { + sql = ( + "SELECT %(retcol)s FROM %(table)s WHERE %(where)s " + "ORDER BY rowid asc" + ) % { "retcol": retcol, "table": table, "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), @@ -299,7 +302,7 @@ class SQLBaseStore(object): keyvalues : dict of column names and values to select the rows with retcols : list of strings giving the names of the columns to return """ - sql = "SELECT %s FROM %s WHERE %s" % ( + sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( ", ".join(retcols), table, " AND ".join("%s = ?" % (k, ) for k in keyvalues) @@ -334,7 +337,7 @@ class SQLBaseStore(object): retcols=None, allow_none=False): """ Combined SELECT then UPDATE.""" if retcols: - select_sql = "SELECT %s FROM %s WHERE %s" % ( + select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( ", ".join(retcols), table, " AND ".join("%s = ?" % (k) for k in keyvalues) @@ -461,7 +464,7 @@ class SQLBaseStore(object): def _get_events_txn(self, txn, event_ids): # FIXME (erikj): This should be batched? - sql = "SELECT * FROM events WHERE event_id = ?" + sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc" event_rows = [] for e_id in event_ids: @@ -478,7 +481,9 @@ class SQLBaseStore(object): def _parse_events_txn(self, txn, rows): events = [self._parse_event_from_row(r) for r in rows] - select_event_sql = "SELECT * FROM events WHERE event_id = ?" + select_event_sql = ( + "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc" + ) for i, ev in enumerate(events): signatures = self._get_event_signatures_txn( diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 1f89d77344..4d15005c9e 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -75,7 +75,9 @@ class RegistrationStore(SQLBaseStore): "VALUES (?,?,?)", [user_id, password_hash, now]) except IntegrityError: - raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + raise StoreError( + 400, "User ID already taken.", errcode=Codes.USER_IN_USE + ) # it's possible for this to get a conflict, but only for a single user # since tokens are namespaced based on their user ID @@ -83,8 +85,8 @@ class RegistrationStore(SQLBaseStore): "VALUES (?,?)", [txn.lastrowid, token]) def get_user_by_id(self, user_id): - query = ("SELECT users.name, users.password_hash FROM users " - "WHERE users.name = ?") + query = ("SELECT users.name, users.password_hash FROM users" + " WHERE users.name = ?") return self._execute( self.cursor_to_dict, query, user_id @@ -120,10 +122,10 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.admin, access_tokens.device_id " - "FROM users " - "INNER JOIN access_tokens on users.id = access_tokens.user_id " - "WHERE token = ?" + "SELECT users.name, users.admin, access_tokens.device_id" + " FROM users" + " INNER JOIN access_tokens on users.id = access_tokens.user_id" + " WHERE token = ?" ) cursor = txn.execute(sql, (token,)) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index cc0513b8d2..2378d65943 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -27,7 +27,9 @@ import logging logger = logging.getLogger(__name__) -OpsLevel = collections.namedtuple("OpsLevel", ("ban_level", "kick_level", "redact_level")) +OpsLevel = collections.namedtuple("OpsLevel", ( + "ban_level", "kick_level", "redact_level") +) class RoomStore(SQLBaseStore): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 93329703a2..c37df59d45 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -177,8 +177,8 @@ class RoomMemberStore(SQLBaseStore): return self._get_members_query(clause, vals) def _get_members_query(self, where_clause, where_values): - return self._db_pool.runInteraction( - self._get_members_query_txn, + return self.runInteraction( + "get_members_query", self._get_members_query_txn, where_clause, where_values ) diff --git a/synapse/storage/schema/delta/v8.sql b/synapse/storage/schema/delta/v8.sql new file mode 100644 index 0000000000..daf6646ed5 --- /dev/null +++ b/synapse/storage/schema/delta/v8.sql @@ -0,0 +1,34 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + CREATE TABLE IF NOT EXISTS event_signatures_2 ( + event_id TEXT, + signature_name TEXT, + key_id TEXT, + signature BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id) +); + +INSERT INTO event_signatures_2 (event_id, signature_name, key_id, signature) +SELECT event_id, signature_name, key_id, signature FROM event_signatures; + +DROP TABLE event_signatures; +ALTER TABLE event_signatures_2 RENAME TO event_signatures; + +CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures ( + event_id +); + +PRAGMA user_version = 8; \ No newline at end of file diff --git a/synapse/storage/schema/event_signatures.sql b/synapse/storage/schema/event_signatures.sql index 4efa8a3e63..b6b56b47a2 100644 --- a/synapse/storage/schema/event_signatures.sql +++ b/synapse/storage/schema/event_signatures.sql @@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS event_signatures ( signature_name TEXT, key_id TEXT, signature BLOB, - CONSTRAINT uniqueness UNIQUE (event_id, key_id) + CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id) ); CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures ( diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index d90e08fff1..eea4f21065 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -36,7 +36,7 @@ class SignatureStore(SQLBaseStore): return dict(txn.fetchall()) def _store_event_content_hash_txn(self, txn, event_id, algorithm, - hash_bytes): + hash_bytes): """Store a hash for a Event Args: txn (cursor): @@ -84,7 +84,7 @@ class SignatureStore(SQLBaseStore): return dict(txn.fetchall()) def _store_event_reference_hash_txn(self, txn, event_id, algorithm, - hash_bytes): + hash_bytes): """Store a hash for a PDU Args: txn (cursor): @@ -127,7 +127,7 @@ class SignatureStore(SQLBaseStore): return res def _store_event_signature_txn(self, txn, event_id, signature_name, key_id, - signature_bytes): + signature_bytes): """Store a signature from the origin server for a PDU. Args: txn (cursor): @@ -169,7 +169,7 @@ class SignatureStore(SQLBaseStore): return results def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id, - algorithm, hash_bytes): + algorithm, hash_bytes): self._simple_insert_txn( txn, "event_edge_hashes", @@ -180,4 +180,4 @@ class SignatureStore(SQLBaseStore): "hash": buffer(hash_bytes), }, or_ignore=True, - ) \ No newline at end of file + ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 55ea567793..e0f44b3e59 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -87,7 +87,7 @@ class StateStore(SQLBaseStore): ) def _store_state_groups_txn(self, txn, event): - if not event.state_events: + if event.state_events is None: return state_group = event.state_group diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index a954024678..b84735e61c 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -213,8 +213,8 @@ class StreamStore(SQLBaseStore): # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. - from_comp = '<=' if direction =='b' else '>' - to_comp = '>' if direction =='b' else '<=' + from_comp = '<=' if direction == 'b' else '>' + to_comp = '>' if direction == 'b' else '<=' order = "DESC" if direction == 'b' else "ASC" args = [room_id] @@ -235,9 +235,10 @@ class StreamStore(SQLBaseStore): ) sql = ( - "SELECT *, (%(redacted)s) AS redacted FROM events " - "WHERE outlier = 0 AND room_id = ? AND %(bounds)s " - "ORDER BY topological_ordering %(order)s, stream_ordering %(order)s %(limit)s " + "SELECT *, (%(redacted)s) AS redacted FROM events" + " WHERE outlier = 0 AND room_id = ? AND %(bounds)s" + " ORDER BY topological_ordering %(order)s," + " stream_ordering %(order)s %(limit)s" ) % { "redacted": del_sql, "bounds": bounds, diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 527507e5cd..0317e78c08 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -28,11 +28,11 @@ class SourcePaginationConfig(object): specific event source.""" def __init__(self, from_key=None, to_key=None, direction='f', - limit=0): + limit=None): self.from_key = from_key self.to_key = to_key self.direction = 'f' if direction == 'f' else 'b' - self.limit = int(limit) + self.limit = int(limit) if limit is not None else None class PaginationConfig(object): @@ -40,11 +40,11 @@ class PaginationConfig(object): """A configuration object which stores pagination parameters.""" def __init__(self, from_token=None, to_token=None, direction='f', - limit=0): + limit=None): self.from_token = from_token self.to_token = to_token self.direction = 'f' if direction == 'f' else 'b' - self.limit = int(limit) + self.limit = int(limit) if limit is not None else None @classmethod def from_request(cls, request, raise_invalid_params=True): @@ -80,8 +80,8 @@ class PaginationConfig(object): except: raise SynapseError(400, "'to' paramater is invalid") - limit = get_param("limit", "0") - if not limit.isdigit(): + limit = get_param("limit", None) + if limit is not None and not limit.isdigit(): raise SynapseError(400, "'limit' parameter must be an integer.") try: diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index e57fb0e914..7ec5033ceb 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -37,6 +37,7 @@ class Clock(object): def call_later(self, delay, callback): current_context = LoggingContext.current_context() + def wrapped_callback(): LoggingContext.thread_local.current_context = current_context callback() diff --git a/synapse/util/async.py b/synapse/util/async.py index 1219d927db..7dd3ec3a72 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -18,6 +18,7 @@ from twisted.internet import defer, reactor from .logcontext import PreserveLoggingContext + @defer.inlineCallbacks def sleep(seconds): d = defer.Deferred() @@ -25,6 +26,7 @@ def sleep(seconds): with PreserveLoggingContext(): yield d + def run_on_reactor(): """ This will cause the rest of the function to be invoked upon the next iteration of the main loop diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index eddbe5837f..701ccdb781 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.util.logcontext import PreserveLoggingContext + from twisted.internet import defer import logging @@ -91,6 +93,7 @@ class Signal(object): Each observer callable may return a Deferred.""" self.observers.append(observer) + @defer.inlineCallbacks def fire(self, *args, **kwargs): """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is @@ -98,22 +101,24 @@ class Signal(object): Returns a Deferred that will complete when all the observers have completed.""" - deferreds = [] - for observer in self.observers: - d = defer.maybeDeferred(observer, *args, **kwargs) - - def eb(failure): - logger.warning( - "%s signal observer %s failed: %r", - self.name, observer, failure, - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject())) - if not self.suppress_failures: - raise failure - deferreds.append(d.addErrback(eb)) - - return defer.DeferredList( - deferreds, fireOnOneErrback=not self.suppress_failures - ) + with PreserveLoggingContext(): + deferreds = [] + for observer in self.observers: + d = defer.maybeDeferred(observer, *args, **kwargs) + + def eb(failure): + logger.warning( + "%s signal observer %s failed: %r", + self.name, observer, failure, + exc_info=( + failure.type, + failure.value, + failure.getTracebackObject())) + if not self.suppress_failures: + raise failure + deferreds.append(d.addErrback(eb)) + + result = yield defer.DeferredList( + deferreds, fireOnOneErrback=not self.suppress_failures + ) + defer.returnValue(result) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 2f430a0f19..7d85018d97 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -1,6 +1,8 @@ import threading import logging +logger = logging.getLogger(__name__) + class LoggingContext(object): """Additional context for log formatting. Contexts are scoped within a @@ -53,11 +55,14 @@ class LoggingContext(object): None to avoid suppressing any exeptions that were thrown. """ if self.thread_local.current_context is not self: - logging.error( - "Current logging context %s is not the expected context %s", - self.thread_local.current_context, - self - ) + if self.thread_local.current_context is self.sentinel: + logger.debug("Expected logging context %s has been lost", self) + else: + logger.warn( + "Current logging context %s is not expected context %s", + self.thread_local.current_context, + self + ) self.thread_local.current_context = self.parent_context self.parent_context = None diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 3487a090e9..98cfbe50b3 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -83,20 +83,22 @@ class FederationTestCase(unittest.TestCase): event_id="$a:b", user_id="@a:b", origin="b", + auth_events=[], hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"}, ) self.datastore.persist_event.return_value = defer.succeed(None) self.datastore.get_room.return_value = defer.succeed(True) - self.state_handler.annotate_event_with_state.return_value = ( - defer.succeed(False) - ) + def annotate(ev, old_state=None): + ev.old_state_events = [] + return defer.succeed(False) + self.state_handler.annotate_event_with_state.side_effect = annotate yield self.handlers.federation_handler.on_receive_pdu(pdu, False) self.datastore.persist_event.assert_called_once_with( - ANY, False, is_new_state=False + ANY, is_new_state=False, backfilled=False, current_state=None ) self.state_handler.annotate_event_with_state.assert_called_once_with( @@ -104,7 +106,7 @@ class FederationTestCase(unittest.TestCase): old_state=None, ) - self.auth.check.assert_called_once_with(ANY, raises=True) + self.auth.check.assert_called_once_with(ANY, auth_events={}) self.notifier.on_new_room_event.assert_called_once_with( ANY, diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index cbe591ab90..0279ab703a 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -120,7 +120,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): self.datastore.get_room_member.return_value = defer.succeed(None) - event.state_events = { + event.old_state_events = { (RoomMemberEvent.TYPE, "@alice:green"): self._create_member( user_id="@alice:green", room_id=room_id, @@ -129,9 +129,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase): user_id="@bob:red", room_id=room_id, ), - (RoomMemberEvent.TYPE, target_user_id): event, } + event.state_events = event.old_state_events + event.state_events[(RoomMemberEvent.TYPE, target_user_id)] = event + # Actual invocation yield self.room_member_handler.change_membership(event) @@ -187,6 +189,16 @@ class RoomMemberHandlerTestCase(unittest.TestCase): (RoomMemberEvent.TYPE, user_id): event, } + event.old_state_events = { + (RoomMemberEvent.TYPE, "@alice:green"): self._create_member( + user_id="@alice:green", + room_id=room_id, + ), + } + + event.state_events = event.old_state_events + event.state_events[(RoomMemberEvent.TYPE, user_id)] = event + # Actual invocation yield self.room_member_handler.change_membership(event) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index fabd364be9..a6f1d6a333 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -84,7 +84,8 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.assertEquals("Value", value) self.mock_txn.execute.assert_called_with( - "SELECT retcol FROM tablename WHERE keycol = ?", + "SELECT retcol FROM tablename WHERE keycol = ? " + "ORDER BY rowid asc", ["TheKey"] ) @@ -101,7 +102,8 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) self.mock_txn.execute.assert_called_with( - "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", + "SELECT colA, colB, colC FROM tablename WHERE keycol = ? " + "ORDER BY rowid asc", ["TheKey"] ) @@ -135,7 +137,8 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.mock_txn.execute.assert_called_with( - "SELECT colA FROM tablename WHERE keycol = ?", + "SELECT colA FROM tablename WHERE keycol = ? " + "ORDER BY rowid asc", ["A set"] ) @@ -184,7 +187,8 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.assertEquals({"columname": "Old Value"}, ret) self.mock_txn.execute.assert_has_calls([ - call('SELECT columname FROM tablename WHERE keycol = ?', + call('SELECT columname FROM tablename WHERE keycol = ? ' + 'ORDER BY rowid asc', ['TheKey']), call("UPDATE tablename SET columname = ? WHERE keycol = ?", ["New Value", "TheKey"]) |