diff options
30 files changed, 273 insertions, 610 deletions
diff --git a/AUTHORS.rst b/AUTHORS.rst index 8711a6ae5c..3dcb1c2a89 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -57,3 +57,6 @@ Florent Violleau <floviolleau at gmail dot com> Niklas Riekenbrauck <nikriek at gmail dot.com> * Add JWT support for registration and login + +Christoph Witzany <christoph at web.crofting.com> + * Add LDAP support for authentication diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index bc90605324..6da6a1b62e 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -100,11 +100,6 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("push_bulk to %s threw exception %s", uri, ex) defer.returnValue(False) - @defer.inlineCallbacks - def push(self, service, event, txn_id=None): - response = yield self.push_bulk(service, [event], txn_id) - defer.returnValue(response) - def _serialize(self, events): time_now = self.clock.time_msec() return [ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index acf74c8761..9a80ac39ec 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -30,13 +30,14 @@ from .saml2 import SAML2Config from .cas import CasConfig from .password import PasswordConfig from .jwt import JWTConfig +from .ldap import LDAPConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, - JWTConfig, PasswordConfig,): + JWTConfig, LDAPConfig, PasswordConfig,): pass diff --git a/synapse/config/ldap.py b/synapse/config/ldap.py new file mode 100644 index 0000000000..9c14593a99 --- /dev/null +++ b/synapse/config/ldap.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 Niklas Riekenbrauck +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class LDAPConfig(Config): + def read_config(self, config): + ldap_config = config.get("ldap_config", None) + if ldap_config: + self.ldap_enabled = ldap_config.get("enabled", False) + self.ldap_server = ldap_config["server"] + self.ldap_port = ldap_config["port"] + self.ldap_tls = ldap_config.get("tls", False) + self.ldap_search_base = ldap_config["search_base"] + self.ldap_search_property = ldap_config["search_property"] + self.ldap_email_property = ldap_config["email_property"] + self.ldap_full_name_property = ldap_config["full_name_property"] + else: + self.ldap_enabled = False + self.ldap_server = None + self.ldap_port = None + self.ldap_tls = False + self.ldap_search_base = None + self.ldap_search_property = None + self.ldap_email_property = None + self.ldap_full_name_property = None + + def default_config(self, **kwargs): + return """\ + # ldap_config: + # enabled: true + # server: "ldap://localhost" + # port: 389 + # tls: false + # search_base: "ou=Users,dc=example,dc=com" + # search_property: "cn" + # email_property: "email" + # full_name_property: "givenName" + """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d5d6faa85f..7a13a8b11c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -49,6 +49,21 @@ class AuthHandler(BaseHandler): self.sessions = {} self.INVALID_TOKEN_HTTP_STATUS = 401 + self.ldap_enabled = hs.config.ldap_enabled + self.ldap_server = hs.config.ldap_server + self.ldap_port = hs.config.ldap_port + self.ldap_tls = hs.config.ldap_tls + self.ldap_search_base = hs.config.ldap_search_base + self.ldap_search_property = hs.config.ldap_search_property + self.ldap_email_property = hs.config.ldap_email_property + self.ldap_full_name_property = hs.config.ldap_full_name_property + + if self.ldap_enabled is True: + import ldap + logger.info("Import ldap version: %s", ldap.__version__) + + self.hs = hs # FIXME better possibility to access registrationHandler later? + @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ @@ -215,8 +230,10 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) + if not (yield self._check_password(user_id, password)): + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + defer.returnValue(user_id) @defer.inlineCallbacks @@ -340,8 +357,10 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) + + if not (yield self._check_password(user_id, password)): + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) logger.info("Logging in user %s", user_id) access_token = yield self.issue_access_token(user_id) @@ -407,11 +426,60 @@ class AuthHandler(BaseHandler): else: defer.returnValue(user_infos.popitem()) - def _check_password(self, user_id, password, stored_hash): - """Checks that user_id has passed password, raises LoginError if not.""" - if not self.validate_hash(password, stored_hash): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) + @defer.inlineCallbacks + def _check_password(self, user_id, password): + defer.returnValue( + not ( + (yield self._check_ldap_password(user_id, password)) + or + (yield self._check_local_password(user_id, password)) + )) + + @defer.inlineCallbacks + def _check_local_password(self, user_id, password): + try: + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + defer.returnValue(not self.validate_hash(password, password_hash)) + except LoginError: + defer.returnValue(False) + + @defer.inlineCallbacks + def _check_ldap_password(self, user_id, password): + if self.ldap_enabled is not True: + logger.debug("LDAP not configured") + defer.returnValue(False) + + import ldap + + logger.info("Authenticating %s with LDAP" % user_id) + try: + ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) + logger.debug("Connecting LDAP server at %s" % ldap_url) + l = ldap.initialize(ldap_url) + if self.ldap_tls: + logger.debug("Initiating TLS") + self._connection.start_tls_s() + + local_name = UserID.from_string(user_id).localpart + + dn = "%s=%s, %s" % ( + self.ldap_search_property, + local_name, + self.ldap_search_base) + logger.debug("DN for LDAP authentication: %s" % dn) + + l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) + + if not (yield self.does_user_exist(user_id)): + handler = self.hs.get_handlers().registration_handler + user_id, access_token = ( + yield handler.register(localpart=local_name) + ) + + defer.returnValue(True) + except ldap.LDAPError, e: + logger.warn("LDAP error: %s", e) + defer.returnValue(False) @defer.inlineCallbacks def issue_access_token(self, user_id): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eb02f0e000..c28226f840 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -40,6 +40,7 @@ from synapse.events.utils import prune_event from synapse.util.retryutils import NotRetryingDestination from synapse.push.action_generator import ActionGenerator +from synapse.util.distributor import user_joined_room from twisted.internet import defer @@ -49,10 +50,6 @@ import logging logger = logging.getLogger(__name__) -def user_joined_room(distributor, user, room_id): - return distributor.fire("user_joined_room", user, room_id) - - class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 10608c0dd9..f51feda2f4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -34,10 +34,6 @@ import logging logger = logging.getLogger(__name__) -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - class MessageHandler(BaseHandler): def __init__(self, hs): @@ -49,35 +45,6 @@ class MessageHandler(BaseHandler): self.snapshot_cache = SnapshotCache() @defer.inlineCallbacks - def get_message(self, msg_id=None, room_id=None, sender_id=None, - user_id=None): - """ Retrieve a message. - - Args: - msg_id (str): The message ID to obtain. - room_id (str): The room where the message resides. - sender_id (str): The user ID of the user who sent the message. - user_id (str): The user ID of the user making this request. - Returns: - The message, or None if no message exists. - Raises: - SynapseError if something went wrong. - """ - yield self.auth.check_joined_room(room_id, user_id) - - # Pull out the message from the db -# msg = yield self.store.get_message( -# room_id=room_id, -# msg_id=msg_id, -# user_id=sender_id -# ) - - # TODO (erikj): Once we work out the correct c-s api we need to think - # on how to do this. - - defer.returnValue(None) - - @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, as_client_event=True): """Get messages in a room. @@ -202,12 +169,8 @@ class MessageHandler(BaseHandler): membership = builder.content.get("membership", None) target = UserID.from_string(builder.state_key) - if membership == Membership.JOIN: + if membership in {Membership.JOIN, Membership.INVITE}: # If event doesn't include a display name, add one. - yield collect_presencelike_data( - self.distributor, target, builder.content - ) - elif membership == Membership.INVITE: profile = self.hs.get_handlers().profile_handler content = builder.content diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b45eafbb49..e37409170d 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,7 +17,6 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.types import UserID, Requester -from synapse.util import unwrapFirstError from ._base import BaseHandler @@ -27,14 +26,6 @@ import logging logger = logging.getLogger(__name__) -def changed_presencelike_data(distributor, user, state): - return distributor.fire("changed_presencelike_data", user, state) - - -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - class ProfileHandler(BaseHandler): def __init__(self, hs): @@ -46,17 +37,9 @@ class ProfileHandler(BaseHandler): ) distributor = hs.get_distributor() - self.distributor = distributor - - distributor.declare("collect_presencelike_data") - distributor.declare("changed_presencelike_data") distributor.observe("registered_user", self.registered_user) - distributor.observe( - "collect_presencelike_data", self.collect_presencelike_data - ) - def registered_user(self, user): return self.store.create_profile(user.localpart) @@ -105,10 +88,6 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_displayname ) - yield changed_presencelike_data(self.distributor, target_user, { - "displayname": new_displayname, - }) - yield self._update_join_states(requester) @defer.inlineCallbacks @@ -152,31 +131,9 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_avatar_url ) - yield changed_presencelike_data(self.distributor, target_user, { - "avatar_url": new_avatar_url, - }) - yield self._update_join_states(requester) @defer.inlineCallbacks - def collect_presencelike_data(self, user, state): - if not self.hs.is_mine(user): - defer.returnValue(None) - - (displayname, avatar_url) = yield defer.gatherResults( - [ - self.store.get_profile_displayname(user.localpart), - self.store.get_profile_avatar_url(user.localpart), - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - - state["displayname"] = displayname - state["avatar_url"] = avatar_url - - defer.returnValue(None) - - @defer.inlineCallbacks def on_profile_query(self, args): user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f287ee247b..b0862067e1 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -23,6 +23,7 @@ from synapse.api.errors import ( from ._base import BaseHandler from synapse.util.async import run_on_reactor from synapse.http.client import CaptchaServerHttpClient +from synapse.util.distributor import registered_user import logging import urllib @@ -30,10 +31,6 @@ import urllib logger = logging.getLogger(__name__) -def registered_user(distributor, user): - return distributor.fire("registered_user", user) - - class RegistrationHandler(BaseHandler): def __init__(self, hs): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3e1d9282d7..ea306cd42a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,7 +25,6 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.util import stringutils from synapse.util.async import concurrently_execute -from synapse.util.logcontext import preserve_context_over_fn from synapse.util.caches.response_cache import ResponseCache from collections import OrderedDict @@ -39,20 +38,6 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) - - -def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) - - class RoomCreationHandler(BaseHandler): PRESETS_DICT = { diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b6ef3c91af..b69f36aefe 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -23,8 +23,8 @@ from synapse.api.constants import ( EventTypes, Membership, ) from synapse.api.errors import AuthError, SynapseError, Codes -from synapse.util.logcontext import preserve_context_over_fn from synapse.util.async import Linearizer +from synapse.util.distributor import user_left_room, user_joined_room from signedjson.sign import verify_signed_json from signedjson.key import decode_verify_key_bytes @@ -38,20 +38,6 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) - - -def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) - - class RoomMemberHandler(BaseHandler): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level @@ -406,19 +392,6 @@ class RoomMemberHandler(BaseHandler): and guest_access.content["guest_access"] == "can_join" ) - def _should_do_dance(self, current_state, inviter, room_hosts=None): - # TODO: Shouldn't this be remote_room_host? - room_hosts = room_hosts or [] - - is_host_in_room = self.is_host_in_room(current_state) - if is_host_in_room: - return False, room_hosts - - if inviter and not self.hs.is_mine(inviter): - room_hosts.append(inviter.domain) - - return True, room_hosts - @defer.inlineCallbacks def lookup_room_alias(self, room_alias): """ diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index de5c762f50..a456dc19da 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError import collections import logging import random +import time logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ SERVER_CACHE = {} _Server = collections.namedtuple( - "_Server", "priority weight host port" + "_Server", "priority weight host port expires" ) @@ -123,7 +124,8 @@ class SRVClientEndpoint(object): host=domain, port=default_port, priority=0, - weight=0 + weight=0, + expires=0, ) else: self.default_server = None @@ -184,7 +186,13 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks -def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): +def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): + cache_entry = cache.get(service_name, None) + if cache_entry: + if all(s.expires > int(clock.time()) for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + servers = [] try: @@ -204,27 +212,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): continue payload = answer.payload - host = str(payload.target) + srv_ttl = answer.ttl try: answers, _, _ = yield dns_client.lookupAddress(host) except DNSNameError: continue - ips = [ - answer.payload.dottedQuad() - for answer in answers - if answer.type == dns.A and answer.payload - ] - - for ip in ips: - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight) - )) + for answer in answers: + if answer.type == dns.A and answer.payload: + ip = answer.payload.dottedQuad() + host_ttl = min(srv_ttl, answer.ttl) + + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + host_ttl, + )) servers.sort() cache[service_name] = list(servers) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 707ddd248a..cfc728a038 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore): "_get_current_state_for_key" ] + get_event = DataStore.get_event.__func__ get_current_state = DataStore.get_current_state.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_rooms_for_user_where_membership_is = ( @@ -103,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore): def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() result["events"] = self._stream_id_gen.get_current_token() - result["backfilled"] = self._backfill_id_gen.get_current_token() + result["backfill"] = self._backfill_id_gen.get_current_token() return result def process_replication(self, result): @@ -145,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore): position = row[0] internal = json.loads(row[1]) event_json = json.loads(row[2]) - event = FrozenEvent(event_json, internal_metadata_dict=internal) self._invalidate_caches_for_event( event, backfilled, reset_state=position in state_resets diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 04d7fcf6d6..1e27c2c0ce 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -810,12 +810,6 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) - def get_next_stream_id(self): - with self._next_stream_id_lock: - i = self._next_stream_id - self._next_stream_id += 1 - return i - def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): # Fetch a mapping of room_id -> max stream position for "recent" rooms. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 00833422af..57f14fd12b 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -30,18 +30,6 @@ SCHEMA_VERSION = 31 dir_path = os.path.abspath(os.path.dirname(__file__)) -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - class PrepareDatabaseException(Exception): pass diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 59b4ef5ce6..07f5fae8dd 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -176,16 +176,6 @@ class PresenceStore(SQLBaseStore): desc="disallow_presence_visible", ) - def is_presence_visible(self, observed_localpart, observer_userid): - return self._simple_select_one( - table="presence_allow_inbound", - keyvalues={"observed_user_id": observed_localpart, - "observer_user_id": observer_userid}, - retcols=["observed_user_id"], - allow_none=True, - desc="is_presence_visible", - ) - def add_presence_list_pending(self, observer_localpart, observed_userid): return self._simple_insert( table="presence_list", diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 66e7a40e3c..77518e893f 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -121,26 +121,6 @@ class RoomMemberStore(SQLBaseStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) - def get_room_member(self, user_id, room_id): - """Retrieve the current state of a room member. - - Args: - user_id (str): The member's user ID. - room_id (str): The room the member is in. - Returns: - Deferred: Results in a MembershipEvent or None. - """ - return self.runInteraction( - "get_room_member", - self._get_members_events_txn, - room_id, - user_id=user_id, - ).addCallback( - self._get_events - ).addCallback( - lambda events: events[0] if events else None - ) - @cached(max_entries=5000) def get_users_in_room(self, room_id): def f(txn): @@ -203,19 +183,6 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(invite) defer.returnValue(None) - def get_leave_and_ban_events_for_user(self, user_id): - """ Get all the leave events for a user - Args: - user_id (str): The user ID. - Returns: - A deferred list of event objects. - """ - return self.get_rooms_for_user_where_membership_is( - user_id, (Membership.LEAVE, Membership.BAN) - ).addCallback(lambda leaves: self._get_events([ - leave.event_id for leave in leaves - ])) - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user matches one in the membership list. diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 76bcd9cd00..95b12559a6 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -303,96 +303,6 @@ class StreamStore(SQLBaseStore): defer.returnValue(ret) - def get_room_events_stream( - self, - user_id, - from_key, - to_key, - limit=0, - is_guest=False, - room_ids=None - ): - room_ids = room_ids or [] - room_ids = [r for r in room_ids] - if is_guest: - current_room_membership_sql = ( - "SELECT c.room_id FROM history_visibility AS h" - " INNER JOIN current_state_events AS c" - " ON h.event_id = c.event_id" - " WHERE c.room_id IN (%s)" - " AND h.history_visibility = 'world_readable'" % ( - ",".join(map(lambda _: "?", room_ids)) - ) - ) - current_room_membership_args = room_ids - else: - current_room_membership_sql = ( - "SELECT m.room_id FROM room_memberships as m " - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id AND c.state_key = m.user_id" - " WHERE m.user_id = ? AND m.membership = 'join'" - ) - current_room_membership_args = [user_id] - - # We also want to get any membership events about that user, e.g. - # invites or leave notifications. - membership_sql = ( - "SELECT m.event_id FROM room_memberships as m " - "INNER JOIN current_state_events as c ON m.event_id = c.event_id " - "WHERE m.user_id = ? " - ) - membership_args = [user_id] - - if limit: - limit = max(limit, MAX_STREAM_SIZE) - else: - limit = MAX_STREAM_SIZE - - # From and to keys should be integers from ordering. - from_id = RoomStreamToken.parse_stream_token(from_key) - to_id = RoomStreamToken.parse_stream_token(to_key) - - if from_key == to_key: - return defer.succeed(([], to_key)) - - sql = ( - "SELECT e.event_id, e.stream_ordering FROM events AS e WHERE " - "(e.outlier = ? AND (room_id IN (%(current)s)) OR " - "(event_id IN (%(invites)s))) " - "AND e.stream_ordering > ? AND e.stream_ordering <= ? " - "ORDER BY stream_ordering ASC LIMIT %(limit)d " - ) % { - "current": current_room_membership_sql, - "invites": membership_sql, - "limit": limit - } - - def f(txn): - args = ([False] + current_room_membership_args + membership_args + - [from_id.stream, to_id.stream]) - txn.execute(sql, args) - - rows = self.cursor_to_dict(txn) - - ret = self._get_events_txn( - txn, - [r["event_id"] for r in rows], - get_prev_content=True - ) - - self._set_before_and_after(ret, rows) - - if rows: - key = "s%d" % max(r["stream_ordering"] for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = to_key - - return ret, key - - return self.runInteraction("get_room_events_stream", f) - @defer.inlineCallbacks def paginate_room_events(self, room_id, from_key, to_key=None, direction='b', limit=-1): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f69f1cdad4..46cf93ff87 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -112,7 +112,7 @@ class StreamIdGenerator(object): self._current + self._step * (n + 1), self._step ) - self._current += n + self._current += n * self._step for next_id in next_ids: self._unfinished_ids.append(next_id) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 3b9da5b34a..b462495eb8 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -49,9 +49,6 @@ class Clock(object): l.start(msec / 1000.0, now=False) return l - def stop_looping_call(self, loop): - loop.stop() - def call_later(self, delay, callback, *args, **kwargs): """Call something later diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 8875813de4..d7cccc06b1 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -15,7 +15,9 @@ from twisted.internet import defer -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_context_over_fn +) from synapse.util import unwrapFirstError @@ -25,6 +27,24 @@ import logging logger = logging.getLogger(__name__) +def registered_user(distributor, user): + return distributor.fire("registered_user", user) + + +def user_left_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_left_room", user=user, room_id=room_id + ) + + +def user_joined_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_joined_room", user=user, room_id=room_id + ) + + class Distributor(object): """A central dispatch point for loosely-connected pieces of code to register, observe, and fire signals. diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 4076eed269..1101881a2d 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -100,20 +100,6 @@ class _PerHostRatelimiter(object): self.current_processing = set() self.request_times = [] - def is_empty(self): - time_now = self.clock.time_msec() - self.request_times[:] = [ - r for r in self.request_times - if time_now - r < self.window_size - ] - - return not ( - self.ready_request_queue - or self.sleeping_requests - or self.current_processing - or self.request_times - ) - @contextlib.contextmanager def ratelimit(self): # `contextlib.contextmanager` takes a generator and turns it into a diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index b490bb8725..a100f151d4 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -21,10 +21,6 @@ _string_with_symbols = ( ) -def origin_from_ucid(ucid): - return ucid.split("@", 1)[1] - - def random_string(length): return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 0f525a8943..983caafe8a 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): def check(self, method, args, expected_result=None): master_result = yield getattr(self.master_store, method)(*args) slaved_result = yield getattr(self.slaved_store, method)(*args) - self.assertEqual(master_result, slaved_result) if expected_result is not None: self.assertEqual(master_result, expected_result) self.assertEqual(slaved_result, expected_result) + self.assertEqual(master_result, slaved_result) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 9af62702b3..baa4a26eb5 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -205,13 +205,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): [join3] ) + @defer.inlineCallbacks + def test_redactions(self): + yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = yield self.persist( + type="m.room.message", msgtype="m.text", body="Hello" + ) + yield self.replicate() + yield self.check("get_event", [msg.event_id], msg) + + redaction = yield self.persist( + type="m.room.redaction", redacts=msg.event_id + ) + yield self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + yield self.check("get_event", [msg.event_id], redacted) + + @defer.inlineCallbacks + def test_backfilled_redactions(self): + yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = yield self.persist( + type="m.room.message", msgtype="m.text", body="Hello" + ) + yield self.replicate() + yield self.check("get_event", [msg.event_id], msg) + + redaction = yield self.persist( + type="m.room.redaction", redacts=msg.event_id, backfill=True + ) + yield self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + yield self.check("get_event", [msg.event_id], redacted) + event_id = 0 @defer.inlineCallbacks def persist( self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, state=None, reset_state=False, backfill=False, - depth=None, prev_events=[], auth_events=[], prev_state=[], + depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None, **content ): """ @@ -236,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event_dict["state_key"] = key event_dict["prev_state"] = prev_state + if redacts is not None: + event_dict["redacts"] = redacts + event = FrozenEvent(event_dict, internal_metadata_dict=internal) self.event_id += 1 diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index ec78f007ca..63203cea35 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -35,33 +35,6 @@ class PresenceStoreTestCase(unittest.TestCase): self.u_banana = UserID.from_string("@banana:test") @defer.inlineCallbacks - def test_visibility(self): - self.assertFalse((yield self.store.is_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ))) - - yield self.store.allow_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ) - - self.assertTrue((yield self.store.is_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ))) - - yield self.store.disallow_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ) - - self.assertFalse((yield self.store.is_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ))) - - @defer.inlineCallbacks def test_presence_list(self): self.assertEquals( [], diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 5880409867..6afaca3a61 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -110,22 +110,10 @@ class RedactionTestCase(unittest.TestCase): self.room1, self.u_alice, Membership.JOIN ) - start = yield self.store.get_room_events_max_id() - msg_event = yield self.inject_message(self.room1, self.u_alice, u"t") - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - # Check event has not been redacted: - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertObjectHasAttributes( { @@ -144,17 +132,7 @@ class RedactionTestCase(unittest.TestCase): self.room1, msg_event.event_id, self.u_alice, reason ) - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - # Check redaction - - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertEqual(msg_event.event_id, event.event_id) @@ -184,25 +162,12 @@ class RedactionTestCase(unittest.TestCase): self.room1, self.u_alice, Membership.JOIN ) - start = yield self.store.get_room_events_max_id() - msg_event = yield self.inject_room_member( self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}, ) - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - # Check event has not been redacted: - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertObjectHasAttributes( { @@ -221,17 +186,9 @@ class RedactionTestCase(unittest.TestCase): self.room1, msg_event.event_id, self.u_alice, reason ) - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - # Check redaction - event = results[0] + event = yield self.store.get_event(msg_event.event_id) self.assertTrue("redacted_because" in event.unsigned) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index b029ff0584..997090fe35 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -71,13 +71,6 @@ class RoomMemberStoreTestCase(unittest.TestCase): yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) self.assertEquals( - Membership.JOIN, - (yield self.store.get_room_member( - user_id=self.u_alice.to_string(), - room_id=self.room.to_string(), - )).membership - ) - self.assertEquals( [self.u_alice.to_string()], [m.user_id for m in ( yield self.store.get_room_members(self.room.to_string()) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py deleted file mode 100644 index da322152c7..0000000000 --- a/tests/storage/test_stream.py +++ /dev/null @@ -1,185 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from tests import unittest -from twisted.internet import defer - -from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID, RoomID -from tests.storage.event_injector import EventInjector - -from tests.utils import setup_test_homeserver - -from mock import Mock - - -class StreamStoreTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver( - resource_for_federation=Mock(), - http_client=None, - ) - - self.store = hs.get_datastore() - self.event_builder_factory = hs.get_event_builder_factory() - self.event_injector = EventInjector(hs) - self.handlers = hs.get_handlers() - self.message_handler = self.handlers.message_handler - - self.u_alice = UserID.from_string("@alice:test") - self.u_bob = UserID.from_string("@bob:test") - - self.room1 = RoomID.from_string("!abc123:test") - self.room2 = RoomID.from_string("!xyx987:test") - - @defer.inlineCallbacks - def test_event_stream_get_other(self): - # Both bob and alice joins the room - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - # Initial stream key: - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_bob.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - event = results[0] - - self.assertObjectHasAttributes( - { - "type": EventTypes.Message, - "user_id": self.u_alice.to_string(), - "content": {"body": "test", "msgtype": "message"}, - }, - event, - ) - - @defer.inlineCallbacks - def test_event_stream_get_own(self): - # Both bob and alice joins the room - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - # Initial stream key: - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_alice.to_string(), - start, - end, - ) - - self.assertEqual(1, len(results)) - - event = results[0] - - self.assertObjectHasAttributes( - { - "type": EventTypes.Message, - "user_id": self.u_alice.to_string(), - "content": {"body": "test", "msgtype": "message"}, - }, - event, - ) - - @defer.inlineCallbacks - def test_event_stream_join_leave(self): - # Both bob and alice joins the room - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - # Then bob leaves again. - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.LEAVE - ) - - # Initial stream key: - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_bob.to_string(), - start, - end, - ) - - # We should not get the message, as it happened *after* bob left. - self.assertEqual(0, len(results)) - - @defer.inlineCallbacks - def test_event_stream_prev_content(self): - yield self.event_injector.inject_room_member( - self.room1, self.u_bob, Membership.JOIN - ) - - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) - - start = yield self.store.get_room_events_max_id() - - yield self.event_injector.inject_room_member( - self.room1, self.u_alice, Membership.JOIN, - ) - - end = yield self.store.get_room_events_max_id() - - results, _ = yield self.store.get_room_events_stream( - self.u_bob.to_string(), - start, - end, - ) - - # We should not get the message, as it happened *after* bob left. - self.assertEqual(1, len(results)) - - event = results[0] - - self.assertTrue( - "prev_content" in event.unsigned, - msg="No prev_content key" - ) diff --git a/tests/test_dns.py b/tests/test_dns.py index 637b1606f8..c394c57ee7 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -21,6 +21,8 @@ from mock import Mock from synapse.http.endpoint import resolve_service +from tests.utils import MockClock + class DnsTestCase(unittest.TestCase): @@ -63,14 +65,17 @@ class DnsTestCase(unittest.TestCase): self.assertEquals(servers[0].host, ip_address) @defer.inlineCallbacks - def test_from_cache(self): + def test_from_cache_expired_and_dns_fail(self): dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) service_name = "test_service.examle.com" + entry = Mock(spec_set=["expires"]) + entry.expires = 0 + cache = { - service_name: [object()] + service_name: [entry] } servers = yield resolve_service( @@ -83,6 +88,31 @@ class DnsTestCase(unittest.TestCase): self.assertEquals(servers, cache[service_name]) @defer.inlineCallbacks + def test_from_cache(self): + clock = MockClock() + + dns_client_mock = Mock(spec_set=['lookupService']) + dns_client_mock.lookupService = Mock(spec_set=[]) + + service_name = "test_service.examle.com" + + entry = Mock(spec_set=["expires"]) + entry.expires = 999999999 + + cache = { + service_name: [entry] + } + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache, clock=clock, + ) + + self.assertFalse(dns_client_mock.lookupService.called) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + + @defer.inlineCallbacks def test_empty_cache(self): dns_client_mock = Mock() |