diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 05af54d31b..543bf28aec 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -70,11 +70,10 @@ class ApplicationServicesHandler(object):
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
- upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
- upper_bound, limit
+ self.current_max, limit
)
if not events:
@@ -105,9 +104,6 @@ class ApplicationServicesHandler(object):
)
yield self.store.set_appservice_last_pos(upper_bound)
-
- if len(events) < limit:
- break
finally:
self.is_processing = False
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 9cef9d184b..7a0ba6ef35 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -13,13 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
-from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
+from synapse.module_api import ModuleApi
+from synapse.types import UserID
from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache
@@ -63,10 +63,7 @@ class AuthHandler(BaseHandler):
reset_expiry_on_get=True,
)
- account_handler = _AccountHandler(
- hs, check_user_exists=self.check_user_exists
- )
-
+ account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
for module, config in hs.config.password_providers
@@ -75,14 +72,24 @@ class AuthHandler(BaseHandler):
logger.info("Extra password_providers: %r", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
- self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
+ self._password_enabled = hs.config.password_enabled
+
+ login_types = set()
+ if self._password_enabled:
+ login_types.add(LoginType.PASSWORD)
+ for provider in self.password_providers:
+ if hasattr(provider, "get_supported_login_types"):
+ login_types.update(
+ provider.get_supported_login_types().keys()
+ )
+ self._supported_login_types = frozenset(login_types)
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
"""
Takes a dictionary sent by the client in the login / registration
- protocol and handles the login flow.
+ protocol and handles the User-Interactive Auth flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
@@ -260,16 +267,19 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
+ @defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
user_id = authdict["user"]
password = authdict["password"]
- if not user_id.startswith('@'):
- user_id = UserID(user_id, self.hs.hostname).to_string()
- return self._check_password(user_id, password)
+ (canonical_id, callback) = yield self.validate_login(user_id, {
+ "type": LoginType.PASSWORD,
+ "password": password,
+ })
+ defer.returnValue(canonical_id)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@@ -398,26 +408,8 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
- def validate_password_login(self, user_id, password):
- """
- Authenticates the user with their username and password.
-
- Used only by the v1 login API.
-
- Args:
- user_id (str): complete @user:id
- password (str): Password
- Returns:
- defer.Deferred: (str) canonical user id
- Raises:
- StoreError if there was a problem accessing the database
- LoginError if there was an authentication problem.
- """
- return self._check_password(user_id, password)
-
@defer.inlineCallbacks
- def get_access_token_for_user_id(self, user_id, device_id=None,
- initial_display_name=None):
+ def get_access_token_for_user_id(self, user_id, device_id=None):
"""
Creates a new access token for the user with the given user ID.
@@ -431,13 +423,10 @@ class AuthHandler(BaseHandler):
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
- initial_display_name (str): display name to associate with the
- device if it needs re-registering
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
- LoginError if there was an authentication problem.
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
@@ -447,9 +436,11 @@ class AuthHandler(BaseHandler):
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
- yield self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
- )
+ try:
+ yield self.store.get_device(user_id, device_id)
+ except StoreError:
+ yield self.store.delete_access_token(access_token)
+ raise StoreError(400, "Login raced against device deletion")
defer.returnValue(access_token)
@@ -501,29 +492,115 @@ class AuthHandler(BaseHandler):
)
defer.returnValue(result)
+ def get_supported_login_types(self):
+ """Get a the login types supported for the /login API
+
+ By default this is just 'm.login.password' (unless password_enabled is
+ False in the config file), but password auth providers can provide
+ other login types.
+
+ Returns:
+ Iterable[str]: login types
+ """
+ return self._supported_login_types
+
@defer.inlineCallbacks
- def _check_password(self, user_id, password):
- """Authenticate a user against the LDAP and local databases.
+ def validate_login(self, username, login_submission):
+ """Authenticates the user for the /login API
- user_id is checked case insensitively against the local database, but
- will throw if there are multiple inexact matches.
+ Also used by the user-interactive auth flow to validate
+ m.login.password auth types.
Args:
- user_id (str): complete @user:id
+ username (str): username supplied by the user
+ login_submission (dict): the whole of the login submission
+ (including 'type' and other relevant fields)
Returns:
- (str) the canonical_user_id
+ Deferred[str, func]: canonical user id, and optional callback
+ to be called once the access token and device id are issued
Raises:
- LoginError if login fails
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
"""
+
+ if username.startswith('@'):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(
+ username, self.hs.hostname
+ ).to_string()
+
+ login_type = login_submission.get("type")
+ known_login_type = False
+
+ # special case to check for "password" for the check_password interface
+ # for the auth providers
+ password = login_submission.get("password")
+ if login_type == LoginType.PASSWORD:
+ if not self._password_enabled:
+ raise SynapseError(400, "Password login has been disabled.")
+ if not password:
+ raise SynapseError(400, "Missing parameter: password")
+
for provider in self.password_providers:
- is_valid = yield provider.check_password(user_id, password)
- if is_valid:
- defer.returnValue(user_id)
+ if (hasattr(provider, "check_password")
+ and login_type == LoginType.PASSWORD):
+ known_login_type = True
+ is_valid = yield provider.check_password(
+ qualified_user_id, password,
+ )
+ if is_valid:
+ defer.returnValue(qualified_user_id)
+
+ if (not hasattr(provider, "get_supported_login_types")
+ or not hasattr(provider, "check_auth")):
+ # this password provider doesn't understand custom login types
+ continue
+
+ supported_login_types = provider.get_supported_login_types()
+ if login_type not in supported_login_types:
+ # this password provider doesn't understand this login type
+ continue
+
+ known_login_type = True
+ login_fields = supported_login_types[login_type]
+
+ missing_fields = []
+ login_dict = {}
+ for f in login_fields:
+ if f not in login_submission:
+ missing_fields.append(f)
+ else:
+ login_dict[f] = login_submission[f]
+ if missing_fields:
+ raise SynapseError(
+ 400, "Missing parameters for login type %s: %s" % (
+ login_type,
+ missing_fields,
+ ),
+ )
+
+ result = yield provider.check_auth(
+ username, login_type, login_dict,
+ )
+ if result:
+ if isinstance(result, str):
+ result = (result, None)
+ defer.returnValue(result)
+
+ if login_type == LoginType.PASSWORD:
+ known_login_type = True
+
+ canonical_user_id = yield self._check_local_password(
+ qualified_user_id, password,
+ )
- canonical_user_id = yield self._check_local_password(user_id, password)
+ if canonical_user_id:
+ defer.returnValue((canonical_user_id, None))
- if canonical_user_id:
- defer.returnValue(canonical_user_id)
+ if not known_login_type:
+ raise SynapseError(400, "Unknown login type %s" % login_type)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
@@ -584,14 +661,81 @@ class AuthHandler(BaseHandler):
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
- yield self.store.user_delete_access_tokens(
- user_id, except_access_token_id
+ yield self.delete_access_tokens_for_user(
+ user_id, except_token_id=except_access_token_id,
)
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id
)
@defer.inlineCallbacks
+ def deactivate_account(self, user_id):
+ """Deactivate a user's account
+
+ Args:
+ user_id (str): ID of user to be deactivated
+
+ Returns:
+ Deferred
+ """
+ # FIXME: Theoretically there is a race here wherein user resets
+ # password using threepid.
+ yield self.delete_access_tokens_for_user(user_id)
+ yield self.store.user_delete_threepids(user_id)
+ yield self.store.user_set_password_hash(user_id, None)
+
+ @defer.inlineCallbacks
+ def delete_access_token(self, access_token):
+ """Invalidate a single access token
+
+ Args:
+ access_token (str): access token to be deleted
+
+ Returns:
+ Deferred
+ """
+ user_info = yield self.auth.get_user_by_access_token(access_token)
+ yield self.store.delete_access_token(access_token)
+
+ # see if any of our auth providers want to know about this
+ for provider in self.password_providers:
+ if hasattr(provider, "on_logged_out"):
+ yield provider.on_logged_out(
+ user_id=str(user_info["user"]),
+ device_id=user_info["device_id"],
+ access_token=access_token,
+ )
+
+ @defer.inlineCallbacks
+ def delete_access_tokens_for_user(self, user_id, except_token_id=None,
+ device_id=None):
+ """Invalidate access tokens belonging to a user
+
+ Args:
+ user_id (str): ID of user the tokens belong to
+ except_token_id (str|None): access_token ID which should *not* be
+ deleted
+ device_id (str|None): ID of device the tokens are associated with.
+ If None, tokens associated with any device (or no device) will
+ be deleted
+ Returns:
+ Deferred
+ """
+ tokens_and_devices = yield self.store.user_delete_access_tokens(
+ user_id, except_token_id=except_token_id, device_id=device_id,
+ )
+
+ # see if any of our auth providers want to know about this
+ for provider in self.password_providers:
+ if hasattr(provider, "on_logged_out"):
+ for token, device_id in tokens_and_devices:
+ yield provider.on_logged_out(
+ user_id=user_id,
+ device_id=device_id,
+ access_token=token,
+ )
+
+ @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case.
# We've now moving towards the Home Server being the entity that
@@ -696,30 +840,3 @@ class MacaroonGeneartor(object):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
-
-
-class _AccountHandler(object):
- """A proxy object that gets passed to password auth providers so they
- can register new users etc if necessary.
- """
- def __init__(self, hs, check_user_exists):
- self.hs = hs
-
- self._check_user_exists = check_user_exists
-
- def check_user_exists(self, user_id):
- """Check if user exissts.
-
- Returns:
- Deferred(bool)
- """
- return self._check_user_exists(user_id)
-
- def register(self, localpart):
- """Registers a new user with given localpart
-
- Returns:
- Deferred: a 2-tuple of (user_id, access_token)
- """
- reg = self.hs.get_handlers().registration_handler
- return reg.register(localpart=localpart)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index dac4b3f4e0..579d8477ba 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -34,6 +34,7 @@ class DeviceHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
+ self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
@@ -159,9 +160,8 @@ class DeviceHandler(BaseHandler):
else:
raise
- yield self.store.user_delete_access_tokens(
+ yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
- delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
@@ -194,9 +194,8 @@ class DeviceHandler(BaseHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
- yield self.store.user_delete_access_tokens(
+ yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
- delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 8b1e606754..ac70730885 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1706,6 +1706,17 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def do_auth(self, origin, event, context, auth_events):
+ """
+
+ Args:
+ origin (str):
+ event (synapse.events.FrozenEvent):
+ context (synapse.events.snapshot.EventContext):
+ auth_events (dict[(str, str)->str]):
+
+ Returns:
+ defer.Deferred[None]
+ """
# Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events)
@@ -1817,16 +1828,9 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
- context.current_state_ids = dict(context.current_state_ids)
- context.current_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- if k != event_key
- })
- context.prev_state_ids = dict(context.prev_state_ids)
- context.prev_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- })
- context.state_group = self.store.get_next_state_group()
+ self._update_context_for_auth_events(
+ context, auth_events, event_key,
+ )
if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth)
@@ -1906,16 +1910,9 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs.
# TODO.
- context.current_state_ids = dict(context.current_state_ids)
- context.current_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- if k != event_key
- })
- context.prev_state_ids = dict(context.prev_state_ids)
- context.prev_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- })
- context.state_group = self.store.get_next_state_group()
+ self._update_context_for_auth_events(
+ context, auth_events, event_key,
+ )
try:
self.auth.check(event, auth_events=auth_events)
@@ -1923,6 +1920,35 @@ class FederationHandler(BaseHandler):
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
+ def _update_context_for_auth_events(self, context, auth_events,
+ event_key):
+ """Update the state_ids in an event context after auth event resolution
+
+ Args:
+ context (synapse.events.snapshot.EventContext): event context
+ to be updated
+
+ auth_events (dict[(str, str)->str]): Events to update in the event
+ context.
+
+ event_key ((str, str)): (type, state_key) for the current event.
+ this will not be included in the current_state in the context.
+ """
+ state_updates = {
+ k: a.event_id for k, a in auth_events.iteritems()
+ if k != event_key
+ }
+ context.current_state_ids = dict(context.current_state_ids)
+ context.current_state_ids.update(state_updates)
+ if context.delta_ids is not None:
+ context.delta_ids = dict(context.delta_ids)
+ context.delta_ids.update(state_updates)
+ context.prev_state_ids = dict(context.prev_state_ids)
+ context.prev_state_ids.update({
+ k: a.event_id for k, a in auth_events.iteritems()
+ })
+ context.state_group = self.store.get_next_state_group()
+
@defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
""" Given a local and remote auth chain, find the differences. This
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 6699d0888f..da00aeb0f4 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -71,6 +71,7 @@ class GroupsLocalHandler(object):
get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
add_room_to_group = _create_rerouter("add_room_to_group")
+ update_room_in_group = _create_rerouter("update_room_in_group")
remove_room_from_group = _create_rerouter("remove_room_from_group")
update_group_summary_room = _create_rerouter("update_group_summary_room")
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 62b9bd503e..5e5b1952dd 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -17,7 +17,6 @@ import logging
from twisted.internet import defer
-import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, get_domain_from_id
from ._base import BaseHandler
@@ -140,7 +139,7 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname
)
- yield self._update_join_states(requester)
+ yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@@ -184,7 +183,7 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url
)
- yield self._update_join_states(requester)
+ yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def on_profile_query(self, args):
@@ -209,28 +208,24 @@ class ProfileHandler(BaseHandler):
defer.returnValue(response)
@defer.inlineCallbacks
- def _update_join_states(self, requester):
- user = requester.user
- if not self.hs.is_mine(user):
+ def _update_join_states(self, requester, target_user):
+ if not self.hs.is_mine(target_user):
return
yield self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user(
- user.to_string(),
+ target_user.to_string(),
)
for room_id in room_ids:
handler = self.hs.get_handlers().room_member_handler
try:
- # Assume the user isn't a guest because we don't let guests set
- # profile or avatar data.
- # XXX why are we recreating `requester` here for each room?
- # what was wrong with the `requester` we were passed?
- requester = synapse.types.create_requester(user)
+ # Assume the target_user isn't a guest,
+ # because we don't let guests set profile or avatar data.
yield handler.update_membership(
requester,
- user,
+ target_user,
room_id,
"join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 49dc33c147..f6e7e58563 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
@@ -416,7 +417,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_localpart=user.localpart,
)
else:
- yield self.store.user_delete_access_tokens(user_id=user_id)
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 41e1781df7..2cf34e51cb 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -20,6 +20,7 @@ from ._base import BaseHandler
from synapse.api.constants import (
EventTypes, JoinRules,
)
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.async import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache
@@ -70,6 +71,7 @@ class RoomListHandler(BaseHandler):
if search_filter:
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
+ logger.info("Bypassing cache as search request.")
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple,
)
@@ -77,13 +79,16 @@ class RoomListHandler(BaseHandler):
key = (limit, since_token, network_tuple)
result = self.response_cache.get(key)
if not result:
+ logger.info("No cached result, calculating one.")
result = self.response_cache.set(
key,
- self._get_public_room_list(
+ preserve_fn(self._get_public_room_list)(
limit, since_token, network_tuple=network_tuple
)
)
- return result
+ else:
+ logger.info("Using cached deferred result.")
+ return make_deferred_yieldable(result)
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 219529936f..b12988f3c9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -15,7 +15,7 @@
from synapse.api.constants import Membership, EventTypes
from synapse.util.async import concurrently_execute
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn
from synapse.util.metrics import Measure, measure_func
from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user
@@ -184,11 +184,11 @@ class SyncHandler(object):
if not result:
result = self.response_cache.set(
sync_config.request_key,
- self._wait_for_sync_for_user(
+ preserve_fn(self._wait_for_sync_for_user)(
sync_config, since_token, timeout, full_state
)
)
- return result
+ return make_deferred_yieldable(result)
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 2a49456bfc..b5be5d9623 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -152,7 +152,7 @@ class UserDirectoyHandler(object):
for room_id in room_ids:
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
- yield self._handle_intial_room(room_id)
+ yield self._handle_initial_room(room_id)
num_processed_rooms += 1
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
@@ -166,7 +166,7 @@ class UserDirectoyHandler(object):
yield self.store.update_user_directory_stream_pos(new_pos)
@defer.inlineCallbacks
- def _handle_intial_room(self, room_id):
+ def _handle_initial_room(self, room_id):
"""Called when we initially fill out user_directory one room at a time
"""
is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
|