diff options
57 files changed, 2736 insertions, 550 deletions
diff --git a/demo/start.sh b/demo/start.sh index d90115ec97..dcc4d6f4fa 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -38,8 +38,12 @@ for port in 8080 8081 8082; do perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config - echo "full_twisted_stacktraces: true" >> $DIR/etc/$port.config - echo "report_stats: false" >> $DIR/etc/$port.config + if ! grep -F "full_twisted_stacktraces" -q $DIR/etc/$port.config; then + echo "full_twisted_stacktraces: true" >> $DIR/etc/$port.config + fi + if ! grep -F "report_stats" -q $DIR/etc/$port.config ; then + echo "report_stats: false" >> $DIR/etc/$port.config + fi python -m synapse.app.homeserver \ --config-path "$DIR/etc/$port.config" \ diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e3b8c3099a..88445fe999 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -14,13 +14,18 @@ # limitations under the License. """This module contains classes for authenticating the user.""" +from canonicaljson import encode_canonical_json +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json, SignatureVerifyException from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules -from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.util.logutils import log_function +from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.types import RoomID, UserID, EventID +from synapse.util.logutils import log_function +from synapse.util import third_party_invites +from unpaddedbase64 import decode_base64 import logging import pymacaroons @@ -31,6 +36,7 @@ logger = logging.getLogger(__name__) AuthEventTypes = ( EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels, EventTypes.JoinRules, EventTypes.RoomHistoryVisibility, + EventTypes.ThirdPartyInvite, ) @@ -59,6 +65,8 @@ class Auth(object): Returns: True if the auth checks pass. """ + self.check_size_limits(event) + try: if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) @@ -126,6 +134,23 @@ class Auth(object): logger.info("Denying! %s", event) raise + def check_size_limits(self, event): + def too_big(field): + raise EventSizeError("%s too large" % (field,)) + + if len(event.user_id) > 255: + too_big("user_id") + if len(event.room_id) > 255: + too_big("room_id") + if event.is_state() and len(event.state_key) > 255: + too_big("state_key") + if len(event.type) > 255: + too_big("type") + if len(event.event_id) > 255: + too_big("event_id") + if len(encode_canonical_json(event.get_pdu_json())) > 65536: + too_big("event") + @defer.inlineCallbacks def check_joined_room(self, room_id, user_id, current_state=None): """Check if the user is currently joined in the room @@ -303,7 +328,11 @@ class Auth(object): ) if Membership.JOIN != membership: - # JOIN is the only action you can perform if you're not in the room + if (caller_invited + and Membership.LEAVE == membership + and target_user_id == event.user_id): + return True + if not caller_in_room: # caller isn't joined raise AuthError( 403, @@ -341,7 +370,8 @@ class Auth(object): pass elif join_rule == JoinRules.INVITE: if not caller_in_room and not caller_invited: - raise AuthError(403, "You are not invited to this room.") + if not self._verify_third_party_invite(event, auth_events): + raise AuthError(403, "You are not invited to this room.") else: # TODO (erikj): may_join list # TODO (erikj): private rooms @@ -367,6 +397,68 @@ class Auth(object): return True + def _verify_third_party_invite(self, event, auth_events): + """ + Validates that the join event is authorized by a previous third-party invite. + + Checks that the public key, and keyserver, match those in the invite, + and that the join event has a signature issued using that public key. + + Args: + event: The m.room.member join event being validated. + auth_events: All relevant previous context events which may be used + for authorization decisions. + + Return: + True if the event fulfills the expectations of a previous third party + invite event. + """ + if not third_party_invites.join_has_third_party_invite(event.content): + return False + join_third_party_invite = event.content["third_party_invite"] + token = join_third_party_invite["token"] + invite_event = auth_events.get( + (EventTypes.ThirdPartyInvite, token,) + ) + if not invite_event: + logger.info("Failing 3pid invite because no invite found for token %s", token) + return False + try: + public_key = join_third_party_invite["public_key"] + key_validity_url = join_third_party_invite["key_validity_url"] + if invite_event.content["public_key"] != public_key: + logger.info( + "Failing 3pid invite because public key invite: %s != join: %s", + invite_event.content["public_key"], + public_key + ) + return False + if invite_event.content["key_validity_url"] != key_validity_url: + logger.info( + "Failing 3pid invite because key_validity_url invite: %s != join: %s", + invite_event.content["key_validity_url"], + key_validity_url + ) + return False + signed = join_third_party_invite["signed"] + if signed["mxid"] != event.user_id: + return False + if signed["token"] != token: + return False + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + return False + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + return True + return False + except (KeyError, SignatureVerifyException,): + return False + def _get_power_level_event(self, auth_events): key = (EventTypes.PowerLevels, "", ) return auth_events.get(key) @@ -646,6 +738,14 @@ class Auth(object): if e_type == Membership.JOIN: if member_event and not is_public: auth_ids.append(member_event.event_id) + if third_party_invites.join_has_third_party_invite(event.content): + key = ( + EventTypes.ThirdPartyInvite, + event.content["third_party_invite"]["token"] + ) + invite = current_state.get(key) + if invite: + auth_ids.append(invite.event_id) else: if member_event: auth_ids.append(member_event.event_id) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 008ee64727..41125e8719 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -63,6 +63,7 @@ class EventTypes(object): PowerLevels = "m.room.power_levels" Aliases = "m.room.aliases" Redaction = "m.room.redaction" + ThirdPartyInvite = "m.room.third_party_invite" RoomHistoryVisibility = "m.room.history_visibility" CanonicalAlias = "m.room.canonical_alias" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index d1356eb4d9..b3fea27d0e 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -119,6 +119,15 @@ class AuthError(SynapseError): super(AuthError, self).__init__(*args, **kwargs) +class EventSizeError(SynapseError): + """An error raised when an event is too big.""" + + def __init__(self, *args, **kwargs): + if "errcode" not in kwargs: + kwargs["errcode"] = Codes.TOO_LARGE + super(EventSizeError, self).__init__(413, *args, **kwargs) + + class EventStreamError(SynapseError): """An error raised when there a problem with the event stream.""" def __init__(self, *args, **kwargs): diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index e79e91e7eb..e4e3d1c59d 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -24,7 +24,7 @@ class Filtering(object): def get_user_filter(self, user_localpart, filter_id): result = self.store.get_user_filter(user_localpart, filter_id) - result.addCallback(Filter) + result.addCallback(FilterCollection) return result def add_user_filter(self, user_localpart, user_filter): @@ -131,125 +131,133 @@ class Filtering(object): raise SynapseError(400, "Bad bundle_updates: expected bool.") -class Filter(object): +class FilterCollection(object): def __init__(self, filter_json): self.filter_json = filter_json + self.room_timeline_filter = Filter( + self.filter_json.get("room", {}).get("timeline", {}) + ) + + self.room_state_filter = Filter( + self.filter_json.get("room", {}).get("state", {}) + ) + + self.room_ephemeral_filter = Filter( + self.filter_json.get("room", {}).get("ephemeral", {}) + ) + + self.room_private_user_data = Filter( + self.filter_json.get("room", {}).get("private_user_data", {}) + ) + + self.presence_filter = Filter( + self.filter_json.get("presence", {}) + ) + def timeline_limit(self): - return self.filter_json.get("room", {}).get("timeline", {}).get("limit", 10) + return self.room_timeline_filter.limit() def presence_limit(self): - return self.filter_json.get("presence", {}).get("limit", 10) + return self.presence_filter.limit() def ephemeral_limit(self): - return self.filter_json.get("room", {}).get("ephemeral", {}).get("limit", 10) + return self.room_ephemeral_filter.limit() def filter_presence(self, events): - return self._filter_on_key(events, ["presence"]) + return self.presence_filter.filter(events) def filter_room_state(self, events): - return self._filter_on_key(events, ["room", "state"]) + return self.room_state_filter.filter(events) def filter_room_timeline(self, events): - return self._filter_on_key(events, ["room", "timeline"]) + return self.room_timeline_filter.filter(events) def filter_room_ephemeral(self, events): - return self._filter_on_key(events, ["room", "ephemeral"]) - - def _filter_on_key(self, events, keys): - filter_json = self.filter_json - if not filter_json: - return events - - try: - # extract the right definition from the filter - definition = filter_json - for key in keys: - definition = definition[key] - return self._filter_with_definition(events, definition) - except KeyError: - # return all events if definition isn't specified. - return events - - def _filter_with_definition(self, events, definition): - return [e for e in events if self._passes_definition(definition, e)] - - def _passes_definition(self, definition, event): - """Check if the event passes the filter definition - Args: - definition(dict): The filter definition to check against - event(dict or Event): The event to check + return self.room_ephemeral_filter.filter(events) + + def filter_room_private_user_data(self, events): + return self.room_private_user_data.filter(events) + + +class Filter(object): + def __init__(self, filter_json): + self.filter_json = filter_json + + def check(self, event): + """Checks whether the filter matches the given event. + Returns: - True if the event passes the filter in the definition + bool: True if the event matches """ - if type(event) is dict: - room_id = event.get("room_id") - sender = event.get("sender") - event_type = event["type"] + if isinstance(event, dict): + return self.check_fields( + event.get("room_id", None), + event.get("sender", None), + event.get("type", None), + ) else: - room_id = getattr(event, "room_id", None) - sender = getattr(event, "sender", None) - event_type = event.type - return self._event_passes_definition( - definition, room_id, sender, event_type - ) + return self.check_fields( + getattr(event, "room_id", None), + getattr(event, "sender", None), + event.type, + ) - def _event_passes_definition(self, definition, room_id, sender, - event_type): - """Check if the event passes through the given definition. + def check_fields(self, room_id, sender, event_type): + """Checks whether the filter matches the given event fields. - Args: - definition(dict): The definition to check against. - room_id(str): The id of the room this event is in or None. - sender(str): The sender of the event - event_type(str): The type of the event. Returns: - True if the event passes through the filter. + bool: True if the event fields match """ - # Algorithm notes: - # For each key in the definition, check the event meets the criteria: - # * For types: Literal match or prefix match (if ends with wildcard) - # * For senders/rooms: Literal match only - # * "not_" checks take presedence (e.g. if "m.*" is in both 'types' - # and 'not_types' then it is treated as only being in 'not_types') - - # room checks - if room_id is not None: - allow_rooms = definition.get("rooms", None) - reject_rooms = definition.get("not_rooms", None) - if reject_rooms and room_id in reject_rooms: - return False - if allow_rooms and room_id not in allow_rooms: + literal_keys = { + "rooms": lambda v: room_id == v, + "senders": lambda v: sender == v, + "types": lambda v: _matches_wildcard(event_type, v) + } + + for name, match_func in literal_keys.items(): + not_name = "not_%s" % (name,) + disallowed_values = self.filter_json.get(not_name, []) + if any(map(match_func, disallowed_values)): return False - # sender checks - if sender is not None: - allow_senders = definition.get("senders", None) - reject_senders = definition.get("not_senders", None) - if reject_senders and sender in reject_senders: - return False - if allow_senders and sender not in allow_senders: - return False - - # type checks - if "not_types" in definition: - for def_type in definition["not_types"]: - if self._event_matches_type(event_type, def_type): + allowed_values = self.filter_json.get(name, None) + if allowed_values is not None: + if not any(map(match_func, allowed_values)): return False - if "types" in definition: - included = False - for def_type in definition["types"]: - if self._event_matches_type(event_type, def_type): - included = True - break - if not included: - return False return True - def _event_matches_type(self, event_type, def_type): - if def_type.endswith("*"): - type_prefix = def_type[:-1] - return event_type.startswith(type_prefix) - else: - return event_type == def_type + def filter_rooms(self, room_ids): + """Apply the 'rooms' filter to a given list of rooms. + + Args: + room_ids (list): A list of room_ids. + + Returns: + list: A list of room_ids that match the filter + """ + room_ids = set(room_ids) + + disallowed_rooms = set(self.filter_json.get("not_rooms", [])) + room_ids -= disallowed_rooms + + allowed_rooms = self.filter_json.get("rooms", None) + if allowed_rooms is not None: + room_ids &= set(allowed_rooms) + + return room_ids + + def filter(self, events): + return filter(self.check, events) + + def limit(self): + return self.filter_json.get("limit", 10) + + +def _matches_wildcard(actual_value, filter_value): + if filter_value.endswith("*"): + type_prefix = filter_value[:-1] + return actual_value.startswith(type_prefix) + else: + return actual_value == filter_value diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 59b0b1f4ac..44dc2c4744 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -224,8 +224,8 @@ class _Recoverer(object): self.clock.call_later((2 ** self.backoff_counter), self.retry) def _backoff(self): - # cap the backoff to be around 18h => (2^16) = 65536 secs - if self.backoff_counter < 16: + # cap the backoff to be around 8.5min => (2^9) = 512 secs + if self.backoff_counter < 9: self.backoff_counter += 1 self.recover() diff --git a/synapse/config/registration.py b/synapse/config/registration.py index fa98eced34..f5ef36a9f4 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -33,6 +33,7 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.macaroon_secret_key = config.get("macaroon_secret_key") + self.bcrypt_rounds = config.get("bcrypt_rounds", 12) def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -48,6 +49,11 @@ class RegistrationConfig(Config): registration_shared_secret: "%(registration_shared_secret)s" macaroon_secret_key: "%(macaroon_secret_key)s" + + # Set the number of bcrypt rounds used to generate password hash. + # Larger numbers increase the work factor needed to generate the hash. + # The default number of rounds is 12. + bcrypt_rounds: 12 """ % locals() def add_arguments(self, parser): diff --git a/synapse/events/utils.py b/synapse/events/utils.py index b36eec0993..9989b76591 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -66,7 +66,6 @@ def prune_event(event): "users_default", "events", "events_default", - "events_default", "state_default", "ban", "kick", @@ -155,7 +154,8 @@ def serialize_event(e, time_now_ms, as_client_event=True, if "redacted_because" in e.unsigned: d["unsigned"]["redacted_because"] = serialize_event( - e.unsigned["redacted_because"], time_now_ms + e.unsigned["redacted_because"], time_now_ms, + event_format=event_format ) if token_id is not None: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f5e346cdbc..723f571284 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -17,6 +17,7 @@ from twisted.internet import defer from .federation_base import FederationBase +from synapse.api.constants import Membership from .units import Edu from synapse.api.errors import ( @@ -25,6 +26,7 @@ from synapse.api.errors import ( from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function +from synapse.util import third_party_invites from synapse.events import FrozenEvent import synapse.metrics @@ -356,19 +358,51 @@ class FederationClient(FederationBase): defer.returnValue(signed_auth) @defer.inlineCallbacks - def make_join(self, destinations, room_id, user_id): + def make_membership_event(self, destinations, room_id, user_id, membership, content): + """ + Creates an m.room.member event, with context, without participating in the room. + + Does so by asking one of the already participating servers to create an + event with proper context. + + Note that this does not append any events to any graphs. + + Args: + destinations (str): Candidate homeservers which are probably + participating in the room. + room_id (str): The room in which the event will happen. + user_id (str): The user whose membership is being evented. + membership (str): The "membership" property of the event. Must be + one of "join" or "leave". + content (object): Any additional data to put into the content field + of the event. + Return: + A tuple of (origin (str), event (object)) where origin is the remote + homeserver which generated the event. + """ + valid_memberships = {Membership.JOIN, Membership.LEAVE} + if membership not in valid_memberships: + raise RuntimeError( + "make_membership_event called with membership='%s', must be one of %s" % + (membership, ",".join(valid_memberships)) + ) for destination in destinations: if destination == self.server_name: continue + args = {} + if third_party_invites.join_has_third_party_invite(content): + args = third_party_invites.extract_join_keys( + content["third_party_invite"] + ) try: - ret = yield self.transport_layer.make_join( - destination, room_id, user_id + ret = yield self.transport_layer.make_membership_event( + destination, room_id, user_id, membership, args ) pdu_dict = ret["event"] - logger.debug("Got response to make_join: %s", pdu_dict) + logger.debug("Got response to make_%s: %s", membership, pdu_dict) defer.returnValue( (destination, self.event_from_pdu_json(pdu_dict)) @@ -378,8 +412,8 @@ class FederationClient(FederationBase): raise except Exception as e: logger.warn( - "Failed to make_join via %s: %s", - destination, e.message + "Failed to make_%s via %s: %s", + membership, destination, e.message ) raise RuntimeError("Failed to send to any server.") @@ -486,6 +520,33 @@ class FederationClient(FederationBase): defer.returnValue(pdu) @defer.inlineCallbacks + def send_leave(self, destinations, pdu): + for destination in destinations: + if destination == self.server_name: + continue + + try: + time_now = self._clock.time_msec() + _, content = yield self.transport_layer.send_leave( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + logger.debug("Got content: %s", content) + defer.returnValue(None) + except CodeMessageException: + raise + except Exception as e: + logger.exception( + "Failed to send_leave via %s: %s", + destination, e.message + ) + + raise RuntimeError("Failed to send to any server.") + + @defer.inlineCallbacks def query_auth(self, destination, room_id, event_id, local_auth): """ Params: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 725c6f3fa5..9e2d9ee74c 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -23,10 +23,12 @@ from synapse.util.logutils import log_function from synapse.events import FrozenEvent import synapse.metrics -from synapse.api.errors import FederationError, SynapseError +from synapse.api.errors import FederationError, SynapseError, Codes from synapse.crypto.event_signing import compute_event_signature +from synapse.util import third_party_invites + import simplejson as json import logging @@ -228,8 +230,19 @@ class FederationServer(FederationBase): ) @defer.inlineCallbacks - def on_make_join_request(self, room_id, user_id): - pdu = yield self.handler.on_make_join_request(room_id, user_id) + def on_make_join_request(self, room_id, user_id, query): + threepid_details = {} + if third_party_invites.has_join_keys(query): + for k in third_party_invites.JOIN_KEYS: + if not isinstance(query[k], list) or len(query[k]) != 1: + raise FederationError( + "FATAL", + Codes.MISSING_PARAM, + "key %s value %s" % (k, query[k],), + None + ) + threepid_details[k] = query[k][0] + pdu = yield self.handler.on_make_join_request(room_id, user_id, threepid_details) time_now = self._clock.time_msec() defer.returnValue({"event": pdu.get_pdu_json(time_now)}) @@ -255,6 +268,20 @@ class FederationServer(FederationBase): })) @defer.inlineCallbacks + def on_make_leave_request(self, room_id, user_id): + pdu = yield self.handler.on_make_leave_request(room_id, user_id) + time_now = self._clock.time_msec() + defer.returnValue({"event": pdu.get_pdu_json(time_now)}) + + @defer.inlineCallbacks + def on_send_leave_request(self, origin, content): + logger.debug("on_send_leave_request: content: %s", content) + pdu = self.event_from_pdu_json(content) + logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) + yield self.handler.on_send_leave_request(origin, pdu) + defer.returnValue((200, {})) + + @defer.inlineCallbacks def on_event_auth(self, origin, room_id, event_id): time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 32fa5e8c15..aac6f1c167 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -202,6 +202,7 @@ class TransactionQueue(object): @defer.inlineCallbacks @log_function def _attempt_new_transaction(self, destination): + # list of (pending_pdu, deferred, order) if destination in self.pending_transactions: # XXX: pending_transactions can get stuck on by a never-ending # request at which point pending_pdus_by_dest just keeps growing. @@ -213,9 +214,6 @@ class TransactionQueue(object): ) return - logger.debug("TX [%s] _attempt_new_transaction", destination) - - # list of (pending_pdu, deferred, order) pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_failures = self.pending_failures_by_dest.pop(destination, []) @@ -228,20 +226,22 @@ class TransactionQueue(object): logger.debug("TX [%s] Nothing to send", destination) return - # Sort based on the order field - pending_pdus.sort(key=lambda t: t[2]) - - pdus = [x[0] for x in pending_pdus] - edus = [x[0] for x in pending_edus] - failures = [x[0].get_dict() for x in pending_failures] - deferreds = [ - x[1] - for x in pending_pdus + pending_edus + pending_failures - ] - try: self.pending_transactions[destination] = 1 + logger.debug("TX [%s] _attempt_new_transaction", destination) + + # Sort based on the order field + pending_pdus.sort(key=lambda t: t[2]) + + pdus = [x[0] for x in pending_pdus] + edus = [x[0] for x in pending_edus] + failures = [x[0].get_dict() for x in pending_failures] + deferreds = [ + x[1] + for x in pending_pdus + pending_edus + pending_failures + ] + txn_id = str(self._next_txn_id) limiter = yield get_retry_limiter( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index ced703364b..a81b3c4345 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -14,6 +14,7 @@ # limitations under the License. from twisted.internet import defer +from synapse.api.constants import Membership from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.util.logutils import log_function @@ -160,13 +161,20 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def make_join(self, destination, room_id, user_id, retry_on_dns_fail=True): - path = PREFIX + "/make_join/%s/%s" % (room_id, user_id) + def make_membership_event(self, destination, room_id, user_id, membership, args={}): + valid_memberships = {Membership.JOIN, Membership.LEAVE} + if membership not in valid_memberships: + raise RuntimeError( + "make_membership_event called with membership='%s', must be one of %s" % + (membership, ",".join(valid_memberships)) + ) + path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id) content = yield self.client.get_json( destination=destination, path=path, - retry_on_dns_fail=retry_on_dns_fail, + args=args, + retry_on_dns_fail=True, ) defer.returnValue(content) @@ -186,6 +194,19 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function + def send_leave(self, destination, room_id, event_id, content): + path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id) + + response = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks + @log_function def send_invite(self, destination, room_id, event_id, content): path = PREFIX + "/invite/%s/%s" % (room_id, event_id) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 36f250e1a3..8184159210 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -292,7 +292,25 @@ class FederationMakeJoinServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_join_request(context, user_id) + content = yield self.handler.on_make_join_request(context, user_id, query) + defer.returnValue((200, content)) + + +class FederationMakeLeaveServlet(BaseFederationServlet): + PATH = "/make_leave/([^/]*)/([^/]*)" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, context, user_id): + content = yield self.handler.on_make_leave_request(context, user_id) + defer.returnValue((200, content)) + + +class FederationSendLeaveServlet(BaseFederationServlet): + PATH = "/send_leave/([^/]*)/([^/]*)" + + @defer.inlineCallbacks + def on_PUT(self, origin, content, query, room_id, txid): + content = yield self.handler.on_send_leave_request(origin, content) defer.returnValue((200, content)) @@ -385,8 +403,10 @@ SERVLET_CLASSES = ( FederationBackfillServlet, FederationQueryServlet, FederationMakeJoinServlet, + FederationMakeLeaveServlet, FederationEventServlet, FederationSendJoinServlet, + FederationSendLeaveServlet, FederationInviteServlet, FederationQueryAuthServlet, FederationGetMissingEventsServlet, diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 8725c3c420..6a2339f2eb 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -17,7 +17,7 @@ from synapse.appservice.scheduler import AppServiceScheduler from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( - RoomCreationHandler, RoomMemberHandler, RoomListHandler + RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler, ) from .message import MessageHandler from .events import EventStreamHandler, EventHandler @@ -32,6 +32,7 @@ from .sync import SyncHandler from .auth import AuthHandler from .identity import IdentityHandler from .receipts import ReceiptsHandler +from .search import SearchHandler class Handlers(object): @@ -68,3 +69,5 @@ class Handlers(object): self.sync_handler = SyncHandler(hs) self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) + self.search_handler = SearchHandler(hs) + self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index c488ee0f6d..6a26cb1879 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -21,6 +21,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.types import UserID, RoomAlias from synapse.util.logcontext import PreserveLoggingContext +from synapse.util import third_party_invites import logging @@ -45,6 +46,52 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() + @defer.inlineCallbacks + def _filter_events_for_client(self, user_id, events): + event_id_to_state = yield self.store.get_state_for_events( + frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) + ) + + def allowed(event, state): + if event.type == EventTypes.RoomHistoryVisibility: + return True + + membership_ev = state.get((EventTypes.Member, user_id), None) + if membership_ev: + membership = membership_ev.membership + else: + membership = Membership.LEAVE + + if membership == Membership.JOIN: + return True + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history: + visibility = history.content.get("history_visibility", "shared") + else: + visibility = "shared" + + if visibility == "public": + return True + elif visibility == "shared": + return True + elif visibility == "joined": + return membership == Membership.JOIN + elif visibility == "invited": + return membership == Membership.INVITE + + return True + + defer.returnValue([ + event + for event in events + if allowed(event, event_id_to_state[event.event_id]) + ]) + def ratelimit(self, user_id): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.send_message( @@ -123,6 +170,16 @@ class BaseHandler(object): ) ) + if ( + event.type == EventTypes.Member and + event.content["membership"] == Membership.JOIN and + third_party_invites.join_has_third_party_invite(event.content) + ): + yield third_party_invites.check_key_valid( + self.hs.get_simple_http_client(), + event + ) + federation_handler = self.hs.get_handlers().federation_handler if event.type == EventTypes.Member: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 484f719253..055d395b20 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -44,6 +44,7 @@ class AuthHandler(BaseHandler): LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.DUMMY: self._check_dummy_auth, } + self.bcrypt_rounds = hs.config.bcrypt_rounds self.sessions = {} @defer.inlineCallbacks @@ -432,7 +433,7 @@ class AuthHandler(BaseHandler): Returns: Hashed password (str). """ - return bcrypt.hashpw(password, bcrypt.gensalt()) + return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds)) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a710bdcfdb..ae9d227586 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -39,7 +39,7 @@ from twisted.internet import defer import itertools import logging - +from synapse.util import third_party_invites logger = logging.getLogger(__name__) @@ -565,7 +565,7 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot): + def do_invite_join(self, target_hosts, room_id, joinee, content): """ Attempts to join the `joinee` to the room `room_id` via the server `target_host`. @@ -581,49 +581,19 @@ class FederationHandler(BaseHandler): yield self.store.clean_room_for_join(room_id) - origin, pdu = yield self.replication_layer.make_join( + origin, event = yield self._make_and_verify_event( target_hosts, room_id, - joinee + joinee, + "join", + content ) - logger.debug("Got response to make_join: %s", pdu) - - event = pdu - - # We should assert some things. - # FIXME: Do this in a nicer way - assert(event.type == EventTypes.Member) - assert(event.user_id == joinee) - assert(event.state_key == joinee) - assert(event.room_id == room_id) - - event.internal_metadata.outlier = False - self.room_queues[room_id] = [] - - builder = self.event_builder_factory.new( - unfreeze(event.get_pdu_json()) - ) - handled_events = set() try: - builder.event_id = self.event_builder_factory.create_event_id() - builder.origin = self.hs.hostname - builder.content = content - - if not hasattr(event, "signatures"): - builder.signatures = {} - - add_hashes_and_signatures( - builder, - self.hs.hostname, - self.hs.config.signing_key[0], - ) - - new_event = builder.build() - + new_event = self._sign_event(event) # Try the host we successfully got a response to /make_join/ # request first. try: @@ -631,11 +601,7 @@ class FederationHandler(BaseHandler): target_hosts.insert(0, origin) except ValueError: pass - - ret = yield self.replication_layer.send_join( - target_hosts, - new_event - ) + ret = yield self.replication_layer.send_join(target_hosts, new_event) origin = ret["origin"] state = ret["state"] @@ -697,14 +663,20 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function - def on_make_join_request(self, room_id, user_id): + def on_make_join_request(self, room_id, user_id, query): """ We've received a /make_join/ request, so we create a partial - join event for the room and return that. We don *not* persist or + join event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. """ + event_content = {"membership": Membership.JOIN} + if third_party_invites.has_join_keys(query): + event_content["third_party_invite"] = ( + third_party_invites.extract_join_keys(query) + ) + builder = self.event_builder_factory.new({ "type": EventTypes.Member, - "content": {"membership": Membership.JOIN}, + "content": event_content, "room_id": room_id, "sender": user_id, "state_key": user_id, @@ -716,6 +688,9 @@ class FederationHandler(BaseHandler): self.auth.check(event, auth_events=context.current_state) + if third_party_invites.join_has_third_party_invite(event.content): + third_party_invites.check_key_valid(self.hs.get_simple_http_client(), event) + defer.returnValue(event) @defer.inlineCallbacks @@ -850,6 +825,168 @@ class FederationHandler(BaseHandler): defer.returnValue(event) @defer.inlineCallbacks + def do_remotely_reject_invite(self, target_hosts, room_id, user_id): + origin, event = yield self._make_and_verify_event( + target_hosts, + room_id, + user_id, + "leave", + {} + ) + signed_event = self._sign_event(event) + + # Try the host we successfully got a response to /make_join/ + # request first. + try: + target_hosts.remove(origin) + target_hosts.insert(0, origin) + except ValueError: + pass + + yield self.replication_layer.send_leave( + target_hosts, + signed_event + ) + defer.returnValue(None) + + @defer.inlineCallbacks + def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, content): + origin, pdu = yield self.replication_layer.make_membership_event( + target_hosts, + room_id, + user_id, + membership, + content + ) + + logger.debug("Got response to make_%s: %s", membership, pdu) + + event = pdu + + # We should assert some things. + # FIXME: Do this in a nicer way + assert(event.type == EventTypes.Member) + assert(event.user_id == user_id) + assert(event.state_key == user_id) + assert(event.room_id == room_id) + defer.returnValue((origin, event)) + + def _sign_event(self, event): + event.internal_metadata.outlier = False + + builder = self.event_builder_factory.new( + unfreeze(event.get_pdu_json()) + ) + + builder.event_id = self.event_builder_factory.create_event_id() + builder.origin = self.hs.hostname + + if not hasattr(event, "signatures"): + builder.signatures = {} + + add_hashes_and_signatures( + builder, + self.hs.hostname, + self.hs.config.signing_key[0], + ) + + return builder.build() + + @defer.inlineCallbacks + @log_function + def on_make_leave_request(self, room_id, user_id): + """ We've received a /make_leave/ request, so we create a partial + join event for the room and return that. We do *not* persist or + process it until the other server has signed it and sent it back. + """ + builder = self.event_builder_factory.new({ + "type": EventTypes.Member, + "content": {"membership": Membership.LEAVE}, + "room_id": room_id, + "sender": user_id, + "state_key": user_id, + }) + + event, context = yield self._create_new_client_event( + builder=builder, + ) + + self.auth.check(event, auth_events=context.current_state) + + defer.returnValue(event) + + @defer.inlineCallbacks + @log_function + def on_send_leave_request(self, origin, pdu): + """ We have received a leave event for a room. Fully process it.""" + event = pdu + + logger.debug( + "on_send_leave_request: Got event: %s, signatures: %s", + event.event_id, + event.signatures, + ) + + event.internal_metadata.outlier = False + + context, event_stream_id, max_stream_id = yield self._handle_new_event( + origin, event + ) + + logger.debug( + "on_send_leave_request: After _handle_new_event: %s, sigs: %s", + event.event_id, + event.signatures, + ) + + extra_users = [] + if event.type == EventTypes.Member: + target_user_id = event.state_key + target_user = UserID.from_string(target_user_id) + extra_users.append(target_user) + + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, extra_users=extra_users + ) + + def log_failure(f): + logger.warn( + "Failed to notify about %s: %s", + event.event_id, f.value + ) + + d.addErrback(log_failure) + + new_pdu = event + + destinations = set() + + for k, s in context.current_state.items(): + try: + if k[0] == EventTypes.Member: + if s.content["membership"] == Membership.LEAVE: + destinations.add( + UserID.from_string(s.state_key).domain + ) + except: + logger.warn( + "Failed to get destination from event %s", s.event_id + ) + + destinations.discard(origin) + + logger.debug( + "on_send_leave_request: Sending event: %s, signatures: %s", + event.event_id, + event.signatures, + ) + + self.replication_layer.send_pdu(new_pdu, destinations) + + defer.returnValue(None) + + @defer.inlineCallbacks def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): yield run_on_reactor() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index dfeeae76db..0f947993d1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -146,7 +146,7 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) - events = yield self._filter_events_for_client(user_id, room_id, events) + events = yield self._filter_events_for_client(user_id, events) time_now = self.clock.time_msec() @@ -162,52 +162,6 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def _filter_events_for_client(self, user_id, room_id, events): - event_id_to_state = yield self.store.get_state_for_events( - frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), - ) - ) - - def allowed(event, state): - if event.type == EventTypes.RoomHistoryVisibility: - return True - - membership_ev = state.get((EventTypes.Member, user_id), None) - if membership_ev: - membership = membership_ev.membership - else: - membership = Membership.LEAVE - - if membership == Membership.JOIN: - return True - - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history: - visibility = history.content.get("history_visibility", "shared") - else: - visibility = "shared" - - if visibility == "public": - return True - elif visibility == "shared": - return True - elif visibility == "joined": - return membership == Membership.JOIN - elif visibility == "invited": - return membership == Membership.INVITE - - return True - - defer.returnValue([ - event - for event in events - if allowed(event, event_id_to_state[event.event_id]) - ]) - - @defer.inlineCallbacks def create_and_send_event(self, event_dict, ratelimit=True, token_id=None, txn_id=None): """ Given a dict from a client, create and handle a new event. @@ -368,6 +322,8 @@ class MessageHandler(BaseHandler): user, pagination_config.get_source_config("receipt"), None ) + tags_by_room = yield self.store.get_tags_for_user(user_id) + public_room_ids = yield self.store.get_public_room_ids() limit = pagin_config.limit @@ -424,7 +380,7 @@ class MessageHandler(BaseHandler): ).addErrback(unwrapFirstError) messages = yield self._filter_events_for_client( - user_id, event.room_id, messages + user_id, messages ) start_token = now_token.copy_and_replace("room_key", token[0]) @@ -444,6 +400,15 @@ class MessageHandler(BaseHandler): serialize_event(c, time_now, as_client_event) for c in current_state.values() ] + + private_user_data = [] + tags = tags_by_room.get(event.room_id) + if tags: + private_user_data.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + d["private_user_data"] = private_user_data except: logger.exception("Failed to get snapshot") @@ -493,6 +458,16 @@ class MessageHandler(BaseHandler): result = yield self._room_initial_sync_parted( user_id, room_id, pagin_config, member_event ) + + private_user_data = [] + tags = yield self.store.get_tags_for_room(user_id, room_id) + if tags: + private_user_data.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + result["private_user_data"] = private_user_data + defer.returnValue(result) @defer.inlineCallbacks @@ -519,11 +494,11 @@ class MessageHandler(BaseHandler): ) messages = yield self._filter_events_for_client( - user_id, room_id, messages + user_id, messages ) - start_token = StreamToken(token[0], 0, 0, 0) - end_token = StreamToken(token[1], 0, 0, 0) + start_token = StreamToken(token[0], 0, 0, 0, 0) + end_token = StreamToken(token[1], 0, 0, 0, 0) time_now = self.clock.time_msec() @@ -599,7 +574,7 @@ class MessageHandler(BaseHandler): ).addErrback(unwrapFirstError) messages = yield self._filter_events_for_client( - user_id, room_id, messages + user_id, messages ) start_token = now_token.copy_and_replace("room_key", token[0]) diff --git a/synapse/handlers/private_user_data.py b/synapse/handlers/private_user_data.py new file mode 100644 index 0000000000..1778c71325 --- /dev/null +++ b/synapse/handlers/private_user_data.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + + +class PrivateUserDataEventSource(object): + def __init__(self, hs): + self.store = hs.get_datastore() + + def get_current_key(self, direction='f'): + return self.store.get_max_private_user_data_stream_id() + + @defer.inlineCallbacks + def get_new_events_for_user(self, user, from_key, limit): + user_id = user.to_string() + last_stream_id = from_key + + current_stream_id = yield self.store.get_max_private_user_data_stream_id() + tags = yield self.store.get_updated_tags(user_id, last_stream_id) + + results = [] + for room_id, room_tags in tags.items(): + results.append({ + "type": "m.tag", + "content": {"tags": room_tags}, + "room_id": room_id, + }) + + defer.returnValue((results, current_stream_id)) + + @defer.inlineCallbacks + def get_pagination_rows(self, user, config, key): + defer.returnValue(([], config.to_id)) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3364a5de14..36878a6c20 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -22,12 +22,18 @@ from synapse.types import UserID, RoomAlias, RoomID from synapse.api.constants import ( EventTypes, Membership, JoinRules, RoomCreationPreset, ) -from synapse.api.errors import StoreError, SynapseError +from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.util import stringutils, unwrapFirstError from synapse.util.async import run_on_reactor +from signedjson.sign import verify_signed_json +from signedjson.key import decode_verify_key_bytes + from collections import OrderedDict +from unpaddedbase64 import decode_base64 + import logging +import math import string logger = logging.getLogger(__name__) @@ -384,7 +390,22 @@ class RoomMemberHandler(BaseHandler): if event.membership == Membership.JOIN: yield self._do_join(event, context, do_auth=do_auth) else: - # This is not a JOIN, so we can handle it normally. + if event.membership == Membership.LEAVE: + is_host_in_room = yield self.is_host_in_room(room_id, context) + if not is_host_in_room: + # Rejecting an invite, rather than leaving a joined room + handler = self.hs.get_handlers().federation_handler + inviter = yield self.get_inviter(event) + if not inviter: + # return the same error as join_room_alias does + raise SynapseError(404, "No known servers") + yield handler.do_remotely_reject_invite( + [inviter.domain], + room_id, + event.user_id + ) + defer.returnValue({"room_id": room_id}) + return # FIXME: This isn't idempotency. if prev_state and prev_state.membership == event.membership: @@ -408,7 +429,7 @@ class RoomMemberHandler(BaseHandler): defer.returnValue({"room_id": room_id}) @defer.inlineCallbacks - def join_room_alias(self, joinee, room_alias, do_auth=True, content={}): + def join_room_alias(self, joinee, room_alias, content={}): directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) @@ -442,8 +463,6 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def _do_join(self, event, context, room_hosts=None, do_auth=True): - joinee = UserID.from_string(event.state_key) - # room_id = RoomID.from_string(event.room_id, self.hs) room_id = event.room_id # XXX: We don't do an auth check if we are doing an invite @@ -451,41 +470,18 @@ class RoomMemberHandler(BaseHandler): # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - is_host_in_room = yield self.auth.check_host_in_room( - event.room_id, - self.hs.hostname - ) - if not is_host_in_room: - # is *anyone* in the room? - room_member_keys = [ - v for (k, v) in context.current_state.keys() if ( - k == "m.room.member" - ) - ] - if len(room_member_keys) == 0: - # has the room been created so we can join it? - create_event = context.current_state.get(("m.room.create", "")) - if create_event: - is_host_in_room = True - + is_host_in_room = yield self.is_host_in_room(room_id, context) if is_host_in_room: should_do_dance = False elif room_hosts: # TODO: Shouldn't this be remote_room_host? should_do_dance = True else: - # TODO(markjh): get prev_state from snapshot - prev_state = yield self.store.get_room_member( - joinee.to_string(), room_id - ) - - if prev_state and prev_state.membership == Membership.INVITE: - inviter = UserID.from_string(prev_state.user_id) - - should_do_dance = not self.hs.is_mine(inviter) - room_hosts = [inviter.domain] - else: + inviter = yield self.get_inviter(event) + if not inviter: # return the same error as join_room_alias does raise SynapseError(404, "No known servers") + should_do_dance = not self.hs.is_mine(inviter) + room_hosts = [inviter.domain] if should_do_dance: handler = self.hs.get_handlers().federation_handler @@ -493,8 +489,7 @@ class RoomMemberHandler(BaseHandler): room_hosts, room_id, event.user_id, - event.content, # FIXME To get a non-frozen dict - context + event.content # FIXME To get a non-frozen dict ) else: logger.debug("Doing normal join") @@ -512,6 +507,44 @@ class RoomMemberHandler(BaseHandler): ) @defer.inlineCallbacks + def get_inviter(self, event): + # TODO(markjh): get prev_state from snapshot + prev_state = yield self.store.get_room_member( + event.user_id, event.room_id + ) + + if prev_state and prev_state.membership == Membership.INVITE: + defer.returnValue(UserID.from_string(prev_state.user_id)) + return + elif "third_party_invite" in event.content: + if "sender" in event.content["third_party_invite"]: + inviter = UserID.from_string( + event.content["third_party_invite"]["sender"] + ) + defer.returnValue(inviter) + defer.returnValue(None) + + @defer.inlineCallbacks + def is_host_in_room(self, room_id, context): + is_host_in_room = yield self.auth.check_host_in_room( + room_id, + self.hs.hostname + ) + if not is_host_in_room: + # is *anyone* in the room? + room_member_keys = [ + v for (k, v) in context.current_state.keys() if ( + k == "m.room.member" + ) + ] + if len(room_member_keys) == 0: + # has the room been created so we can join it? + create_event = context.current_state.get(("m.room.create", "")) + if create_event: + is_host_in_room = True + defer.returnValue(is_host_in_room) + + @defer.inlineCallbacks def get_joined_rooms_for_user(self, user): """Returns a list of roomids that the user has any of the given membership states in.""" @@ -540,6 +573,160 @@ class RoomMemberHandler(BaseHandler): suppress_auth=(not do_auth), ) + @defer.inlineCallbacks + def do_3pid_invite( + self, + room_id, + inviter, + medium, + address, + id_server, + display_name, + token_id, + txn_id + ): + invitee = yield self._lookup_3pid( + id_server, medium, address + ) + + if invitee: + # make sure it looks like a user ID; it'll throw if it's invalid. + UserID.from_string(invitee) + yield self.hs.get_handlers().message_handler.create_and_send_event( + { + "type": EventTypes.Member, + "content": { + "membership": unicode("invite") + }, + "room_id": room_id, + "sender": inviter.to_string(), + "state_key": invitee, + }, + token_id=token_id, + txn_id=txn_id, + ) + else: + yield self._make_and_store_3pid_invite( + id_server, + display_name, + medium, + address, + room_id, + inviter, + token_id, + txn_id=txn_id + ) + + @defer.inlineCallbacks + def _lookup_3pid(self, id_server, medium, address): + """Looks up a 3pid in the passed identity server. + + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + medium (str): The type of the third party identifier (e.g. "email"). + address (str): The third party identifier (e.g. "foo@example.com"). + + Returns: + (str) the matrix ID of the 3pid, or None if it is not recognized. + """ + try: + data = yield self.hs.get_simple_http_client().get_json( + "https://%s/_matrix/identity/api/v1/lookup" % (id_server,), + { + "medium": medium, + "address": address, + } + ) + + if "mxid" in data: + if "signatures" not in data: + raise AuthError(401, "No signatures on 3pid binding") + self.verify_any_signature(data, id_server) + defer.returnValue(data["mxid"]) + + except IOError as e: + logger.warn("Error from identity server lookup: %s" % (e,)) + defer.returnValue(None) + + @defer.inlineCallbacks + def verify_any_signature(self, data, server_hostname): + if server_hostname not in data["signatures"]: + raise AuthError(401, "No signature from server %s" % (server_hostname,)) + for key_name, signature in data["signatures"][server_hostname].items(): + key_data = yield self.hs.get_simple_http_client().get_json( + "https://%s/_matrix/identity/api/v1/pubkey/%s" % + (server_hostname, key_name,), + ) + if "public_key" not in key_data: + raise AuthError(401, "No public key named %s from %s" % + (key_name, server_hostname,)) + verify_signed_json( + data, + server_hostname, + decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) + ) + return + + @defer.inlineCallbacks + def _make_and_store_3pid_invite( + self, + id_server, + display_name, + medium, + address, + room_id, + user, + token_id, + txn_id + ): + token, public_key, key_validity_url = ( + yield self._ask_id_server_for_third_party_invite( + id_server, + medium, + address, + room_id, + user.to_string() + ) + ) + msg_handler = self.hs.get_handlers().message_handler + yield msg_handler.create_and_send_event( + { + "type": EventTypes.ThirdPartyInvite, + "content": { + "display_name": display_name, + "key_validity_url": key_validity_url, + "public_key": public_key, + }, + "room_id": room_id, + "sender": user.to_string(), + "state_key": token, + }, + token_id=token_id, + txn_id=txn_id, + ) + + @defer.inlineCallbacks + def _ask_id_server_for_third_party_invite( + self, id_server, medium, address, room_id, sender): + is_url = "https://%s/_matrix/identity/api/v1/store-invite" % (id_server,) + data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( + is_url, + { + "medium": medium, + "address": address, + "room_id": room_id, + "sender": sender, + } + ) + # TODO: Check for success + token = data["token"] + public_key = data["public_key"] + key_validity_url = "https://%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server, + ) + defer.returnValue((token, public_key, key_validity_url)) + class RoomListHandler(BaseHandler): @@ -561,6 +748,60 @@ class RoomListHandler(BaseHandler): defer.returnValue({"start": "START", "end": "END", "chunk": chunk}) +class RoomContextHandler(BaseHandler): + @defer.inlineCallbacks + def get_event_context(self, user, room_id, event_id, limit): + """Retrieves events, pagination tokens and state around a given event + in a room. + + Args: + user (UserID) + room_id (str) + event_id (str) + limit (int): The maximum number of events to return in total + (excluding state). + + Returns: + dict + """ + before_limit = math.floor(limit/2.) + after_limit = limit - before_limit + + now_token = yield self.hs.get_event_sources().get_current_token() + + results = yield self.store.get_events_around( + room_id, event_id, before_limit, after_limit + ) + + results["events_before"] = yield self._filter_events_for_client( + user.to_string(), results["events_before"] + ) + + results["events_after"] = yield self._filter_events_for_client( + user.to_string(), results["events_after"] + ) + + if results["events_after"]: + last_event_id = results["events_after"][-1].event_id + else: + last_event_id = event_id + + state = yield self.store.get_state_for_events( + [last_event_id], None + ) + results["state"] = state[last_event_id].values() + + results["start"] = now_token.copy_and_replace( + "room_key", results["start"] + ).to_string() + + results["end"] = now_token.copy_and_replace( + "room_key", results["end"] + ).to_string() + + defer.returnValue(results) + + class RoomEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py new file mode 100644 index 0000000000..2718e9482e --- /dev/null +++ b/synapse/handlers/search.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import BaseHandler + +from synapse.api.constants import Membership +from synapse.api.filtering import Filter +from synapse.api.errors import SynapseError +from synapse.events.utils import serialize_event + +import logging + + +logger = logging.getLogger(__name__) + + +class SearchHandler(BaseHandler): + + def __init__(self, hs): + super(SearchHandler, self).__init__(hs) + + @defer.inlineCallbacks + def search(self, user, content): + """Performs a full text search for a user. + + Args: + user (UserID) + content (dict): Search parameters + + Returns: + dict to be returned to the client with results of search + """ + + try: + search_term = content["search_categories"]["room_events"]["search_term"] + keys = content["search_categories"]["room_events"].get("keys", [ + "content.body", "content.name", "content.topic", + ]) + filter_dict = content["search_categories"]["room_events"].get("filter", {}) + event_context = content["search_categories"]["room_events"].get( + "event_context", None + ) + + if event_context is not None: + before_limit = int(event_context.get( + "before_limit", 5 + )) + after_limit = int(event_context.get( + "after_limit", 5 + )) + except KeyError: + raise SynapseError(400, "Invalid search query") + + search_filter = Filter(filter_dict) + + # TODO: Search through left rooms too + rooms = yield self.store.get_rooms_for_user_where_membership_is( + user.to_string(), + membership_list=[Membership.JOIN], + # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], + ) + room_ids = set(r.room_id for r in rooms) + + room_ids = search_filter.filter_rooms(room_ids) + + rank_map, event_map, _ = yield self.store.search_msgs( + room_ids, search_term, keys + ) + + filtered_events = search_filter.filter(event_map.values()) + + allowed_events = yield self._filter_events_for_client( + user.to_string(), filtered_events + ) + + allowed_events.sort(key=lambda e: -rank_map[e.event_id]) + allowed_events = allowed_events[:search_filter.limit()] + + if event_context is not None: + now_token = yield self.hs.get_event_sources().get_current_token() + + contexts = {} + for event in allowed_events: + res = yield self.store.get_events_around( + event.room_id, event.event_id, before_limit, after_limit + ) + + res["events_before"] = yield self._filter_events_for_client( + user.to_string(), res["events_before"] + ) + + res["events_after"] = yield self._filter_events_for_client( + user.to_string(), res["events_after"] + ) + + res["start"] = now_token.copy_and_replace( + "room_key", res["start"] + ).to_string() + + res["end"] = now_token.copy_and_replace( + "room_key", res["end"] + ).to_string() + + contexts[event.event_id] = res + else: + contexts = {} + + # TODO: Add a limit + + time_now = self.clock.time_msec() + + for context in contexts.values(): + context["events_before"] = [ + serialize_event(e, time_now) + for e in context["events_before"] + ] + context["events_after"] = [ + serialize_event(e, time_now) + for e in context["events_after"] + ] + + results = { + e.event_id: { + "rank": rank_map[e.event_id], + "result": serialize_event(e, time_now), + "context": contexts.get(e.event_id, {}), + } + for e in allowed_events + } + + logger.info("Found %d results", len(results)) + + defer.returnValue({ + "search_categories": { + "room_events": { + "results": results, + "count": len(results) + } + } + }) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 21cf50101a..d6527c1ae8 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -51,6 +51,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ "timeline", "state", "ephemeral", + "private_user_data", ])): __slots__ = [] @@ -58,7 +59,31 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ - return bool(self.timeline or self.state or self.ephemeral) + return bool( + self.timeline + or self.state + or self.ephemeral + or self.private_user_data + ) + + +class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ + "room_id", + "timeline", + "state", + "private_user_data", +])): + __slots__ = [] + + def __nonzero__(self): + """Make the result appear empty if there are no updates. This is used + to tell if room needs to be part of the sync result. + """ + return bool( + self.timeline + or self.state + or self.private_user_data + ) class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ @@ -67,12 +92,17 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ ])): __slots__ = [] + def __nonzero__(self): + """Invited rooms should always be reported to the client""" + return True + class SyncResult(collections.namedtuple("SyncResult", [ "next_batch", # Token for the next sync "presence", # List of presence events for the user. "joined", # JoinedSyncResult for each joined room. "invited", # InvitedSyncResult for each invited room. + "archived", # ArchivedSyncResult for each archived room. ])): __slots__ = [] @@ -94,15 +124,20 @@ class SyncHandler(BaseHandler): self.clock = hs.get_clock() @defer.inlineCallbacks - def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0): + def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, + full_state=False): """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. Returns: A Deferred SyncResult. """ - if timeout == 0 or since_token is None: - result = yield self.current_sync_for_user(sync_config, since_token) + + if timeout == 0 or since_token is None or full_state: + # we are going to return immediately, so don't bother calling + # notifier.wait_for_events. + result = yield self.current_sync_for_user(sync_config, since_token, + full_state=full_state) defer.returnValue(result) else: def current_sync_callback(before_token, after_token): @@ -127,24 +162,33 @@ class SyncHandler(BaseHandler): ) defer.returnValue(result) - def current_sync_for_user(self, sync_config, since_token=None): + def current_sync_for_user(self, sync_config, since_token=None, + full_state=False): """Get the sync for client needed to match what the server has now. Returns: A Deferred SyncResult. """ - if since_token is None: - return self.initial_sync(sync_config) + if since_token is None or full_state: + return self.full_state_sync(sync_config, since_token) else: return self.incremental_sync_with_gap(sync_config, since_token) @defer.inlineCallbacks - def initial_sync(self, sync_config): - """Get a sync for a client which is starting without any state + def full_state_sync(self, sync_config, timeline_since_token): + """Get a sync for a client which is starting without any state. + + If a 'message_since_token' is given, only timeline events which have + happened since that token will be returned. + Returns: A Deferred SyncResult. """ now_token = yield self.event_sources.get_current_token() + now_token, ephemeral_by_room = yield self.ephemeral_by_room( + sync_config, now_token + ) + presence_stream = self.event_sources.sources["presence"] # TODO (mjark): This looks wrong, shouldn't we be getting the presence # UP to the present rather than after the present? @@ -156,15 +200,30 @@ class SyncHandler(BaseHandler): ) room_list = yield self.store.get_rooms_for_user_where_membership_is( user_id=sync_config.user.to_string(), - membership_list=[Membership.INVITE, Membership.JOIN] + membership_list=( + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + Membership.BAN + ) + ) + + tags_by_room = yield self.store.get_tags_for_user( + sync_config.user.to_string() ) joined = [] invited = [] + archived = [] for event in room_list: if event.membership == Membership.JOIN: - room_sync = yield self.initial_sync_for_joined_room( - event.room_id, sync_config, now_token, + room_sync = yield self.full_state_sync_for_joined_room( + room_id=event.room_id, + sync_config=sync_config, + now_token=now_token, + timeline_since_token=timeline_since_token, + ephemeral_by_room=ephemeral_by_room, + tags_by_room=tags_by_room, ) joined.append(room_sync) elif event.membership == Membership.INVITE: @@ -173,23 +232,39 @@ class SyncHandler(BaseHandler): room_id=event.room_id, invite=invite, )) + elif event.membership in (Membership.LEAVE, Membership.BAN): + leave_token = now_token.copy_and_replace( + "room_key", "s%d" % (event.stream_ordering,) + ) + room_sync = yield self.full_state_sync_for_archived_room( + sync_config=sync_config, + room_id=event.room_id, + leave_event_id=event.event_id, + leave_token=leave_token, + timeline_since_token=timeline_since_token, + tags_by_room=tags_by_room, + ) + archived.append(room_sync) defer.returnValue(SyncResult( presence=presence, joined=joined, invited=invited, + archived=archived, next_batch=now_token, )) @defer.inlineCallbacks - def initial_sync_for_joined_room(self, room_id, sync_config, now_token): + def full_state_sync_for_joined_room(self, room_id, sync_config, + now_token, timeline_since_token, + ephemeral_by_room, tags_by_room): """Sync a room for a client which is starting without any state Returns: A Deferred JoinedSyncResult. """ batch = yield self.load_filtered_recents( - room_id, sync_config, now_token, + room_id, sync_config, now_token, since_token=timeline_since_token ) current_state = yield self.state_handler.get_current_state( @@ -201,7 +276,92 @@ class SyncHandler(BaseHandler): room_id=room_id, timeline=batch, state=current_state_events, - ephemeral=[], + ephemeral=ephemeral_by_room.get(room_id, []), + private_user_data=self.private_user_data_for_room( + room_id, tags_by_room + ), + )) + + def private_user_data_for_room(self, room_id, tags_by_room): + private_user_data = [] + tags = tags_by_room.get(room_id) + if tags: + private_user_data.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + return private_user_data + + @defer.inlineCallbacks + def ephemeral_by_room(self, sync_config, now_token, since_token=None): + """Get the ephemeral events for each room the user is in + Args: + sync_config (SyncConfig): The flags, filters and user for the sync. + now_token (StreamToken): Where the server is currently up to. + since_token (StreamToken): Where the server was when the client + last synced. + Returns: + A tuple of the now StreamToken, updated to reflect the which typing + events are included, and a dict mapping from room_id to a list of + typing events for that room. + """ + + typing_key = since_token.typing_key if since_token else "0" + + typing_source = self.event_sources.sources["typing"] + typing, typing_key = yield typing_source.get_new_events_for_user( + user=sync_config.user, + from_key=typing_key, + limit=sync_config.filter.ephemeral_limit(), + ) + now_token = now_token.copy_and_replace("typing_key", typing_key) + + ephemeral_by_room = {} + + for event in typing: + room_id = event.pop("room_id") + ephemeral_by_room.setdefault(room_id, []).append(event) + + receipt_key = since_token.receipt_key if since_token else "0" + + receipt_source = self.event_sources.sources["receipt"] + receipts, receipt_key = yield receipt_source.get_new_events_for_user( + user=sync_config.user, + from_key=receipt_key, + limit=sync_config.filter.ephemeral_limit(), + ) + now_token = now_token.copy_and_replace("receipt_key", receipt_key) + + for event in receipts: + room_id = event.pop("room_id") + ephemeral_by_room.setdefault(room_id, []).append(event) + + defer.returnValue((now_token, ephemeral_by_room)) + + @defer.inlineCallbacks + def full_state_sync_for_archived_room(self, room_id, sync_config, + leave_event_id, leave_token, + timeline_since_token, tags_by_room): + """Sync a room for a client which is starting without any state + Returns: + A Deferred JoinedSyncResult. + """ + + batch = yield self.load_filtered_recents( + room_id, sync_config, leave_token, since_token=timeline_since_token + ) + + leave_state = yield self.store.get_state_for_events( + [leave_event_id], None + ) + + defer.returnValue(ArchivedSyncResult( + room_id=room_id, + timeline=batch, + state=leave_state[leave_event_id].values(), + private_user_data=self.private_user_data_for_room( + room_id, tags_by_room + ), )) @defer.inlineCallbacks @@ -221,18 +381,9 @@ class SyncHandler(BaseHandler): ) now_token = now_token.copy_and_replace("presence_key", presence_key) - typing_source = self.event_sources.sources["typing"] - typing, typing_key = yield typing_source.get_new_events_for_user( - user=sync_config.user, - from_key=since_token.typing_key, - limit=sync_config.filter.ephemeral_limit(), + now_token, ephemeral_by_room = yield self.ephemeral_by_room( + sync_config, now_token, since_token ) - now_token = now_token.copy_and_replace("typing_key", typing_key) - - typing_by_room = {event["room_id"]: [event] for event in typing} - for event in typing: - event.pop("room_id") - logger.debug("Typing %r", typing_by_room) rm_handler = self.hs.get_handlers().room_member_handler app_service = yield self.store.get_app_service_by_user_id( @@ -256,19 +407,28 @@ class SyncHandler(BaseHandler): limit=timeline_limit + 1, ) + tags_by_room = yield self.store.get_updated_tags( + sync_config.user.to_string(), + since_token.private_user_data_key, + ) + joined = [] + archived = [] if len(room_events) <= timeline_limit: # There is no gap in any of the rooms. Therefore we can just # partition the new events by room and return them. invite_events = [] + leave_events = [] events_by_room_id = {} for event in room_events: events_by_room_id.setdefault(event.room_id, []).append(event) if event.room_id not in joined_room_ids: if (event.type == EventTypes.Member - and event.membership == Membership.INVITE and event.state_key == sync_config.user.to_string()): - invite_events.append(event) + if event.membership == Membership.INVITE: + invite_events.append(event) + elif event.membership in (Membership.LEAVE, Membership.BAN): + leave_events.append(event) for room_id in joined_room_ids: recents = events_by_room_id.get(room_id, []) @@ -280,7 +440,7 @@ class SyncHandler(BaseHandler): else: prev_batch = now_token - state = yield self.check_joined_room( + state, limited = yield self.check_joined_room( sync_config, room_id, state ) @@ -289,26 +449,40 @@ class SyncHandler(BaseHandler): timeline=TimelineBatch( events=recents, prev_batch=prev_batch, - limited=False, + limited=limited, ), state=state, - ephemeral=typing_by_room.get(room_id, []) + ephemeral=ephemeral_by_room.get(room_id, []), + private_user_data=self.private_user_data_for_room( + room_id, tags_by_room + ), ) if room_sync: joined.append(room_sync) + else: invite_events = yield self.store.get_invites_for_user( sync_config.user.to_string() ) + leave_events = yield self.store.get_leave_and_ban_events_for_user( + sync_config.user.to_string() + ) + for room_id in joined_room_ids: room_sync = yield self.incremental_sync_with_gap_for_room( room_id, sync_config, since_token, now_token, - typing_by_room + ephemeral_by_room, tags_by_room ) if room_sync: joined.append(room_sync) + for leave_event in leave_events: + room_sync = yield self.incremental_sync_for_archived_room( + sync_config, leave_event, since_token, tags_by_room + ) + archived.append(room_sync) + invited = [ InvitedSyncResult(room_id=event.room_id, invite=event) for event in invite_events @@ -318,56 +492,11 @@ class SyncHandler(BaseHandler): presence=presence, joined=joined, invited=invited, + archived=archived, next_batch=now_token, )) @defer.inlineCallbacks - def _filter_events_for_client(self, user_id, room_id, events): - event_id_to_state = yield self.store.get_state_for_events( - frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), - ) - ) - - def allowed(event, state): - if event.type == EventTypes.RoomHistoryVisibility: - return True - - membership_ev = state.get((EventTypes.Member, user_id), None) - if membership_ev: - membership = membership_ev.membership - else: - membership = Membership.LEAVE - - if membership == Membership.JOIN: - return True - - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history: - visibility = history.content.get("history_visibility", "shared") - else: - visibility = "shared" - - if visibility == "public": - return True - elif visibility == "shared": - return True - elif visibility == "joined": - return membership == Membership.JOIN - elif visibility == "invited": - return membership == Membership.INVITE - - return True - - defer.returnValue([ - event - for event in events - if allowed(event, event_id_to_state[event.event_id]) - ]) - - @defer.inlineCallbacks def load_filtered_recents(self, room_id, sync_config, now_token, since_token=None): limited = True @@ -390,7 +519,7 @@ class SyncHandler(BaseHandler): end_key = "s" + room_key.split('-')[-1] loaded_recents = sync_config.filter.filter_room_timeline(events) loaded_recents = yield self._filter_events_for_client( - sync_config.user.to_string(), room_id, loaded_recents, + sync_config.user.to_string(), loaded_recents, ) loaded_recents.extend(recents) recents = loaded_recents @@ -414,7 +543,7 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def incremental_sync_with_gap_for_room(self, room_id, sync_config, since_token, now_token, - typing_by_room): + ephemeral_by_room, tags_by_room): """ Get the incremental delta needed to bring the client up to date for the room. Gives the client the most recent events and the changes to state. @@ -447,7 +576,7 @@ class SyncHandler(BaseHandler): current_state=current_state_events, ) - state_events_delta = yield self.check_joined_room( + state_events_delta, _ = yield self.check_joined_room( sync_config, room_id, state_events_delta ) @@ -455,7 +584,62 @@ class SyncHandler(BaseHandler): room_id=room_id, timeline=batch, state=state_events_delta, - ephemeral=typing_by_room.get(room_id, []) + ephemeral=ephemeral_by_room.get(room_id, []), + private_user_data=self.private_user_data_for_room( + room_id, tags_by_room + ), + ) + + logging.debug("Room sync: %r", room_sync) + + defer.returnValue(room_sync) + + @defer.inlineCallbacks + def incremental_sync_for_archived_room(self, sync_config, leave_event, + since_token, tags_by_room): + """ Get the incremental delta needed to bring the client up to date for + the archived room. + Returns: + A Deferred ArchivedSyncResult + """ + + stream_token = yield self.store.get_stream_token_for_event( + leave_event.event_id + ) + + leave_token = since_token.copy_and_replace("room_key", stream_token) + + batch = yield self.load_filtered_recents( + leave_event.room_id, sync_config, leave_token, since_token, + ) + + logging.debug("Recents %r", batch) + + # TODO(mjark): This seems racy since this isn't being passed a + # token to indicate what point in the stream this is + leave_state = yield self.store.get_state_for_events( + [leave_event.event_id], None + ) + + state_events_at_leave = leave_state[leave_event.event_id].values() + + state_at_previous_sync = yield self.get_state_at_previous_sync( + leave_event.room_id, since_token=since_token + ) + + state_events_delta = yield self.compute_state_delta( + since_token=since_token, + previous_state=state_at_previous_sync, + current_state=state_events_at_leave, + ) + + room_sync = ArchivedSyncResult( + room_id=leave_event.room_id, + timeline=batch, + state=state_events_delta, + private_user_data=self.private_user_data_for_room( + leave_event.room_id, tags_by_room + ), ) logging.debug("Room sync: %r", room_sync) @@ -505,6 +689,7 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def check_joined_room(self, sync_config, room_id, state_delta): joined = False + limited = False for event in state_delta: if ( event.type == EventTypes.Member @@ -516,5 +701,6 @@ class SyncHandler(BaseHandler): if joined: res = yield self.state_handler.get_current_state(room_id) state_delta = res.values() + limited = True - defer.returnValue(state_delta) + defer.returnValue((state_delta, limited)) diff --git a/synapse/http/client.py b/synapse/http/client.py index 9a5869abee..27e5190224 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -24,7 +24,6 @@ from canonicaljson import encode_canonical_json from twisted.internet import defer, reactor, ssl from twisted.web.client import ( Agent, readBody, FileBodyProducer, PartialDownloadError, - HTTPConnectionPool, ) from twisted.web.http_headers import Headers @@ -59,11 +58,8 @@ class SimpleHttpClient(object): # The default context factory in Twisted 14.0.0 (which we require) is # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' - pool = HTTPConnectionPool(reactor) - pool.maxPersistentPerHost = 10 self.agent = Agent( reactor, - pool=pool, connectTimeout=15, contextFactory=hs.get_http_client_context_factory() ) diff --git a/synapse/notifier.py b/synapse/notifier.py index f998fc83bf..a78ee3c1e7 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -270,7 +270,7 @@ class Notifier(object): @defer.inlineCallbacks def wait_for_events(self, user, rooms, timeout, callback, - from_token=StreamToken("s0", "0", "0", "0")): + from_token=StreamToken("s0", "0", "0", "0", "0")): """Wait until the callback returns a non empty response or the timeout fires. """ diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 00ec8fcd74..4ea06c1434 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -107,6 +107,8 @@ class LoginRestServlet(ClientV1RestServlet): user_id = yield self.hs.get_datastore().get_user_id_by_threepid( login_submission['medium'], login_submission['address'] ) + if not user_id: + raise LoginError(403, "", errcode=Codes.FORBIDDEN) else: user_id = login_submission['user'] @@ -198,36 +200,6 @@ class LoginRestServlet(ClientV1RestServlet): return (user, attributes) -class LoginFallbackRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/fallback$") - - def on_GET(self, request): - # TODO(kegan): This should be returning some HTML which is capable of - # hitting LoginRestServlet - return (200, {}) - - -class PasswordResetRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/reset") - - @defer.inlineCallbacks - def on_POST(self, request): - reset_info = _parse_json(request) - try: - email = reset_info["email"] - user_id = reset_info["user_id"] - handler = self.handlers.login_handler - yield handler.reset_password(user_id, email) - # purposefully give no feedback to avoid people hammering different - # combinations. - defer.returnValue((200, {})) - except KeyError: - raise SynapseError( - 400, - "Missing keys. Requires 'email' and 'user_id'." - ) - - class SAML2RestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login/saml2") diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 23871f161e..2dcaee86cd 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -26,7 +26,7 @@ from synapse.events.utils import serialize_event import simplejson as json import logging import urllib - +from synapse.util import third_party_invites logger = logging.getLogger(__name__) @@ -397,6 +397,41 @@ class RoomTriggerBackfill(ClientV1RestServlet): defer.returnValue((200, res)) +class RoomEventContext(ClientV1RestServlet): + PATTERN = client_path_pattern( + "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" + ) + + def __init__(self, hs): + super(RoomEventContext, self).__init__(hs) + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, event_id): + user, _ = yield self.auth.get_user_by_req(request) + + limit = int(request.args.get("limit", [10])[0]) + + results = yield self.handlers.room_context_handler.get_event_context( + user, room_id, event_id, limit, + ) + + time_now = self.clock.time_msec() + results["events_before"] = [ + serialize_event(event, time_now) for event in results["events_before"] + ] + results["events_after"] = [ + serialize_event(event, time_now) for event in results["events_after"] + ] + results["state"] = [ + serialize_event(event, time_now) for event in results["state"] + ] + + logger.info("Responding with %r", results) + + defer.returnValue((200, results)) + + # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): @@ -414,10 +449,26 @@ class RoomMembershipRestServlet(ClientV1RestServlet): # target user is you unless it is an invite state_key = user.to_string() - if membership_action in ["invite", "ban", "kick"]: - if "user_id" not in content: + + if membership_action == "invite" and third_party_invites.has_invite_keys(content): + yield self.handlers.room_member_handler.do_3pid_invite( + room_id, + user, + content["medium"], + content["address"], + content["id_server"], + content["display_name"], + token_id, + txn_id + ) + defer.returnValue((200, {})) + return + elif membership_action in ["invite", "ban", "kick"]: + if "user_id" in content: + state_key = content["user_id"] + else: raise SynapseError(400, "Missing user_id key.") - state_key = content["user_id"] + # make sure it looks like a user ID; it'll throw if it's invalid. UserID.from_string(state_key) @@ -425,10 +476,20 @@ class RoomMembershipRestServlet(ClientV1RestServlet): membership_action = "leave" msg_handler = self.handlers.message_handler + + event_content = { + "membership": unicode(membership_action), + } + + if membership_action == "join" and third_party_invites.has_join_keys(content): + event_content["third_party_invite"] = ( + third_party_invites.extract_join_keys(content) + ) + yield msg_handler.create_and_send_event( { "type": EventTypes.Member, - "content": {"membership": unicode(membership_action)}, + "content": event_content, "room_id": room_id, "sender": user.to_string(), "state_key": state_key, @@ -529,6 +590,22 @@ class RoomTypingRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) +class SearchRestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern( + "/search$" + ) + + @defer.inlineCallbacks + def on_POST(self, request): + auth_user, _ = yield self.auth.get_user_by_req(request) + + content = _parse_json(request) + + results = yield self.handlers.search_handler.search(auth_user, content) + + defer.returnValue((200, results)) + + def _parse_json(request): try: content = json.loads(request.content.read()) @@ -585,3 +662,5 @@ def register_servlets(hs, http_server): RoomInitialSyncRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) + SearchRestServlet(hs).register(http_server) + RoomEventContext(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index 5831ff0e62..a108132346 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -22,6 +22,7 @@ from . import ( receipts, keys, tokenrefresh, + tags, ) from synapse.http.server import JsonResource @@ -44,3 +45,4 @@ class ClientV2AlphaRestResource(JsonResource): receipts.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource) + tags.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index fffecb24f5..32a1087c91 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -16,14 +16,14 @@ from twisted.internet import defer from synapse.http.servlet import ( - RestServlet, parse_string, parse_integer + RestServlet, parse_string, parse_integer, parse_boolean ) from synapse.handlers.sync import SyncConfig from synapse.types import StreamToken from synapse.events.utils import ( serialize_event, format_event_for_client_v2_without_event_id, ) -from synapse.api.filtering import Filter +from synapse.api.filtering import FilterCollection from ._base import client_v2_pattern import copy @@ -90,6 +90,7 @@ class SyncRestServlet(RestServlet): allowed_values=self.ALLOWED_PRESENCE ) filter_id = parse_string(request, "filter", default=None) + full_state = parse_boolean(request, "full_state", default=False) logger.info( "/sync: user=%r, timeout=%r, since=%r," @@ -103,7 +104,7 @@ class SyncRestServlet(RestServlet): user.localpart, filter_id ) except: - filter = Filter({}) + filter = FilterCollection({}) sync_config = SyncConfig( user=user, @@ -120,7 +121,8 @@ class SyncRestServlet(RestServlet): try: sync_result = yield self.sync_handler.wait_for_sync_for_user( - sync_config, since_token=since_token, timeout=timeout + sync_config, since_token=since_token, timeout=timeout, + full_state=full_state ) finally: if set_presence == "online": @@ -136,6 +138,10 @@ class SyncRestServlet(RestServlet): sync_result.invited, filter, time_now, token_id ) + archived = self.encode_archived( + sync_result.archived, filter, time_now, token_id + ) + response_content = { "presence": self.encode_presence( sync_result.presence, filter, time_now @@ -143,7 +149,7 @@ class SyncRestServlet(RestServlet): "rooms": { "joined": joined, "invited": invited, - "archived": {}, + "archived": archived, }, "next_batch": sync_result.next_batch.to_string(), } @@ -182,14 +188,20 @@ class SyncRestServlet(RestServlet): return invited + def encode_archived(self, rooms, filter, time_now, token_id): + joined = {} + for room in rooms: + joined[room.room_id] = self.encode_room( + room, filter, time_now, token_id, joined=False + ) + + return joined + @staticmethod - def encode_room(room, filter, time_now, token_id): + def encode_room(room, filter, time_now, token_id, joined=True): event_map = {} state_events = filter.filter_room_state(room.state) - timeline_events = filter.filter_room_timeline(room.timeline.events) - ephemeral_events = filter.filter_room_ephemeral(room.ephemeral) state_event_ids = [] - timeline_event_ids = [] for event in state_events: # TODO(mjark): Respect formatting requirements in the filter. event_map[event.event_id] = serialize_event( @@ -198,6 +210,8 @@ class SyncRestServlet(RestServlet): ) state_event_ids.append(event.event_id) + timeline_events = filter.filter_room_timeline(room.timeline.events) + timeline_event_ids = [] for event in timeline_events: # TODO(mjark): Respect formatting requirements in the filter. event_map[event.event_id] = serialize_event( @@ -205,6 +219,11 @@ class SyncRestServlet(RestServlet): event_format=format_event_for_client_v2_without_event_id, ) timeline_event_ids.append(event.event_id) + + private_user_data = filter.filter_room_private_user_data( + room.private_user_data + ) + result = { "event_map": event_map, "timeline": { @@ -213,8 +232,13 @@ class SyncRestServlet(RestServlet): "limited": room.timeline.limited, }, "state": {"events": state_event_ids}, - "ephemeral": {"events": ephemeral_events}, + "private_user_data": {"events": private_user_data}, } + + if joined: + ephemeral_events = filter.filter_room_ephemeral(room.ephemeral) + result["ephemeral"] = {"events": ephemeral_events} + return result diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py new file mode 100644 index 0000000000..dcfe6bd20e --- /dev/null +++ b/synapse/rest/client/v2_alpha/tags.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import client_v2_pattern + +from synapse.http.servlet import RestServlet +from synapse.api.errors import AuthError, SynapseError + +from twisted.internet import defer + +import logging + +import simplejson as json + +logger = logging.getLogger(__name__) + + +class TagListServlet(RestServlet): + """ + GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 + """ + PATTERN = client_v2_pattern( + "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags" + ) + + def __init__(self, hs): + super(TagListServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request, user_id, room_id): + auth_user, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot get tags for other users.") + + tags = yield self.store.get_tags_for_room(user_id, room_id) + + defer.returnValue((200, {"tags": tags})) + + +class TagServlet(RestServlet): + """ + PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 + DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 + """ + PATTERN = client_v2_pattern( + "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" + ) + + def __init__(self, hs): + super(TagServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + + @defer.inlineCallbacks + def on_PUT(self, request, user_id, room_id, tag): + auth_user, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add tags for other users.") + + try: + content_bytes = request.content.read() + body = json.loads(content_bytes) + except: + raise SynapseError(400, "Invalid tag JSON") + + max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) + + yield self.notifier.on_new_event( + "private_user_data_key", max_id, users=[user_id] + ) + + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_DELETE(self, request, user_id, room_id, tag): + auth_user, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add tags for other users.") + + max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) + + yield self.notifier.on_new_event( + "private_user_data_key", max_id, users=[user_id] + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + TagListServlet(hs).register(http_server) + TagServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 48a0633746..e7443f2838 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -40,6 +40,8 @@ from .filtering import FilteringStore from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore +from .search import SearchStore +from .tags import TagsStore import logging @@ -69,6 +71,8 @@ class DataStore(RoomMemberStore, RoomStore, EventsStore, ReceiptsStore, EndToEndKeyStore, + SearchStore, + TagsStore, ): def __init__(self, hs): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 693784ad38..218e708054 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -519,7 +519,7 @@ class SQLBaseStore(object): allow_none=False, desc="_simple_select_one_onecol"): """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it." + return a single row, returning a single column from it. Args: table : string giving the table name diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index bad3b5c5ac..a5a54ec011 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -17,6 +17,8 @@ from synapse.storage.prepare_database import ( prepare_database, prepare_sqlite3_database ) +import struct + class Sqlite3Engine(object): single_threaded = True @@ -32,6 +34,7 @@ class Sqlite3Engine(object): def on_new_connection(self, db_conn): self.prepare_database(db_conn) + db_conn.create_function("rank", 1, _rank) def prepare_database(self, db_conn): prepare_sqlite3_database(db_conn) @@ -45,3 +48,27 @@ class Sqlite3Engine(object): def lock_table(self, txn, table): return + + +# Following functions taken from: https://github.com/coleifer/peewee + +def _parse_match_info(buf): + bufsize = len(buf) + return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] + + +def _rank(raw_match_info): + """Handle match_info called w/default args 'pcx' - based on the example rank + function http://sqlite.org/fts3.html#appendix_a + """ + match_info = _parse_match_info(raw_match_info) + score = 0.0 + p, c = match_info[:2] + for phrase_num in range(p): + phrase_info_idx = 2 + (phrase_num * c * 3) + for col_num in range(c): + col_idx = phrase_info_idx + (col_num * 3) + x1, x2 = match_info[col_idx:col_idx + 2] + if x1 > 0: + score += float(x1) / x2 + return score diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 416ef6af93..e6c1abfc27 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -307,6 +307,8 @@ class EventsStore(SQLBaseStore): self._store_room_name_txn(txn, event) elif event.type == EventTypes.Topic: self._store_room_topic_txn(txn, event) + elif event.type == EventTypes.Message: + self._store_room_message_txn(txn, event) elif event.type == EventTypes.Redaction: self._store_redaction(txn, event) diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 8800116570..fcd43c7fdd 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -34,10 +34,10 @@ class FilteringStore(SQLBaseStore): desc="get_user_filter", ) - defer.returnValue(json.loads(def_json)) + defer.returnValue(json.loads(str(def_json).decode("utf-8"))) def add_user_filter(self, user_localpart, user_filter): - def_json = json.dumps(user_filter) + def_json = json.dumps(user_filter).encode("utf-8") # Need an atomic transaction to SELECT the maximal ID so far then # INSERT a new one diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 1ddf55be4d..1a74d6e360 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 24 +SCHEMA_VERSION = 25 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 5e07b7e0e5..13441fcdce 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -19,6 +19,7 @@ from synapse.api.errors import StoreError from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks +from .engines import PostgresEngine, Sqlite3Engine import collections import logging @@ -175,6 +176,10 @@ class RoomStore(SQLBaseStore): }, ) + self._store_event_search_txn( + txn, event, "content.topic", event.content["topic"] + ) + def _store_room_name_txn(self, txn, event): if hasattr(event, "content") and "name" in event.content: self._simple_insert_txn( @@ -187,6 +192,33 @@ class RoomStore(SQLBaseStore): } ) + self._store_event_search_txn( + txn, event, "content.name", event.content["name"] + ) + + def _store_room_message_txn(self, txn, event): + if hasattr(event, "content") and "body" in event.content: + self._store_event_search_txn( + txn, event, "content.body", event.content["body"] + ) + + def _store_event_search_txn(self, txn, event, key, value): + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, vector)" + " VALUES (?,?,?,to_tsvector('english', ?))" + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + txn.execute(sql, (event.event_id, event.room_id, key, value,)) + @cachedInlineCallbacks() def get_room_name_and_aliases(self, room_id): def f(txn): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index dd98dcfda8..ae1ad56d9a 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -124,6 +124,19 @@ class RoomMemberStore(SQLBaseStore): invites.event_id for invite in invites ])) + def get_leave_and_ban_events_for_user(self, user_id): + """ Get all the leave events for a user + Args: + user_id (str): The user ID. + Returns: + A deferred list of event objects. + """ + return self.get_rooms_for_user_where_membership_is( + user_id, (Membership.LEAVE, Membership.BAN) + ).addCallback(lambda leaves: self._get_events([ + leave.event_id for leave in leaves + ])) + def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user matches one in the membership list. diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py new file mode 100644 index 0000000000..b7cd0ce3b8 --- /dev/null +++ b/synapse/storage/schema/delta/25/fts.py @@ -0,0 +1,127 @@ +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine, Sqlite3Engine + +import ujson + +logger = logging.getLogger(__name__) + + +POSTGRES_SQL = """ +CREATE TABLE IF NOT EXISTS event_search ( + event_id TEXT, + room_id TEXT, + sender TEXT, + key TEXT, + vector tsvector +); + +INSERT INTO event_search SELECT + event_id, room_id, json::json->>'sender', 'content.body', + to_tsvector('english', json::json->'content'->>'body') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.message'; + +INSERT INTO event_search SELECT + event_id, room_id, json::json->>'sender', 'content.name', + to_tsvector('english', json::json->'content'->>'name') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.name'; + +INSERT INTO event_search SELECT + event_id, room_id, json::json->>'sender', 'content.topic', + to_tsvector('english', json::json->'content'->>'topic') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic'; + + +CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); +CREATE INDEX event_search_ev_idx ON event_search(event_id); +CREATE INDEX event_search_ev_ridx ON event_search(room_id); +""" + + +SQLITE_TABLE = ( + "CREATE VIRTUAL TABLE IF NOT EXISTS event_search" + " USING fts4 ( event_id, room_id, sender, key, value )" +) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + run_postgres_upgrade(cur) + return + + if isinstance(database_engine, Sqlite3Engine): + run_sqlite_upgrade(cur) + return + + +def run_postgres_upgrade(cur): + for statement in get_statements(POSTGRES_SQL.splitlines()): + cur.execute(statement) + + +def run_sqlite_upgrade(cur): + cur.execute(SQLITE_TABLE) + + rowid = -1 + while True: + cur.execute( + "SELECT rowid, json FROM event_json" + " WHERE rowid > ?" + " ORDER BY rowid ASC LIMIT 100", + (rowid,) + ) + + res = cur.fetchall() + + if not res: + break + + events = [ + ujson.loads(js) + for _, js in res + ] + + rowid = max(rid for rid, _ in res) + + rows = [] + for ev in events: + content = ev.get("content", {}) + body = content.get("body", None) + name = content.get("name", None) + topic = content.get("topic", None) + sender = ev.get("sender", None) + if ev["type"] == "m.room.message" and body: + rows.append(( + ev["event_id"], ev["room_id"], sender, "content.body", body + )) + if ev["type"] == "m.room.name" and name: + rows.append(( + ev["event_id"], ev["room_id"], sender, "content.name", name + )) + if ev["type"] == "m.room.topic" and topic: + rows.append(( + ev["event_id"], ev["room_id"], sender, "content.topic", topic + )) + + if rows: + logger.info(rows) + cur.executemany( + "INSERT INTO event_search (event_id, room_id, sender, key, value)" + " VALUES (?,?,?,?,?)", + rows + ) diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/schema/delta/25/tags.sql new file mode 100644 index 0000000000..527424c998 --- /dev/null +++ b/synapse/storage/schema/delta/25/tags.sql @@ -0,0 +1,38 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS room_tags( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + tag TEXT NOT NULL, -- The name of the tag. + content TEXT NOT NULL, -- The JSON content of the tag. + CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) +); + +CREATE TABLE IF NOT EXISTS room_tags_revisions ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + stream_id BIGINT NOT NULL, -- The current version of the room tags. + CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) +); + +CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0); diff --git a/synapse/storage/search.py b/synapse/storage/search.py new file mode 100644 index 0000000000..cdf003502f --- /dev/null +++ b/synapse/storage/search.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from _base import SQLBaseStore +from synapse.storage.engines import PostgresEngine, Sqlite3Engine + +from collections import namedtuple + +"""The result of a search. + +Fields: + rank_map (dict): Mapping event_id -> rank + event_map (dict): Mapping event_id -> event + pagination_token (str): Pagination token +""" +SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token")) + + +class SearchStore(SQLBaseStore): + @defer.inlineCallbacks + def search_msgs(self, room_ids, search_term, keys): + """Performs a full text search over events with given keys. + + Args: + room_ids (list): List of room ids to search in + search_term (str): Search term to search for + keys (list): List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + + Returns: + SearchResult + """ + clauses = [] + args = [] + + # Make sure we don't explode because the person is in too many rooms. + # We filter the results below regardless. + if len(room_ids) < 500: + clauses.append( + "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) + ) + args.extend(room_ids) + + local_clauses = [] + for key in keys: + local_clauses.append("key = ?") + args.append(key) + + clauses.append( + "(%s)" % (" OR ".join(local_clauses),) + ) + + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id" + " FROM plainto_tsquery('english', ?) as query, event_search" + " WHERE vector @@ query" + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" + " FROM event_search" + " WHERE value MATCH ?" + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + for clause in clauses: + sql += " AND " + clause + + # We add an arbitrary limit here to ensure we don't try to pull the + # entire table from the database. + sql += " ORDER BY rank DESC LIMIT 500" + + results = yield self._execute( + "search_msgs", self.cursor_to_dict, sql, *([search_term] + args) + ) + + results = filter(lambda row: row["room_id"] in room_ids, results) + + events = yield self._get_events([r["event_id"] for r in results]) + + event_map = { + ev.event_id: ev + for ev in events + } + + defer.returnValue(SearchResult( + { + r["event_id"]: r["rank"] + for r in results + if r["event_id"] in event_map + }, + event_map, + None + )) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 6f2a50d585..acfb322a53 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -214,7 +214,6 @@ class StateStore(SQLBaseStore): that are in the `types` list. Args: - room_id (str) event_ids (list) types (list): List of (type, state_key) tuples which are used to filter the state fetched. `state_key` may be None, which matches diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 3cab06fdef..15d4c2bf68 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -23,7 +23,7 @@ paginate bacwards. This is implemented by keeping two ordering columns: stream_ordering and topological_ordering. Stream ordering is basically insertion/received order -(except for events from backfill requests). The topolgical_ordering is a +(except for events from backfill requests). The topological_ordering is a weak ordering of events based on the pdu graph. This means that we have to have two different types of tokens, depending on @@ -436,3 +436,138 @@ class StreamStore(SQLBaseStore): internal = event.internal_metadata internal.before = str(RoomStreamToken(topo, stream - 1)) internal.after = str(RoomStreamToken(topo, stream)) + + @defer.inlineCallbacks + def get_events_around(self, room_id, event_id, before_limit, after_limit): + """Retrieve events and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + + Returns: + dict + """ + + results = yield self.runInteraction( + "get_events_around", self._get_events_around_txn, + room_id, event_id, before_limit, after_limit + ) + + events_before = yield self._get_events( + [e for e in results["before"]["event_ids"]], + get_prev_content=True + ) + + events_after = yield self._get_events( + [e for e in results["after"]["event_ids"]], + get_prev_content=True + ) + + defer.returnValue({ + "events_before": events_before, + "events_after": events_after, + "start": results["before"]["token"], + "end": results["after"]["token"], + }) + + def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit): + """Retrieves event_ids and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + + Returns: + dict + """ + + results = self._simple_select_one_txn( + txn, + "events", + keyvalues={ + "event_id": event_id, + "room_id": room_id, + }, + retcols=["stream_ordering", "topological_ordering"], + ) + + stream_ordering = results["stream_ordering"] + topological_ordering = results["topological_ordering"] + + query_before = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND (topological_ordering < ?" + " OR (topological_ordering = ? AND stream_ordering < ?))" + " ORDER BY topological_ordering DESC, stream_ordering DESC" + " LIMIT ?" + ) + + query_after = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND (topological_ordering > ?" + " OR (topological_ordering = ? AND stream_ordering > ?))" + " ORDER BY topological_ordering ASC, stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute( + query_before, + ( + room_id, topological_ordering, topological_ordering, + stream_ordering, before_limit, + ) + ) + + rows = self.cursor_to_dict(txn) + events_before = [r["event_id"] for r in rows] + + if rows: + start_token = str(RoomStreamToken( + rows[0]["topological_ordering"], + rows[0]["stream_ordering"] - 1, + )) + else: + start_token = str(RoomStreamToken( + topological_ordering, + stream_ordering - 1, + )) + + txn.execute( + query_after, + ( + room_id, topological_ordering, topological_ordering, + stream_ordering, after_limit, + ) + ) + + rows = self.cursor_to_dict(txn) + events_after = [r["event_id"] for r in rows] + + if rows: + end_token = str(RoomStreamToken( + rows[-1]["topological_ordering"], + rows[-1]["stream_ordering"], + )) + else: + end_token = str(RoomStreamToken( + topological_ordering, + stream_ordering, + )) + + return { + "before": { + "event_ids": events_before, + "token": start_token, + }, + "after": { + "event_ids": events_after, + "token": end_token, + }, + } diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py new file mode 100644 index 0000000000..641ea250f0 --- /dev/null +++ b/synapse/storage/tags.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached +from twisted.internet import defer +from .util.id_generators import StreamIdGenerator + +import ujson as json +import logging + +logger = logging.getLogger(__name__) + + +class TagsStore(SQLBaseStore): + def __init__(self, hs): + super(TagsStore, self).__init__(hs) + + self._private_user_data_id_gen = StreamIdGenerator( + "private_user_data_max_stream_id", "stream_id" + ) + + def get_max_private_user_data_stream_id(self): + """Get the current max stream id for the private user data stream + + Returns: + A deferred int. + """ + return self._private_user_data_id_gen.get_max_token(self) + + @cached() + def get_tags_for_user(self, user_id): + """Get all the tags for a user. + + + Args: + user_id(str): The user to get the tags for. + Returns: + A deferred dict mapping from room_id strings to lists of tag + strings. + """ + + deferred = self._simple_select_list( + "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] + ) + + @deferred.addCallback + def tags_by_room(rows): + tags_by_room = {} + for row in rows: + room_tags = tags_by_room.setdefault(row["room_id"], {}) + room_tags[row["tag"]] = json.loads(row["content"]) + return tags_by_room + + return deferred + + @defer.inlineCallbacks + def get_updated_tags(self, user_id, stream_id): + """Get all the tags for the rooms where the tags have changed since the + given version + + Args: + user_id(str): The user to get the tags for. + stream_id(int): The earliest update to get for the user. + Returns: + A deferred dict mapping from room_id strings to lists of tag + strings for all the rooms that changed since the stream_id token. + """ + def get_updated_tags_txn(txn): + sql = ( + "SELECT room_id from room_tags_revisions" + " WHERE user_id = ? AND stream_id > ?" + ) + txn.execute(sql, (user_id, stream_id)) + room_ids = [row[0] for row in txn.fetchall()] + return room_ids + + room_ids = yield self.runInteraction( + "get_updated_tags", get_updated_tags_txn + ) + + results = {} + if room_ids: + tags_by_room = yield self.get_tags_for_user(user_id) + for room_id in room_ids: + results[room_id] = tags_by_room[room_id] + + defer.returnValue(results) + + def get_tags_for_room(self, user_id, room_id): + """Get all the tags for the given room + Args: + user_id(str): The user to get tags for + room_id(str): The room to get tags for + Returns: + A deferred list of string tags. + """ + return self._simple_select_list( + table="room_tags", + keyvalues={"user_id": user_id, "room_id": room_id}, + retcols=("tag", "content"), + desc="get_tags_for_room", + ).addCallback(lambda rows: { + row["tag"]: json.loads(row["content"]) for row in rows + }) + + @defer.inlineCallbacks + def add_tag_to_room(self, user_id, room_id, tag, content): + """Add a tag to a room for a user. + Args: + user_id(str): The user to add a tag for. + room_id(str): The room to add a tag for. + tag(str): The tag name to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the tag has been added. + """ + content_json = json.dumps(content) + + def add_tag_txn(txn, next_id): + self._simple_upsert_txn( + txn, + table="room_tags", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "tag": tag, + }, + values={ + "content": content_json, + } + ) + self._update_revision_txn(txn, user_id, room_id, next_id) + + with (yield self._private_user_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction("add_tag", add_tag_txn, next_id) + + self.get_tags_for_user.invalidate((user_id,)) + + result = yield self._private_user_data_id_gen.get_max_token(self) + defer.returnValue(result) + + @defer.inlineCallbacks + def remove_tag_from_room(self, user_id, room_id, tag): + """Remove a tag from a room for a user. + Returns: + A deferred that completes once the tag has been removed + """ + def remove_tag_txn(txn, next_id): + sql = ( + "DELETE FROM room_tags " + " WHERE user_id = ? AND room_id = ? AND tag = ?" + ) + txn.execute(sql, (user_id, room_id, tag)) + self._update_revision_txn(txn, user_id, room_id, next_id) + + with (yield self._private_user_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction("remove_tag", remove_tag_txn, next_id) + + self.get_tags_for_user.invalidate((user_id,)) + + result = yield self._private_user_data_id_gen.get_max_token(self) + defer.returnValue(result) + + def _update_revision_txn(self, txn, user_id, room_id, next_id): + """Update the latest revision of the tags for the given user and room. + + Args: + txn: The database cursor + user_id(str): The ID of the user. + room_id(str): The ID of the room. + next_id(int): The the revision to advance to. + """ + + update_max_id_sql = ( + "UPDATE private_user_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + + update_sql = ( + "UPDATE room_tags_revisions" + " SET stream_id = ?" + " WHERE user_id = ?" + " AND room_id = ?" + ) + txn.execute(update_sql, (next_id, user_id, room_id)) + + if txn.rowcount == 0: + insert_sql = ( + "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)" + " VALUES (?, ?, ?)" + ) + try: + txn.execute(insert_sql, (user_id, room_id, next_id)) + except self.database_engine.module.IntegrityError: + # Ignore insertion errors. It doesn't matter if the row wasn't + # inserted because if two updates happend concurrently the one + # with the higher stream_id will not be reported to a client + # unless the previous update has completed. It doesn't matter + # which stream_id ends up in the table, as long as it is higher + # than the id that the client has. + pass diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 15695e9831..4e0d7c9774 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -253,16 +253,6 @@ class TransactionStore(SQLBaseStore): retry_interval (int) - how long until next retry in ms """ - # As this is the new value, we might as well prefill the cache - self.get_destination_retry_timings.prefill( - destination, - { - "destination": destination, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval - }, - ) - # XXX: we could chose to not bother persisting this if our cache thinks # this is a NOOP return self.runInteraction( @@ -275,31 +265,25 @@ class TransactionStore(SQLBaseStore): def _set_destination_retry_timings(self, txn, destination, retry_last_ts, retry_interval): - query = ( - "UPDATE destinations" - " SET retry_last_ts = ?, retry_interval = ?" - " WHERE destination = ?" - ) + txn.call_after(self.get_destination_retry_timings.invalidate, (destination,)) - txn.execute( - query, - ( - retry_last_ts, retry_interval, destination, - ) + self._simple_upsert_txn( + txn, + "destinations", + keyvalues={ + "destination": destination, + }, + values={ + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + }, + insertion_values={ + "destination": destination, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + } ) - if txn.rowcount == 0: - # destination wasn't already in table. Insert it. - self._simple_insert_txn( - txn, - table="destinations", - values={ - "destination": destination, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - } - ) - def get_destinations_needing_retry(self): """Get all destinations which are due a retry for sending a transaction. diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 699083ae12..f0d68b5bf2 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -21,6 +21,7 @@ from synapse.handlers.presence import PresenceEventSource from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.receipts import ReceiptEventSource +from synapse.handlers.private_user_data import PrivateUserDataEventSource class EventSources(object): @@ -29,6 +30,7 @@ class EventSources(object): "presence": PresenceEventSource, "typing": TypingNotificationEventSource, "receipt": ReceiptEventSource, + "private_user_data": PrivateUserDataEventSource, } def __init__(self, hs): @@ -52,5 +54,8 @@ class EventSources(object): receipt_key=( yield self.sources["receipt"].get_current_key() ), + private_user_data_key=( + yield self.sources["private_user_data"].get_current_key() + ), ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 9cffc33d27..28344d8b36 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -47,7 +47,7 @@ class DomainSpecificString( @classmethod def from_string(cls, s): """Parse the string given by 's' into a structure object.""" - if s[0] != cls.SIGIL: + if len(s) < 1 or s[0] != cls.SIGIL: raise SynapseError(400, "Expected %s string to start with '%s'" % ( cls.__name__, cls.SIGIL, )) @@ -98,10 +98,13 @@ class EventID(DomainSpecificString): class StreamToken( - namedtuple( - "Token", - ("room_key", "presence_key", "typing_key", "receipt_key") - ) + namedtuple("Token", ( + "room_key", + "presence_key", + "typing_key", + "receipt_key", + "private_user_data_key", + )) ): _SEPARATOR = "_" @@ -109,7 +112,7 @@ class StreamToken( def from_string(cls, string): try: keys = string.split(cls._SEPARATOR) - if len(keys) == len(cls._fields) - 1: + while len(keys) < len(cls._fields): # i.e. old token from before receipt_key keys.append("0") return cls(*keys) @@ -128,13 +131,14 @@ class StreamToken( else: return int(self.room_key[1:].split("-")[-1]) - def is_after(self, other_token): + def is_after(self, other): """Does this token contain events that the other doesn't?""" return ( - (other_token.room_stream_id < self.room_stream_id) - or (int(other_token.presence_key) < int(self.presence_key)) - or (int(other_token.typing_key) < int(self.typing_key)) - or (int(other_token.receipt_key) < int(self.receipt_key)) + (other.room_stream_id < self.room_stream_id) + or (int(other.presence_key) < int(self.presence_key)) + or (int(other.typing_key) < int(self.typing_key)) + or (int(other.receipt_key) < int(self.receipt_key)) + or (int(other.private_user_data_key) < int(self.private_user_data_key)) ) def copy_and_advance(self, key, new_value): diff --git a/synapse/util/emailutils.py b/synapse/util/emailutils.py deleted file mode 100644 index 7f9a77bf44..0000000000 --- a/synapse/util/emailutils.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" This module allows you to send out emails. -""" -import email.utils -import smtplib -import twisted.python.log -from email.mime.text import MIMEText -from email.mime.multipart import MIMEMultipart - -import logging - -logger = logging.getLogger(__name__) - - -class EmailException(Exception): - pass - - -def send_email(smtp_server, from_addr, to_addr, subject, body): - """Sends an email. - - Args: - smtp_server(str): The SMTP server to use. - from_addr(str): The address to send from. - to_addr(str): The address to send to. - subject(str): The subject of the email. - body(str): The plain text body of the email. - Raises: - EmailException if there was a problem sending the mail. - """ - if not smtp_server or not from_addr or not to_addr: - raise EmailException("Need SMTP server, from and to addresses. Check" - " the config to set these.") - - msg = MIMEMultipart('alternative') - msg['Subject'] = subject - msg['From'] = from_addr - msg['To'] = to_addr - plain_part = MIMEText(body) - msg.attach(plain_part) - - raw_from = email.utils.parseaddr(from_addr)[1] - raw_to = email.utils.parseaddr(to_addr)[1] - if not raw_from or not raw_to: - raise EmailException("Couldn't parse from/to address.") - - logger.info("Sending email to %s on server %s with subject %s", - to_addr, smtp_server, subject) - - try: - smtp = smtplib.SMTP(smtp_server) - smtp.sendmail(raw_from, raw_to, msg.as_string()) - smtp.quit() - except Exception as origException: - twisted.python.log.err() - ese = EmailException() - ese.cause = origException - raise ese diff --git a/synapse/util/third_party_invites.py b/synapse/util/third_party_invites.py new file mode 100644 index 0000000000..31d186740d --- /dev/null +++ b/synapse/util/third_party_invites.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from synapse.api.errors import AuthError + + +INVITE_KEYS = {"id_server", "medium", "address", "display_name"} + +JOIN_KEYS = { + "token", + "public_key", + "key_validity_url", + "sender", + "signed", +} + + +def has_invite_keys(content): + for key in INVITE_KEYS: + if key not in content: + return False + return True + + +def has_join_keys(content): + for key in JOIN_KEYS: + if key not in content: + return False + return True + + +def join_has_third_party_invite(content): + if "third_party_invite" not in content: + return False + return has_join_keys(content["third_party_invite"]) + + +def extract_join_keys(src): + return { + key: value + for key, value in src.items() + if key in JOIN_KEYS + } + + +@defer.inlineCallbacks +def check_key_valid(http_client, event): + try: + response = yield http_client.get_json( + event.content["third_party_invite"]["key_validity_url"], + {"public_key": event.content["third_party_invite"]["public_key"]} + ) + except Exception: + raise AuthError(502, "Third party certificate could not be checked") + if "valid" not in response or not response["valid"]: + raise AuthError(403, "Third party certificate was invalid") diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 6942cdac51..9f9af2d783 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -23,10 +23,17 @@ from tests.utils import ( ) from synapse.types import UserID -from synapse.api.filtering import Filter +from synapse.api.filtering import FilterCollection, Filter user_localpart = "test_user" -MockEvent = namedtuple("MockEvent", "sender type room_id") +# MockEvent = namedtuple("MockEvent", "sender type room_id") + + +def MockEvent(**kwargs): + ev = NonCallableMock(spec_set=kwargs.keys()) + ev.configure_mock(**kwargs) + return ev + class FilteringTestCase(unittest.TestCase): @@ -44,7 +51,6 @@ class FilteringTestCase(unittest.TestCase): ) self.filtering = hs.get_filtering() - self.filter = Filter({}) self.datastore = hs.get_datastore() @@ -57,8 +63,9 @@ class FilteringTestCase(unittest.TestCase): type="m.room.message", room_id="!foo:bar" ) + self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_types_works_with_wildcards(self): @@ -71,7 +78,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_types_works_with_unknowns(self): @@ -84,7 +91,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_types_works_with_literals(self): @@ -97,7 +104,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_types_works_with_wildcards(self): @@ -110,7 +117,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_types_works_with_unknowns(self): @@ -123,7 +130,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_types_takes_priority_over_types(self): @@ -137,7 +144,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_senders_works_with_literals(self): @@ -150,7 +157,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_senders_works_with_unknowns(self): @@ -163,7 +170,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_senders_works_with_literals(self): @@ -176,7 +183,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_senders_works_with_unknowns(self): @@ -189,7 +196,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_senders_takes_priority_over_senders(self): @@ -203,7 +210,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_rooms_works_with_literals(self): @@ -216,7 +223,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!secretbase:unknown" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_rooms_works_with_unknowns(self): @@ -229,7 +236,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!anothersecretbase:unknown" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_rooms_works_with_literals(self): @@ -242,7 +249,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!anothersecretbase:unknown" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_rooms_works_with_unknowns(self): @@ -255,7 +262,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!anothersecretbase:unknown" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_rooms_takes_priority_over_rooms(self): @@ -269,7 +276,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!secretbase:unknown" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_combined_event(self): @@ -287,7 +294,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!stage:unknown" # yup ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_combined_event_bad_sender(self): @@ -305,7 +312,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!stage:unknown" # yup ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_combined_event_bad_room(self): @@ -323,7 +330,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!piggyshouse:muppets" # nope ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_combined_event_bad_type(self): @@ -341,7 +348,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!stage:unknown" # yup ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) @defer.inlineCallbacks @@ -359,7 +366,6 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent( sender="@foo:bar", type="m.profile", - room_id="!foo:bar" ) events = [event] @@ -386,7 +392,6 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent( sender="@foo:bar", type="custom.avatar.3d.crazy", - room_id="!foo:bar" ) events = [event] diff --git a/tests/crypto/__init__.py b/tests/crypto/__init__.py new file mode 100644 index 0000000000..9bff9ec169 --- /dev/null +++ b/tests/crypto/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py new file mode 100644 index 0000000000..7913472941 --- /dev/null +++ b/tests/crypto/test_event_signing.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest + +from synapse.events.builder import EventBuilder +from synapse.crypto.event_signing import add_hashes_and_signatures + +from unpaddedbase64 import decode_base64 + +import nacl.signing + + +# Perform these tests using given secret key so we get entirely deterministic +# signatures output that we can test against. +SIGNING_KEY_SEED = decode_base64( + "YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1" +) + +KEY_ALG = "ed25519" +KEY_VER = 1 +KEY_NAME = "%s:%d" % (KEY_ALG, KEY_VER) + +HOSTNAME = "domain" + + +class EventSigningTestCase(unittest.TestCase): + + def setUp(self): + self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED) + self.signing_key.alg = KEY_ALG + self.signing_key.version = KEY_VER + + def test_sign_minimal(self): + builder = EventBuilder( + { + 'event_id': "$0:domain", + 'origin': "domain", + 'origin_server_ts': 1000000, + 'signatures': {}, + 'type': "X", + 'unsigned': {'age_ts': 1000000}, + }, + ) + + add_hashes_and_signatures(builder, HOSTNAME, self.signing_key) + + event = builder.build() + + self.assertTrue(hasattr(event, 'hashes')) + self.assertIn('sha256', event.hashes) + self.assertEquals( + event.hashes['sha256'], + "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI", + ) + + self.assertTrue(hasattr(event, 'signatures')) + self.assertIn(HOSTNAME, event.signatures) + self.assertIn(KEY_NAME, event.signatures["domain"]) + self.assertEquals( + event.signatures[HOSTNAME][KEY_NAME], + "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+" + "aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA", + ) + + def test_sign_message(self): + builder = EventBuilder( + { + 'content': { + 'body': "Here is the message content", + }, + 'event_id': "$0:domain", + 'origin': "domain", + 'origin_server_ts': 1000000, + 'type': "m.room.message", + 'room_id': "!r:domain", + 'sender': "@u:domain", + 'signatures': {}, + 'unsigned': {'age_ts': 1000000}, + } + ) + + add_hashes_and_signatures(builder, HOSTNAME, self.signing_key) + + event = builder.build() + + self.assertTrue(hasattr(event, 'hashes')) + self.assertIn('sha256', event.hashes) + self.assertEquals( + event.hashes['sha256'], + "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g", + ) + + self.assertTrue(hasattr(event, 'signatures')) + self.assertIn(HOSTNAME, event.signatures) + self.assertIn(KEY_NAME, event.signatures["domain"]) + self.assertEquals( + event.signatures[HOSTNAME][KEY_NAME], + "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw" + "u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA" + ) diff --git a/tests/events/__init__.py b/tests/events/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/events/__init__.py diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py new file mode 100644 index 0000000000..16179921f0 --- /dev/null +++ b/tests/events/test_utils.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .. import unittest + +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +class PruneEventTestCase(unittest.TestCase): + """ Asserts that a new event constructed with `evdict` will look like + `matchdict` when it is redacted. """ + def run_test(self, evdict, matchdict): + self.assertEquals( + prune_event(FrozenEvent(evdict)).get_dict(), + matchdict + ) + + def test_minimal(self): + self.run_test( + {'type': 'A'}, + { + 'type': 'A', + 'content': {}, + 'signatures': {}, + 'unsigned': {}, + } + ) + + def test_basic_keys(self): + self.run_test( + { + 'type': 'A', + 'room_id': '!1:domain', + 'sender': '@2:domain', + 'event_id': '$3:domain', + 'origin': 'domain', + }, + { + 'type': 'A', + 'room_id': '!1:domain', + 'sender': '@2:domain', + 'event_id': '$3:domain', + 'origin': 'domain', + 'content': {}, + 'signatures': {}, + 'unsigned': {}, + } + ) + + def test_unsigned_age_ts(self): + self.run_test( + { + 'type': 'B', + 'unsigned': {'age_ts': 20}, + }, + { + 'type': 'B', + 'content': {}, + 'signatures': {}, + 'unsigned': {'age_ts': 20}, + } + ) + + self.run_test( + { + 'type': 'B', + 'unsigned': {'other_key': 'here'}, + }, + { + 'type': 'B', + 'content': {}, + 'signatures': {}, + 'unsigned': {}, + } + ) + + def test_content(self): + self.run_test( + { + 'type': 'C', + 'content': {'things': 'here'}, + }, + { + 'type': 'C', + 'content': {}, + 'signatures': {}, + 'unsigned': {}, + } + ) + + self.run_test( + { + 'type': 'm.room.create', + 'content': {'creator': '@2:domain', 'other_field': 'here'}, + }, + { + 'type': 'm.room.create', + 'content': {'creator': '@2:domain'}, + 'signatures': {}, + 'unsigned': {}, + } + ) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 29d9bbaad4..0e3b922246 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -369,7 +369,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): # all be ours # I'll already get my own presence state change - self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []}, + self.assertEquals({"start": "0_1_0_0_0", "end": "0_1_0_0_0", "chunk": []}, response ) @@ -388,7 +388,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): "/events?from=s0_1_0&timeout=0", None) self.assertEquals(200, code) - self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [ + self.assertEquals({"start": "s0_1_0_0_0", "end": "s0_2_0_0_0", "chunk": [ {"type": "m.presence", "content": { "user_id": "@banana:test", diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index a2123be81b..93896dd076 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -277,10 +277,10 @@ class RoomPermissionsTestCase(RestTestCase): expect_code=403) # set [invite/join/left] of self, set [invite/join/left] of other, - # expect all 403s + # expect all 404s because room doesn't exist on any server for usr in [self.user_id, self.rmcreator_id]: yield self.join(room=room, user=usr, expect_code=404) - yield self.leave(room=room, user=usr, expect_code=403) + yield self.leave(room=room, user=usr, expect_code=404) @defer.inlineCallbacks def test_membership_private_room_perms(self): diff --git a/tests/test_types.py b/tests/test_types.py index b29a8415b1..495cd20f02 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -15,13 +15,14 @@ from tests import unittest +from synapse.api.errors import SynapseError from synapse.server import BaseHomeServer from synapse.types import UserID, RoomAlias mock_homeserver = BaseHomeServer(hostname="my.domain") -class UserIDTestCase(unittest.TestCase): +class UserIDTestCase(unittest.TestCase): def test_parse(self): user = UserID.from_string("@1234abcd:my.domain") @@ -29,6 +30,11 @@ class UserIDTestCase(unittest.TestCase): self.assertEquals("my.domain", user.domain) self.assertEquals(True, mock_homeserver.is_mine(user)) + def test_pase_empty(self): + with self.assertRaises(SynapseError): + UserID.from_string("") + + def test_build(self): user = UserID("5678efgh", "my.domain") @@ -44,7 +50,6 @@ class UserIDTestCase(unittest.TestCase): class RoomAliasTestCase(unittest.TestCase): - def test_parse(self): room = RoomAlias.from_string("#channel:my.domain") diff --git a/tox.ini b/tox.ini index a69948484f..01b23e6bd9 100644 --- a/tox.ini +++ b/tox.ini @@ -19,6 +19,7 @@ commands = check-manifest [testenv:pep8] +skip_install = True basepython = python2.7 deps = flake8 |