diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 5ad408f549..413425fed1 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -13,17 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .register import RegistrationHandler
-from .room import (
- RoomCreationHandler, RoomContextHandler,
-)
-from .room_member import RoomMemberHandler
-from .message import MessageHandler
-from .federation import FederationHandler
-from .profile import ProfileHandler
-from .directory import DirectoryHandler
from .admin import AdminHandler
+from .directory import DirectoryHandler
+from .federation import FederationHandler
from .identity import IdentityHandler
+from .register import RegistrationHandler
from .search import SearchHandler
@@ -48,13 +42,8 @@ class Handlers(object):
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
- self.message_handler = MessageHandler(hs)
- self.room_creation_handler = RoomCreationHandler(hs)
- self.room_member_handler = RoomMemberHandler(hs)
self.federation_handler = FederationHandler(hs)
- self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs)
self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)
- self.room_context_handler = RoomContextHandler(hs)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index faa5609c0c..704181d2d3 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -18,11 +18,10 @@ import logging
from twisted.internet import defer
import synapse.types
-from synapse.api.constants import Membership, EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import LimitExceededError
from synapse.types import UserID
-
logger = logging.getLogger(__name__)
@@ -113,15 +112,16 @@ class BaseHandler(object):
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
+ current_state_ids = yield context.get_current_state_ids(self.store)
current_state = yield self.store.get_events(
- context.current_state_ids.values()
+ list(current_state_ids.values())
)
else:
current_state = yield self.state_handler.get_current_state(
event.room_id
)
- current_state = current_state.values()
+ current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)
@@ -158,7 +158,7 @@ class BaseHandler(object):
# homeserver.
requester = synapse.types.create_requester(
target_user, is_guest=True)
- handler = self.hs.get_handlers().room_member_handler
+ handler = self.hs.get_room_member_handler()
yield handler.update_membership(
requester,
target_user,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index f36b358b45..5d629126fc 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
from ._base import BaseHandler
-import logging
-
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 05af54d31b..ee41aed69e 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -13,16 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from six import itervalues
+
+from prometheus_client import Counter
+
from twisted.internet import defer
+import synapse
from synapse.api.constants import EventTypes
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.metrics import Measure
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-
-import logging
logger = logging.getLogger(__name__)
+events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
+
def log_failure(failure):
logger.error(
@@ -70,21 +78,25 @@ class ApplicationServicesHandler(object):
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
- upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
- upper_bound, limit
+ self.current_max, limit
)
if not events:
break
+ events_by_room = {}
for event in events:
+ events_by_room.setdefault(event.room_id, []).append(event)
+
+ @defer.inlineCallbacks
+ def handle_event(event):
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
- continue # no services need notifying
+ return # no services need notifying
# Do we know this user exists? If not, poke the user
# query API for all services which match that user regex.
@@ -95,19 +107,39 @@ class ApplicationServicesHandler(object):
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
- self.scheduler.start().addErrback(log_failure)
+ def start_scheduler():
+ return self.scheduler.start().addErrback(log_failure)
+ run_as_background_process("as_scheduler", start_scheduler)
self.started_scheduler = True
# Fork off pushes to these services
for service in services:
- preserve_fn(self.scheduler.submit_event_for_as)(
- service, event
- )
+ self.scheduler.submit_event_for_as(service, event)
+
+ @defer.inlineCallbacks
+ def handle_room_events(events):
+ for event in events:
+ yield handle_event(event)
+
+ yield make_deferred_yieldable(defer.gatherResults([
+ run_in_background(handle_room_events, evs)
+ for evs in itervalues(events_by_room)
+ ], consumeErrors=True))
yield self.store.set_appservice_last_pos(upper_bound)
- if len(events) < limit:
- break
+ now = self.clock.time_msec()
+ ts = yield self.store.get_received_ts(events[-1].event_id)
+
+ synapse.metrics.event_processing_positions.labels(
+ "appservice_sender").set(upper_bound)
+
+ events_processed_counter.inc(len(events))
+
+ synapse.metrics.event_processing_lag.labels(
+ "appservice_sender").set(now - ts)
+ synapse.metrics.event_processing_last_ts.labels(
+ "appservice_sender").set(ts)
finally:
self.is_processing = False
@@ -163,8 +195,11 @@ class ApplicationServicesHandler(object):
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
- results = yield preserve_context_over_deferred(defer.DeferredList([
- preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
+ results = yield make_deferred_yieldable(defer.DeferredList([
+ run_in_background(
+ self.appservice_api.query_3pe,
+ service, kind, protocol, fields,
+ )
for service in services
], consumeErrors=True))
@@ -225,11 +260,15 @@ class ApplicationServicesHandler(object):
event based on the service regex.
"""
services = self.store.get_app_services()
- interested_list = [
- s for s in services if (
- yield s.is_interested(event, self.store)
- )
- ]
+
+ # we can't use a list comprehension here. Since python 3, list
+ # comprehensions use a generator internally. This means you can't yield
+ # inside of a list comprehension anymore.
+ interested_list = []
+ for s in services:
+ if (yield s.is_interested(event, self.store)):
+ interested_list.append(s)
+
defer.returnValue(interested_list)
def _get_services_for_user(self, user_id):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index b00446bec0..402e44cdef 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,24 +14,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from ._base import BaseHandler
-from synapse.api.constants import LoginType
-from synapse.types import UserID
-from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
-from synapse.util.async import run_on_reactor
-from synapse.util.caches.expiringcache import ExpiringCache
-
-from twisted.web.client import PartialDownloadError
-
import logging
+
+import attr
import bcrypt
import pymacaroons
-import simplejson
+from canonicaljson import json
+
+from twisted.internet import defer, threads
+from twisted.web.client import PartialDownloadError
import synapse.util.stringutils as stringutils
+from synapse.api.constants import LoginType
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ InteractiveAuthIncompleteError,
+ LoginError,
+ StoreError,
+ SynapseError,
+)
+from synapse.module_api import ModuleApi
+from synapse.types import UserID
+from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -46,7 +54,6 @@ class AuthHandler(BaseHandler):
"""
super(AuthHandler, self).__init__(hs)
self.checkers = {
- LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn,
@@ -63,10 +70,7 @@ class AuthHandler(BaseHandler):
reset_expiry_on_get=True,
)
- account_handler = _AccountHandler(
- hs, check_user_exists=self.check_user_exists
- )
-
+ account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
for module, config in hs.config.password_providers
@@ -75,39 +79,120 @@ class AuthHandler(BaseHandler):
logger.info("Extra password_providers: %r", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
- self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
+ self._password_enabled = hs.config.password_enabled
+
+ # we keep this as a list despite the O(N^2) implication so that we can
+ # keep PASSWORD first and avoid confusing clients which pick the first
+ # type in the list. (NB that the spec doesn't require us to do so and
+ # clients which favour types that they don't understand over those that
+ # they do are technically broken)
+ login_types = []
+ if self._password_enabled:
+ login_types.append(LoginType.PASSWORD)
+ for provider in self.password_providers:
+ if hasattr(provider, "get_supported_login_types"):
+ for t in provider.get_supported_login_types().keys():
+ if t not in login_types:
+ login_types.append(t)
+ self._supported_login_types = login_types
+
+ @defer.inlineCallbacks
+ def validate_user_via_ui_auth(self, requester, request_body, clientip):
+ """
+ Checks that the user is who they claim to be, via a UI auth.
+
+ This is used for things like device deletion and password reset where
+ the user already has a valid access token, but we want to double-check
+ that it isn't stolen by re-authenticating them.
+
+ Args:
+ requester (Requester): The user, as given by the access token
+
+ request_body (dict): The body of the request sent by the client
+
+ clientip (str): The IP address of the client.
+
+ Returns:
+ defer.Deferred[dict]: the parameters for this request (which may
+ have been given only in a previous call).
+
+ Raises:
+ InteractiveAuthIncompleteError if the client has not yet completed
+ any of the permitted login flows
+
+ AuthError if the client has completed a login flow, and it gives
+ a different user to `requester`
+ """
+
+ # build a list of supported flows
+ flows = [
+ [login_type] for login_type in self._supported_login_types
+ ]
+
+ result, params, _ = yield self.check_auth(
+ flows, request_body, clientip,
+ )
+
+ # find the completed login type
+ for login_type in self._supported_login_types:
+ if login_type not in result:
+ continue
+
+ user_id = result[login_type]
+ break
+ else:
+ # this can't happen
+ raise Exception(
+ "check_auth returned True but no successful login type",
+ )
+
+ # check that the UI auth matched the access token
+ if user_id != requester.user.to_string():
+ raise AuthError(403, "Invalid auth")
+
+ defer.returnValue(params)
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
"""
Takes a dictionary sent by the client in the login / registration
- protocol and handles the login flow.
+ protocol and handles the User-Interactive Auth flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
+ If no auth flows have been completed successfully, raises an
+ InteractiveAuthIncompleteError. To handle this, you can use
+ synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
+ decorator.
+
Args:
flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
+
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
+
clientip (str): The IP address of the client.
+
Returns:
- A tuple of (authed, dict, dict, session_id) where authed is true if
- the client has successfully completed an auth flow. If it is true
- the first dict contains the authenticated credentials of each stage.
+ defer.Deferred[dict, dict, str]: a deferred tuple of
+ (creds, params, session_id).
- If authed is false, the first dictionary is the server response to
- the login request and should be passed back to the client.
+ 'creds' contains the authenticated credentials of each stage.
- In either case, the second dict contains the parameters for this
- request (which may have been given only in a previous call).
+ 'params' contains the parameters for this request (which may
+ have been given only in a previous call).
- session_id is the ID of this session, either passed in by the client
- or assigned by the call to check_auth
+ 'session_id' is the ID of this session, either passed in by the
+ client or assigned by this call
+
+ Raises:
+ InteractiveAuthIncompleteError if the client has not yet completed
+ all the stages in any of the permitted flows.
"""
authdict = None
@@ -135,11 +220,8 @@ class AuthHandler(BaseHandler):
clientdict = session['clientdict']
if not authdict:
- defer.returnValue(
- (
- False, self._auth_dict_for_flows(flows, session),
- clientdict, session['id']
- )
+ raise InteractiveAuthIncompleteError(
+ self._auth_dict_for_flows(flows, session),
)
if 'creds' not in session:
@@ -150,14 +232,12 @@ class AuthHandler(BaseHandler):
errordict = {}
if 'type' in authdict:
login_type = authdict['type']
- if login_type not in self.checkers:
- raise LoginError(400, "", Codes.UNRECOGNIZED)
try:
- result = yield self.checkers[login_type](authdict, clientip)
+ result = yield self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
- except LoginError, e:
+ except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
@@ -166,14 +246,14 @@ class AuthHandler(BaseHandler):
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
- raise e
+ raise
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
for f in flows:
- if len(set(f) - set(creds.keys())) == 0:
+ if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
# include the password in the case of registering, so only log
# the keys (confusingly, clientdict may contain a password
@@ -181,14 +261,16 @@ class AuthHandler(BaseHandler):
# and is not sensitive).
logger.info(
"Auth completed with creds: %r. Client dict has keys: %r",
- creds, clientdict.keys()
+ creds, list(clientdict)
)
- defer.returnValue((True, creds, clientdict, session['id']))
+ defer.returnValue((creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session)
- ret['completed'] = creds.keys()
+ ret['completed'] = list(creds)
ret.update(errordict)
- defer.returnValue((False, ret, clientdict, session['id']))
+ raise InteractiveAuthIncompleteError(
+ ret,
+ )
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
@@ -260,16 +342,37 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
- def _check_password_auth(self, authdict, _):
- if "user" not in authdict or "password" not in authdict:
- raise LoginError(400, "", Codes.MISSING_PARAM)
+ @defer.inlineCallbacks
+ def _check_auth_dict(self, authdict, clientip):
+ """Attempt to validate the auth dict provided by a client
- user_id = authdict["user"]
- password = authdict["password"]
- if not user_id.startswith('@'):
- user_id = UserID.create(user_id, self.hs.hostname).to_string()
+ Args:
+ authdict (object): auth dict provided by the client
+ clientip (str): IP address of the client
+
+ Returns:
+ Deferred: result of the stage verification.
+
+ Raises:
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
+ """
+ login_type = authdict['type']
+ checker = self.checkers.get(login_type)
+ if checker is not None:
+ res = yield checker(authdict, clientip)
+ defer.returnValue(res)
+
+ # build a v1-login-style dict out of the authdict and fall back to the
+ # v1 code
+ user_id = authdict.get("user")
- return self._check_password(user_id, password)
+ if user_id is None:
+ raise SynapseError(400, "", Codes.MISSING_PARAM)
+
+ (canonical_id, callback) = yield self.validate_login(user_id, authdict)
+ defer.returnValue(canonical_id)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@@ -303,7 +406,7 @@ class AuthHandler(BaseHandler):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
- resp_body = simplejson.loads(data)
+ resp_body = json.loads(data)
if 'success' in resp_body:
# Note that we do NOT check the hostname here: we explicitly
@@ -324,15 +427,11 @@ class AuthHandler(BaseHandler):
def _check_msisdn(self, authdict, _):
return self._check_threepid('msisdn', authdict)
- @defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
- yield run_on_reactor()
- defer.returnValue(True)
+ return defer.succeed(True)
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict):
- yield run_on_reactor()
-
if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
@@ -398,26 +497,8 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
- def validate_password_login(self, user_id, password):
- """
- Authenticates the user with their username and password.
-
- Used only by the v1 login API.
-
- Args:
- user_id (str): complete @user:id
- password (str): Password
- Returns:
- defer.Deferred: (str) canonical user id
- Raises:
- StoreError if there was a problem accessing the database
- LoginError if there was an authentication problem.
- """
- return self._check_password(user_id, password)
-
@defer.inlineCallbacks
- def get_access_token_for_user_id(self, user_id, device_id=None,
- initial_display_name=None):
+ def get_access_token_for_user_id(self, user_id, device_id=None):
"""
Creates a new access token for the user with the given user ID.
@@ -431,13 +512,10 @@ class AuthHandler(BaseHandler):
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
- initial_display_name (str): display name to associate with the
- device if it needs re-registering
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
- LoginError if there was an authentication problem.
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
@@ -447,9 +525,11 @@ class AuthHandler(BaseHandler):
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
- yield self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
- )
+ try:
+ yield self.store.get_device(user_id, device_id)
+ except StoreError:
+ yield self.store.delete_access_token(access_token)
+ raise StoreError(400, "Login raced against device deletion")
defer.returnValue(access_token)
@@ -501,29 +581,115 @@ class AuthHandler(BaseHandler):
)
defer.returnValue(result)
+ def get_supported_login_types(self):
+ """Get a the login types supported for the /login API
+
+ By default this is just 'm.login.password' (unless password_enabled is
+ False in the config file), but password auth providers can provide
+ other login types.
+
+ Returns:
+ Iterable[str]: login types
+ """
+ return self._supported_login_types
+
@defer.inlineCallbacks
- def _check_password(self, user_id, password):
- """Authenticate a user against the LDAP and local databases.
+ def validate_login(self, username, login_submission):
+ """Authenticates the user for the /login API
- user_id is checked case insensitively against the local database, but
- will throw if there are multiple inexact matches.
+ Also used by the user-interactive auth flow to validate
+ m.login.password auth types.
Args:
- user_id (str): complete @user:id
+ username (str): username supplied by the user
+ login_submission (dict): the whole of the login submission
+ (including 'type' and other relevant fields)
Returns:
- (str) the canonical_user_id
+ Deferred[str, func]: canonical user id, and optional callback
+ to be called once the access token and device id are issued
Raises:
- LoginError if login fails
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
"""
+
+ if username.startswith('@'):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(
+ username, self.hs.hostname
+ ).to_string()
+
+ login_type = login_submission.get("type")
+ known_login_type = False
+
+ # special case to check for "password" for the check_password interface
+ # for the auth providers
+ password = login_submission.get("password")
+ if login_type == LoginType.PASSWORD:
+ if not self._password_enabled:
+ raise SynapseError(400, "Password login has been disabled.")
+ if not password:
+ raise SynapseError(400, "Missing parameter: password")
+
for provider in self.password_providers:
- is_valid = yield provider.check_password(user_id, password)
- if is_valid:
- defer.returnValue(user_id)
+ if (hasattr(provider, "check_password")
+ and login_type == LoginType.PASSWORD):
+ known_login_type = True
+ is_valid = yield provider.check_password(
+ qualified_user_id, password,
+ )
+ if is_valid:
+ defer.returnValue((qualified_user_id, None))
+
+ if (not hasattr(provider, "get_supported_login_types")
+ or not hasattr(provider, "check_auth")):
+ # this password provider doesn't understand custom login types
+ continue
+
+ supported_login_types = provider.get_supported_login_types()
+ if login_type not in supported_login_types:
+ # this password provider doesn't understand this login type
+ continue
+
+ known_login_type = True
+ login_fields = supported_login_types[login_type]
+
+ missing_fields = []
+ login_dict = {}
+ for f in login_fields:
+ if f not in login_submission:
+ missing_fields.append(f)
+ else:
+ login_dict[f] = login_submission[f]
+ if missing_fields:
+ raise SynapseError(
+ 400, "Missing parameters for login type %s: %s" % (
+ login_type,
+ missing_fields,
+ ),
+ )
+
+ result = yield provider.check_auth(
+ username, login_type, login_dict,
+ )
+ if result:
+ if isinstance(result, str):
+ result = (result, None)
+ defer.returnValue(result)
+
+ if login_type == LoginType.PASSWORD:
+ known_login_type = True
+
+ canonical_user_id = yield self._check_local_password(
+ qualified_user_id, password,
+ )
- canonical_user_id = yield self._check_local_password(user_id, password)
+ if canonical_user_id:
+ defer.returnValue((canonical_user_id, None))
- if canonical_user_id:
- defer.returnValue(canonical_user_id)
+ if not known_login_type:
+ raise SynapseError(400, "Unknown login type %s" % login_type)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
@@ -549,7 +715,7 @@ class AuthHandler(BaseHandler):
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
- result = self.validate_hash(password, password_hash)
+ result = yield self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
defer.returnValue(None)
@@ -573,22 +739,65 @@ class AuthHandler(BaseHandler):
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
- def set_password(self, user_id, newpassword, requester=None):
- password_hash = self.hash(newpassword)
+ def delete_access_token(self, access_token):
+ """Invalidate a single access token
- except_access_token_id = requester.access_token_id if requester else None
+ Args:
+ access_token (str): access token to be deleted
- try:
- yield self.store.user_set_password_hash(user_id, password_hash)
- except StoreError as e:
- if e.code == 404:
- raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
- raise e
- yield self.store.user_delete_access_tokens(
- user_id, except_access_token_id
+ Returns:
+ Deferred
+ """
+ user_info = yield self.auth.get_user_by_access_token(access_token)
+ yield self.store.delete_access_token(access_token)
+
+ # see if any of our auth providers want to know about this
+ for provider in self.password_providers:
+ if hasattr(provider, "on_logged_out"):
+ yield provider.on_logged_out(
+ user_id=str(user_info["user"]),
+ device_id=user_info["device_id"],
+ access_token=access_token,
+ )
+
+ # delete pushers associated with this access token
+ if user_info["token_id"] is not None:
+ yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+ str(user_info["user"]), (user_info["token_id"], )
+ )
+
+ @defer.inlineCallbacks
+ def delete_access_tokens_for_user(self, user_id, except_token_id=None,
+ device_id=None):
+ """Invalidate access tokens belonging to a user
+
+ Args:
+ user_id (str): ID of user the tokens belong to
+ except_token_id (str|None): access_token ID which should *not* be
+ deleted
+ device_id (str|None): ID of device the tokens are associated with.
+ If None, tokens associated with any device (or no device) will
+ be deleted
+ Returns:
+ Deferred
+ """
+ tokens_and_devices = yield self.store.user_delete_access_tokens(
+ user_id, except_token_id=except_token_id, device_id=device_id,
)
- yield self.hs.get_pusherpool().remove_pushers_by_user(
- user_id, except_access_token_id
+
+ # see if any of our auth providers want to know about this
+ for provider in self.password_providers:
+ if hasattr(provider, "on_logged_out"):
+ for token, token_id, device_id in tokens_and_devices:
+ yield provider.on_logged_out(
+ user_id=user_id,
+ device_id=device_id,
+ access_token=token,
+ )
+
+ # delete pushers associated with the access tokens
+ yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+ user_id, (token_id for _, token_id, _ in tokens_and_devices),
)
@defer.inlineCallbacks
@@ -616,6 +825,15 @@ class AuthHandler(BaseHandler):
if medium == 'email':
address = address.lower()
+ identity_handler = self.hs.get_handlers().identity_handler
+ yield identity_handler.unbind_threepid(
+ user_id,
+ {
+ 'medium': medium,
+ 'address': address,
+ },
+ )
+
ret = yield self.store.user_delete_threepid(
user_id, medium, address,
)
@@ -634,10 +852,17 @@ class AuthHandler(BaseHandler):
password (str): Password to hash.
Returns:
- Hashed password (str).
+ Deferred(str): Hashed password.
"""
- return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
- bcrypt.gensalt(self.bcrypt_rounds))
+ def _do_hash():
+ return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
+ bcrypt.gensalt(self.bcrypt_rounds))
+
+ return make_deferred_yieldable(
+ threads.deferToThreadPool(
+ self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash
+ ),
+ )
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@@ -647,20 +872,31 @@ class AuthHandler(BaseHandler):
stored_hash (str): Expected hash value.
Returns:
- Whether self.hash(password) == stored_hash (bool).
+ Deferred(bool): Whether self.hash(password) == stored_hash.
"""
+
+ def _do_validate_hash():
+ return bcrypt.checkpw(
+ password.encode('utf8') + self.hs.config.password_pepper,
+ stored_hash.encode('utf8')
+ )
+
if stored_hash:
- return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
- stored_hash.encode('utf8')) == stored_hash
+ return make_deferred_yieldable(
+ threads.deferToThreadPool(
+ self.hs.get_reactor(),
+ self.hs.get_reactor().getThreadPool(),
+ _do_validate_hash,
+ ),
+ )
else:
- return False
+ return defer.succeed(False)
-class MacaroonGeneartor(object):
- def __init__(self, hs):
- self.clock = hs.get_clock()
- self.server_name = hs.config.server_name
- self.macaroon_secret_key = hs.config.macaroon_secret_key
+@attr.s
+class MacaroonGenerator(object):
+
+ hs = attr.ib()
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
@@ -678,7 +914,7 @@ class MacaroonGeneartor(object):
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
- now = self.clock.time_msec()
+ now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
@@ -690,36 +926,9 @@ class MacaroonGeneartor(object):
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
- location=self.server_name,
+ location=self.hs.config.server_name,
identifier="key",
- key=self.macaroon_secret_key)
+ key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
-
-
-class _AccountHandler(object):
- """A proxy object that gets passed to password auth providers so they
- can register new users etc if necessary.
- """
- def __init__(self, hs, check_user_exists):
- self.hs = hs
-
- self._check_user_exists = check_user_exists
-
- def check_user_exists(self, user_id):
- """Check if user exissts.
-
- Returns:
- Deferred(bool)
- """
- return self._check_user_exists(user_id)
-
- def register(self, localpart):
- """Registers a new user with given localpart
-
- Returns:
- Deferred: a 2-tuple of (user_id, access_token)
- """
- reg = self.hs.get_handlers().registration_handler
- return reg.register(localpart=localpart)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
new file mode 100644
index 0000000000..b3c5a9ee64
--- /dev/null
+++ b/synapse/handlers/deactivate_account.py
@@ -0,0 +1,163 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017, 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import UserID, create_requester
+from synapse.util.logcontext import run_in_background
+
+from ._base import BaseHandler
+
+logger = logging.getLogger(__name__)
+
+
+class DeactivateAccountHandler(BaseHandler):
+ """Handler which deals with deactivating user accounts."""
+ def __init__(self, hs):
+ super(DeactivateAccountHandler, self).__init__(hs)
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
+ self._room_member_handler = hs.get_room_member_handler()
+ self._identity_handler = hs.get_handlers().identity_handler
+ self.user_directory_handler = hs.get_user_directory_handler()
+
+ # Flag that indicates whether the process to part users from rooms is running
+ self._user_parter_running = False
+
+ # Start the user parter loop so it can resume parting users from rooms where
+ # it left off (if it has work left to do).
+ hs.get_reactor().callWhenRunning(self._start_user_parting)
+
+ @defer.inlineCallbacks
+ def deactivate_account(self, user_id, erase_data):
+ """Deactivate a user's account
+
+ Args:
+ user_id (str): ID of user to be deactivated
+ erase_data (bool): whether to GDPR-erase the user's data
+
+ Returns:
+ Deferred
+ """
+ # FIXME: Theoretically there is a race here wherein user resets
+ # password using threepid.
+
+ # delete threepids first. We remove these from the IS so if this fails,
+ # leave the user still active so they can try again.
+ # Ideally we would prevent password resets and then do this in the
+ # background thread.
+ threepids = yield self.store.user_get_threepids(user_id)
+ for threepid in threepids:
+ try:
+ yield self._identity_handler.unbind_threepid(
+ user_id,
+ {
+ 'medium': threepid['medium'],
+ 'address': threepid['address'],
+ },
+ )
+ except Exception:
+ # Do we want this to be a fatal error or should we carry on?
+ logger.exception("Failed to remove threepid from ID server")
+ raise SynapseError(400, "Failed to remove threepid from ID server")
+ yield self.store.user_delete_threepid(
+ user_id, threepid['medium'], threepid['address'],
+ )
+
+ # delete any devices belonging to the user, which will also
+ # delete corresponding access tokens.
+ yield self._device_handler.delete_all_devices_for_user(user_id)
+ # then delete any remaining access tokens which weren't associated with
+ # a device.
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
+
+ yield self.store.user_set_password_hash(user_id, None)
+
+ # Add the user to a table of users pending deactivation (ie.
+ # removal from all the rooms they're a member of)
+ yield self.store.add_user_pending_deactivation(user_id)
+
+ # delete from user directory
+ yield self.user_directory_handler.handle_user_deactivated(user_id)
+
+ # Mark the user as erased, if they asked for that
+ if erase_data:
+ logger.info("Marking %s as erased", user_id)
+ yield self.store.mark_user_erased(user_id)
+
+ # Now start the process that goes through that list and
+ # parts users from rooms (if it isn't already running)
+ self._start_user_parting()
+
+ def _start_user_parting(self):
+ """
+ Start the process that goes through the table of users
+ pending deactivation, if it isn't already running.
+
+ Returns:
+ None
+ """
+ if not self._user_parter_running:
+ run_in_background(self._user_parter_loop)
+
+ @defer.inlineCallbacks
+ def _user_parter_loop(self):
+ """Loop that parts deactivated users from rooms
+
+ Returns:
+ None
+ """
+ self._user_parter_running = True
+ logger.info("Starting user parter")
+ try:
+ while True:
+ user_id = yield self.store.get_user_pending_deactivation()
+ if user_id is None:
+ break
+ logger.info("User parter parting %r", user_id)
+ yield self._part_user(user_id)
+ yield self.store.del_user_pending_deactivation(user_id)
+ logger.info("User parter finished parting %r", user_id)
+ logger.info("User parter finished: stopping")
+ finally:
+ self._user_parter_running = False
+
+ @defer.inlineCallbacks
+ def _part_user(self, user_id):
+ """Causes the given user_id to leave all the rooms they're joined to
+
+ Returns:
+ None
+ """
+ user = UserID.from_string(user_id)
+
+ rooms_for_user = yield self.store.get_rooms_for_user(user_id)
+ for room_id in rooms_for_user:
+ logger.info("User parter parting %r from %r", user_id, room_id)
+ try:
+ yield self._room_member_handler.update_membership(
+ create_requester(user),
+ user,
+ room_id,
+ "leave",
+ ratelimit=False,
+ )
+ except Exception:
+ logger.exception(
+ "Failed to part user %r from room %r: ignoring and continuing",
+ user_id, room_id,
+ )
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index ed60d494ff..2d44f15da3 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -12,18 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from six import iteritems, itervalues
+
+from twisted.internet import defer
+
from synapse.api import errors
from synapse.api.constants import EventTypes
+from synapse.api.errors import FederationDeniedError
+from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.retryutils import NotRetryingDestination
from synapse.util.metrics import measure_func
-from synapse.types import get_domain_from_id, RoomStreamToken
-from twisted.internet import defer
-from ._base import BaseHandler
+from synapse.util.retryutils import NotRetryingDestination
-import logging
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -34,15 +39,17 @@ class DeviceHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
+ self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender()
- self.federation = hs.get_replication_layer()
self._edu_updater = DeviceListEduUpdater(hs, self)
- self.federation.register_edu_handler(
+ federation_registry = hs.get_federation_registry()
+
+ federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update,
)
- self.federation.register_query_handler(
+ federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
)
@@ -109,7 +116,7 @@ class DeviceHandler(BaseHandler):
user_id, device_id=None
)
- devices = device_map.values()
+ devices = list(device_map.values())
for device in devices:
_update_device_from_client_ips(device, ips)
@@ -152,16 +159,15 @@ class DeviceHandler(BaseHandler):
try:
yield self.store.delete_device(user_id, device_id)
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
# no match
pass
else:
raise
- yield self.store.user_delete_access_tokens(
+ yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
- delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
@@ -171,12 +177,30 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
+ def delete_all_devices_for_user(self, user_id, except_device_id=None):
+ """Delete all of the user's devices
+
+ Args:
+ user_id (str):
+ except_device_id (str|None): optional device id which should not
+ be deleted
+
+ Returns:
+ defer.Deferred:
+ """
+ device_map = yield self.store.get_devices_by_user(user_id)
+ device_ids = list(device_map)
+ if except_device_id is not None:
+ device_ids = [d for d in device_ids if d != except_device_id]
+ yield self.delete_devices(user_id, device_ids)
+
+ @defer.inlineCallbacks
def delete_devices(self, user_id, device_ids):
""" Delete several devices
Args:
user_id (str):
- device_ids (str): The list of device IDs to delete
+ device_ids (List[str]): The list of device IDs to delete
Returns:
defer.Deferred:
@@ -184,7 +208,7 @@ class DeviceHandler(BaseHandler):
try:
yield self.store.delete_devices(user_id, device_ids)
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
# no match
pass
@@ -194,9 +218,8 @@ class DeviceHandler(BaseHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
- yield self.store.user_delete_access_tokens(
+ yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
- delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
@@ -224,7 +247,7 @@ class DeviceHandler(BaseHandler):
new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, [device_id])
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
raise errors.NotFoundError()
else:
@@ -270,6 +293,8 @@ class DeviceHandler(BaseHandler):
user_id (str)
from_token (StreamToken)
"""
+ now_token = yield self.hs.get_event_sources().get_current_token()
+
room_ids = yield self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed
@@ -280,11 +305,30 @@ class DeviceHandler(BaseHandler):
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
+ member_events = yield self.store.get_membership_changes_for_user(
+ user_id, from_token.room_key, now_token.room_key
+ )
+ rooms_changed.update(event.room_id for event in member_events)
+
stream_ordering = RoomStreamToken.parse_stream_token(
- from_token.room_key).stream
+ from_token.room_key
+ ).stream
possibly_changed = set(changed)
+ possibly_left = set()
for room_id in rooms_changed:
+ current_state_ids = yield self.store.get_current_state_ids(room_id)
+
+ # The user may have left the room
+ # TODO: Check if they actually did or if we were just invited.
+ if room_id not in room_ids:
+ for key, event_id in iteritems(current_state_ids):
+ etype, state_key = key
+ if etype != EventTypes.Member:
+ continue
+ possibly_left.add(state_key)
+ continue
+
# Fetch the current state at the time.
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
@@ -295,44 +339,69 @@ class DeviceHandler(BaseHandler):
# ordering: treat it the same as a new room
event_ids = []
- current_state_ids = yield self.store.get_current_state_ids(room_id)
-
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
continue
+ current_member_id = current_state_ids.get((EventTypes.Member, user_id))
+ if not current_member_id:
+ continue
+
# mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
+ # Check if we've joined the room? If so we just blindly add all the users to
+ # the "possibly changed" users.
+ for state_dict in itervalues(prev_state_ids):
+ member_event = state_dict.get((EventTypes.Member, user_id), None)
+ if not member_event or member_event != current_member_id:
+ for key, event_id in iteritems(current_state_ids):
+ etype, state_key = key
+ if etype != EventTypes.Member:
+ continue
+ possibly_changed.add(state_key)
+ break
+
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
- for state_dict in prev_state_ids.values():
+ for state_dict in itervalues(prev_state_ids):
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
- possibly_changed.add(state_key)
+ if state_key != user_id:
+ possibly_changed.add(state_key)
break
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
- user_id
- )
+ if possibly_changed or possibly_left:
+ users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ user_id
+ )
- # Take the intersection of the users whose devices may have changed
- # and those that actually still share a room with the user
- defer.returnValue(users_who_share_room & possibly_changed)
+ # Take the intersection of the users whose devices may have changed
+ # and those that actually still share a room with the user
+ possibly_joined = possibly_changed & users_who_share_room
+ possibly_left = (possibly_changed | possibly_left) - users_who_share_room
+ else:
+ possibly_joined = []
+ possibly_left = []
+
+ defer.returnValue({
+ "changed": list(possibly_joined),
+ "left": list(possibly_left),
+ })
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
@@ -366,7 +435,7 @@ class DeviceListEduUpdater(object):
def __init__(self, hs, device_handler):
self.store = hs.get_datastore()
- self.federation = hs.get_replication_layer()
+ self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.device_handler = device_handler
@@ -450,6 +519,9 @@ class DeviceListEduUpdater(object):
# This makes it more likely that the device lists will
# eventually become consistent.
return
+ except FederationDeniedError as e:
+ logger.info(e)
+ return
except Exception:
# TODO: Remember that we are now out of sync and try again
# later
@@ -467,7 +539,7 @@ class DeviceListEduUpdater(object):
yield self.device_handler.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
- # change (becuase of the single prev_id matching the current cache)
+ # change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index f7fad15c62..2e2e5261de 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -17,10 +17,10 @@ import logging
from twisted.internet import defer
-from synapse.types import get_domain_from_id
+from synapse.api.errors import SynapseError
+from synapse.types import UserID, get_domain_from_id
from synapse.util.stringutils import random_string
-
logger = logging.getLogger(__name__)
@@ -33,10 +33,10 @@ class DeviceMessageHandler(object):
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender()
- hs.get_replication_layer().register_edu_handler(
+ hs.get_federation_registry().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
@@ -52,6 +52,12 @@ class DeviceMessageHandler(object):
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
+ logger.warning("Request for keys for non-local user %s",
+ user_id)
+ raise SynapseError(400, "Not a user here")
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -77,7 +83,8 @@ class DeviceMessageHandler(object):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
device_id: {
"content": message_content,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 943554ce98..ef866da1b6 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -14,15 +14,16 @@
# limitations under the License.
+import logging
+import string
+
from twisted.internet import defer
-from ._base import BaseHandler
-from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError, CodeMessageException, Codes, SynapseError
from synapse.types import RoomAlias, UserID, get_domain_from_id
-import logging
-import string
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -34,12 +35,15 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
- self.federation = hs.get_replication_layer()
- self.federation.register_query_handler(
+ self.federation = hs.get_federation_client()
+ hs.get_federation_registry().register_query_handler(
"directory", self.on_directory_query
)
+ self.spam_checker = hs.get_spam_checker()
+
@defer.inlineCallbacks
def _create_association(self, room_alias, room_id, servers=None, creator=None):
# general association creation for both human users and app services
@@ -73,6 +77,11 @@ class DirectoryHandler(BaseHandler):
# association creation for human users
# TODO(erikj): Do user auth.
+ if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+ raise SynapseError(
+ 403, "This user is not permitted to create this alias",
+ )
+
can_create = yield self.can_modify_alias(
room_alias,
user_id=user_id
@@ -242,8 +251,7 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Aliases,
@@ -265,8 +273,7 @@ class DirectoryHandler(BaseHandler):
if not alias_event or alias_event.content.get("alias", "") != alias_str:
return
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@@ -327,6 +334,14 @@ class DirectoryHandler(BaseHandler):
room_id (str)
visibility (str): "public" or "private"
"""
+ if not self.spam_checker.user_may_publish_room(
+ requester.user.to_string(), room_id
+ ):
+ raise AuthError(
+ 403,
+ "This user is not permitted to publish rooms to the room list"
+ )
+
if requester.is_guest:
raise AuthError(403, "Guests cannot edit the published room list")
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 668a90e495..5816bf8b4f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import ujson as json
import logging
-from canonicaljson import encode_canonical_json
+from six import iteritems
+
+from canonicaljson import encode_canonical_json, json
+
from twisted.internet import defer
-from synapse.api.errors import SynapseError, CodeMessageException
-from synapse.types import get_domain_from_id
-from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
+from synapse.api.errors import CodeMessageException, FederationDeniedError, SynapseError
+from synapse.types import UserID, get_domain_from_id
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -30,15 +33,15 @@ logger = logging.getLogger(__name__)
class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- self.federation = hs.get_replication_layer()
+ self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
- self.federation.register_query_handler(
+ hs.get_federation_registry().register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@@ -70,12 +73,13 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
remote_queries[user_id] = device_ids
- # Firt get local devices.
+ # First get local devices.
failures = {}
results = {}
if local_query:
@@ -88,7 +92,7 @@ class E2eKeysHandler(object):
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
- for user_id, device_ids in remote_queries.iteritems():
+ for user_id, device_ids in iteritems(remote_queries):
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
@@ -99,9 +103,9 @@ class E2eKeysHandler(object):
query_list
)
)
- for user_id, devices in remote_results.iteritems():
+ for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
- for device_id, device in devices.iteritems():
+ for device_id, device in iteritems(devices):
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
@@ -131,24 +135,13 @@ class E2eKeysHandler(object):
if user_id in destination_query:
results[user_id] = keys
- except CodeMessageException as e:
- failures[destination] = {
- "status": e.code, "message": e.message
- }
- except NotRetryingDestination as e:
- failures[destination] = {
- "status": 503, "message": "Not ready for retry",
- }
except Exception as e:
- # include ConnectionRefused and other errors
- failures[destination] = {
- "status": 503, "message": e.message
- }
+ failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([
- preserve_fn(do_remote_query)(destination)
+ run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
- ]))
+ ], consumeErrors=True))
defer.returnValue({
"device_keys": results, "failures": failures,
@@ -170,7 +163,8 @@ class E2eKeysHandler(object):
result_dict = {}
for user_id, device_ids in query.items():
- if not self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
@@ -213,7 +207,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
@@ -243,32 +238,21 @@ class E2eKeysHandler(object):
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
- except CodeMessageException as e:
- failures[destination] = {
- "status": e.code, "message": e.message
- }
- except NotRetryingDestination as e:
- failures[destination] = {
- "status": 503, "message": "Not ready for retry",
- }
except Exception as e:
- # include ConnectionRefused and other errors
- failures[destination] = {
- "status": 503, "message": e.message
- }
+ failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([
- preserve_fn(claim_client_keys)(destination)
+ run_in_background(claim_client_keys, destination)
for destination in remote_queries
- ]))
+ ], consumeErrors=True))
logger.info(
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in json_result.iteritems()
- for device_id, device_keys in user_keys.iteritems()
- for key_id, _ in device_keys.iteritems()
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
)),
)
@@ -353,6 +337,31 @@ class E2eKeysHandler(object):
)
+def _exception_to_failure(e):
+ if isinstance(e, CodeMessageException):
+ return {
+ "status": e.code, "message": e.message,
+ }
+
+ if isinstance(e, NotRetryingDestination):
+ return {
+ "status": 503, "message": "Not ready for retry",
+ }
+
+ if isinstance(e, FederationDeniedError):
+ return {
+ "status": 403, "message": "Federation Denied",
+ }
+
+ # include ConnectionRefused and other errors
+ #
+ # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
+ # give a string for e.message, which json then fails to serialize.
+ return {
+ "status": 503, "message": str(e.message),
+ }
+
+
def _one_time_keys_match(old_key_json, new_key):
old_key = json.loads(old_key_json)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d3685fb12a..c3f2d7feff 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -13,20 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import random
+
from twisted.internet import defer
-from synapse.util.logutils import log_function
-from synapse.types import UserID
-from synapse.events.utils import serialize_event
-from synapse.api.constants import Membership, EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
+from synapse.events.utils import serialize_event
+from synapse.types import UserID
+from synapse.util.logutils import log_function
from ._base import BaseHandler
-import logging
-import random
-
-
logger = logging.getLogger(__name__)
@@ -48,6 +47,7 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
+ self._server_notices_sender = hs.get_server_notices_sender()
@defer.inlineCallbacks
@log_function
@@ -58,6 +58,10 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down.
"""
+
+ # send any outstanding server notices to the user.
+ yield self._server_notices_sender.on_user_syncing(auth_user_id)
+
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler()
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 483cb8eac6..145c1a21d4 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,41 +15,46 @@
# limitations under the License.
"""Contains handlers for federation events."""
-import synapse.util.logcontext
+
+import itertools
+import logging
+import sys
+
+import six
+from six import iteritems, itervalues
+from six.moves import http_client, zip
+
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
-from ._base import BaseHandler
+from twisted.internet import defer
-from synapse.api.errors import (
- AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
-)
from synapse.api.constants import EventTypes, Membership, RejectedReason
-from synapse.events.validator import EventValidator
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import (
- preserve_fn, preserve_context_over_deferred
+from synapse.api.errors import (
+ AuthError,
+ CodeMessageException,
+ FederationDeniedError,
+ FederationError,
+ StoreError,
+ SynapseError,
)
-from synapse.util.metrics import measure_func
-from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor, Linearizer
-from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import (
- compute_event_signature, add_hashes_and_signatures,
+ add_hashes_and_signatures,
+ compute_event_signature,
)
+from synapse.events.validator import EventValidator
+from synapse.state import resolve_events_with_factory
from synapse.types import UserID, get_domain_from_id
-
-from synapse.events.utils import prune_event
-
-from synapse.util.retryutils import NotRetryingDestination
-
+from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async import Linearizer
from synapse.util.distributor import user_joined_room
+from synapse.util.frozenutils import unfreeze
+from synapse.util.logutils import log_function
+from synapse.util.retryutils import NotRetryingDestination
+from synapse.visibility import filter_events_for_server
-from twisted.internet import defer
-
-import itertools
-import logging
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -70,14 +76,16 @@ class FederationHandler(BaseHandler):
self.hs = hs
self.store = hs.get_datastore()
- self.replication_layer = hs.get_replication_layer()
+ self.replication_layer = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.action_generator = hs.get_action_generator()
self.is_mine_id = hs.is_mine_id
-
- self.replication_layer.set_handler(self)
+ self.pusher_pool = hs.get_pusherpool()
+ self.spam_checker = hs.get_spam_checker()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self._server_notices_mxid = hs.config.server_notices_mxid
# When joining a room we need to queue any events for that room up
self.room_queues = {}
@@ -85,7 +93,9 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def on_receive_pdu(self, origin, pdu, get_missing=True):
+ def on_receive_pdu(
+ self, origin, pdu, get_missing=True, sent_to_us_directly=False,
+ ):
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -99,8 +109,10 @@ class FederationHandler(BaseHandler):
"""
# We reprocess pdus when we have seen them only as outliers
- existing = yield self.get_persisted_pdu(
- origin, pdu.event_id, do_auth=False
+ existing = yield self.store.get_event(
+ pdu.event_id,
+ allow_none=True,
+ allow_rejected=True,
)
# FIXME: Currently we fetch an event again when we already have it
@@ -116,6 +128,19 @@ class FederationHandler(BaseHandler):
logger.debug("Already seen pdu %s", pdu.event_id)
return
+ # do some initial sanity-checking of the event. In particular, make
+ # sure it doesn't have hundreds of prev_events or auth_events, which
+ # could cause a huge state resolution or cascade of event fetches.
+ try:
+ self._sanity_check_event(pdu)
+ except SynapseError as err:
+ raise FederationError(
+ "ERROR",
+ err.code,
+ err.msg,
+ affected=pdu.event_id,
+ )
+
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if pdu.room_id in self.room_queues:
@@ -124,15 +149,30 @@ class FederationHandler(BaseHandler):
self.room_queues[pdu.room_id].append((pdu, origin))
return
- state = None
-
- auth_chain = []
-
- have_seen = yield self.store.have_events(
- [ev for ev, _ in pdu.prev_events]
+ # If we're no longer in the room just ditch the event entirely. This
+ # is probably an old server that has come back and thinks we're still
+ # in the room (or we've been rejoined to the room by a state reset).
+ #
+ # If we were never in the room then maybe our database got vaped and
+ # we should check if we *are* in fact in the room. If we are then we
+ # can magically rejoin the room.
+ is_in_room = yield self.auth.check_host_in_room(
+ pdu.room_id,
+ self.server_name
)
+ if not is_in_room:
+ was_in_room = yield self.store.was_host_joined(
+ pdu.room_id, self.server_name,
+ )
+ if was_in_room:
+ logger.info(
+ "Ignoring PDU %s for room %s from %s as we've left the room!",
+ pdu.event_id, pdu.room_id, origin,
+ )
+ defer.returnValue(None)
- fetch_state = False
+ state = None
+ auth_chain = []
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
@@ -147,7 +187,7 @@ class FederationHandler(BaseHandler):
)
prevs = {e_id for e_id, _ in pdu.prev_events}
- seen = set(have_seen.keys())
+ seen = yield self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
@@ -175,8 +215,7 @@ class FederationHandler(BaseHandler):
# Update the set of things we've seen after trying to
# fetch the missing stuff
- have_seen = yield self.store.have_events(prevs)
- seen = set(have_seen.iterkeys())
+ seen = yield self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@@ -189,26 +228,60 @@ class FederationHandler(BaseHandler):
list(prevs - seen)[:5],
)
- if prevs - seen:
- logger.info(
- "Still missing %d events for room %r: %r...",
- len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ if sent_to_us_directly and prevs - seen:
+ # If they have sent it to us directly, and the server
+ # isn't telling us about the auth events that it's
+ # made a message referencing, we explode
+ raise FederationError(
+ "ERROR",
+ 403,
+ (
+ "Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ affected=pdu.event_id,
)
- fetch_state = True
+ elif prevs - seen:
+ # Calculate the state of the previous events, and
+ # de-conflict them to find the current state.
+ state_groups = []
+ auth_chains = set()
+ try:
+ # Get the state of the events we know about
+ ours = yield self.store.get_state_groups(pdu.room_id, list(seen))
+ state_groups.append(ours)
+
+ # Ask the remote server for the states we don't
+ # know about
+ for p in prevs - seen:
+ state, got_auth_chain = (
+ yield self.replication_layer.get_state_for_room(
+ origin, pdu.room_id, p
+ )
+ )
+ auth_chains.update(got_auth_chain)
+ state_group = {(x.type, x.state_key): x.event_id for x in state}
+ state_groups.append(state_group)
+
+ # Resolve any conflicting state
+ def fetch(ev_ids):
+ return self.store.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False
+ )
- if fetch_state:
- # We need to get the state at this event, since we haven't
- # processed all the prev events.
- logger.debug(
- "_handle_new_pdu getting state for %s",
- pdu.room_id
- )
- try:
- state, auth_chain = yield self.replication_layer.get_state_for_room(
- origin, pdu.room_id, pdu.event_id,
- )
- except:
- logger.exception("Failed to get state for event: %s", pdu.event_id)
+ state_map = yield resolve_events_with_factory(
+ state_groups, {pdu.event_id: pdu}, fetch
+ )
+
+ state = (yield self.store.get_events(state_map.values())).values()
+ auth_chain = list(auth_chains)
+ except Exception:
+ raise FederationError(
+ "ERROR",
+ 403,
+ "We can't get valid state history.",
+ affected=pdu.event_id,
+ )
yield self._process_received_pdu(
origin,
@@ -227,8 +300,7 @@ class FederationHandler(BaseHandler):
min_depth (int): Minimum depth of events to return.
"""
# We recalculate seen, since it may have changed.
- have_seen = yield self.store.have_events(prevs)
- seen = set(have_seen.keys())
+ seen = yield self.store.have_seen_events(prevs)
if not prevs - seen:
return
@@ -287,11 +359,17 @@ class FederationHandler(BaseHandler):
for e in missing_events:
logger.info("Handling found event %s", e.event_id)
- yield self.on_receive_pdu(
- origin,
- e,
- get_missing=False
- )
+ try:
+ yield self.on_receive_pdu(
+ origin,
+ e,
+ get_missing=False
+ )
+ except FederationError as e:
+ if e.code == 403:
+ logger.warn("Event %s failed history check.")
+ else:
+ raise
@log_function
@defer.inlineCallbacks
@@ -340,9 +418,7 @@ class FederationHandler(BaseHandler):
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
- seen_ids = set(
- (yield self.store.have_events(event_ids)).keys()
- )
+ seen_ids = yield self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
@@ -410,7 +486,10 @@ class FederationHandler(BaseHandler):
# joined the room. Don't bother if the user is just
# changing their profile info.
newly_joined = True
- prev_state_id = context.prev_state_ids.get(
+
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
+ prev_state_id = prev_state_ids.get(
(event.type, event.state_key)
)
if prev_state_id:
@@ -424,91 +503,21 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
- @measure_func("_filter_events_for_server")
- @defer.inlineCallbacks
- def _filter_events_for_server(self, server_name, room_id, events):
- event_to_state_ids = yield self.store.get_state_ids_for_events(
- frozenset(e.event_id for e in events),
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, None),
- )
- )
-
- # We only want to pull out member events that correspond to the
- # server's domain.
-
- def check_match(id):
- try:
- return server_name == get_domain_from_id(id)
- except:
- return False
-
- # Parses mapping `event_id -> (type, state_key) -> state event_id`
- # to get all state ids that we're interested in.
- event_map = yield self.store.get_events([
- e_id
- for key_to_eid in event_to_state_ids.values()
- for key, e_id in key_to_eid.items()
- if key[0] != EventTypes.Member or check_match(key[1])
- ])
-
- event_to_state = {
- e_id: {
- key: event_map[inner_e_id]
- for key, inner_e_id in key_to_eid.items()
- if inner_e_id in event_map
- }
- for e_id, key_to_eid in event_to_state_ids.items()
- }
-
- def redact_disallowed(event, state):
- if not state:
- return event
-
- history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
- if history:
- visibility = history.content.get("history_visibility", "shared")
- if visibility in ["invited", "joined"]:
- # We now loop through all state events looking for
- # membership states for the requesting server to determine
- # if the server is either in the room or has been invited
- # into the room.
- for ev in state.values():
- if ev.type != EventTypes.Member:
- continue
- try:
- domain = get_domain_from_id(ev.state_key)
- except:
- continue
-
- if domain != server_name:
- continue
-
- memtype = ev.membership
- if memtype == Membership.JOIN:
- return event
- elif memtype == Membership.INVITE:
- if visibility == "invited":
- return event
- else:
- return prune_event(event)
-
- return event
-
- defer.returnValue([
- redact_disallowed(e, event_to_state[e.event_id])
- for e in events
- ])
-
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
- This will attempt to get more events from the remote. This may return
- be successfull and still return no events if the other side has no new
- events to offer.
+ This will attempt to get more events from the remote. If the other side
+ has no new events to offer, this will return an empty list.
+
+ As the events are received, we check their signatures, and also do some
+ sanity-checking on them. If any of the backfilled events are invalid,
+ this method throws a SynapseError.
+
+ TODO: make this more useful to distinguish failures of the remote
+ server from invalid events (there is probably no point in trying to
+ re-fetch invalid events from every other HS in the room.)
"""
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
@@ -520,6 +529,16 @@ class FederationHandler(BaseHandler):
extremities=extremities,
)
+ # ideally we'd sanity check the events here for excess prev_events etc,
+ # but it's hard to reject events at this point without completely
+ # breaking backfill in the same way that it is currently broken by
+ # events whose signature we cannot verify (#3121).
+ #
+ # So for now we accept the events anyway. #3124 tracks this.
+ #
+ # for ev in events:
+ # self._sanity_check_event(ev)
+
# Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline(
set(e.event_id for e in events)
@@ -590,9 +609,10 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch
)
- results = yield preserve_context_over_deferred(defer.gatherResults(
+ results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.replication_layer.get_pdu)(
+ logcontext.run_in_background(
+ self.replication_layer.get_pdu,
[dest],
event_id,
outlier=True,
@@ -612,7 +632,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events)
- seen_events = yield self.store.have_events(
+ seen_events = yield self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)
@@ -702,9 +722,19 @@ class FederationHandler(BaseHandler):
curr_state = yield self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
+ """Get joined domains from state
+
+ Args:
+ state (dict[tuple, FrozenEvent]): State map from type/state
+ key to event.
+
+ Returns:
+ list[tuple[str, int]]: Returns a list of servers with the
+ lowest depth of their joins. Sorted by lowest depth first.
+ """
joined_users = [
(state_key, int(event.depth))
- for (e_type, state_key), event in state.items()
+ for (e_type, state_key), event in iteritems(state)
if e_type == EventTypes.Member
and event.membership == Membership.JOIN
]
@@ -718,7 +748,7 @@ class FederationHandler(BaseHandler):
joined_domains[dom] = min(d, old_d)
else:
joined_domains[dom] = d
- except:
+ except Exception:
pass
return sorted(joined_domains.items(), key=lambda d: d[1])
@@ -738,7 +768,7 @@ class FederationHandler(BaseHandler):
yield self.backfill(
dom, room_id,
limit=100,
- extremities=[e for e in extremities.keys()]
+ extremities=extremities,
)
# If this succeeded then we probably already have the
# appropriate stuff.
@@ -762,6 +792,9 @@ class FederationHandler(BaseHandler):
except NotRetryingDestination as e:
logger.info(e.message)
continue
+ except FederationDeniedError as e:
+ logger.info(e)
+ continue
except Exception as e:
logger.exception(
"Failed to backfill from %s because %s",
@@ -784,38 +817,76 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
- states = yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
- for e in event_ids
- ]))
+ resolve = logcontext.preserve_fn(
+ self.state_handler.resolve_state_groups_for_events
+ )
+ states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
+ [resolve(room_id, [e]) for e in event_ids],
+ consumeErrors=True,
+ ))
+
+ # dict[str, dict[tuple, str]], a map from event_id to state map of
+ # event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
- [e_id for ids in states.values() for e_id in ids],
+ [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False
)
states = {
key: {
k: state_map[e_id]
- for k, e_id in state_dict.items()
+ for k, e_id in iteritems(state_dict)
if e_id in state_map
- } for key, state_dict in states.items()
+ } for key, state_dict in iteritems(states)
}
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill([
- dom for dom in likely_domains
+ dom for dom, _ in likely_domains
if dom not in tried_domains
])
if success:
defer.returnValue(True)
- tried_domains.update(likely_domains)
+ tried_domains.update(dom for dom, _ in likely_domains)
defer.returnValue(False)
+ def _sanity_check_event(self, ev):
+ """
+ Do some early sanity checks of a received event
+
+ In particular, checks it doesn't have an excessive number of
+ prev_events or auth_events, which could cause a huge state resolution
+ or cascade of event fetches.
+
+ Args:
+ ev (synapse.events.EventBase): event to be checked
+
+ Returns: None
+
+ Raises:
+ SynapseError if the event does not pass muster
+ """
+ if len(ev.prev_events) > 20:
+ logger.warn("Rejecting event %s which has %i prev_events",
+ ev.event_id, len(ev.prev_events))
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Too many prev_events",
+ )
+
+ if len(ev.auth_events) > 10:
+ logger.warn("Rejecting event %s which has %i auth_events",
+ ev.event_id, len(ev.auth_events))
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Too many auth_events",
+ )
+
@defer.inlineCallbacks
def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing.
@@ -838,16 +909,6 @@ class FederationHandler(BaseHandler):
[auth_id for auth_id, _ in event.auth_events],
include_given=True
)
-
- for event in auth:
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
- )
- )
-
defer.returnValue([e for e in auth])
@log_function
@@ -916,7 +977,7 @@ class FederationHandler(BaseHandler):
room_creator_user_id="",
is_public=False
)
- except:
+ except Exception:
# FIXME
pass
@@ -940,9 +1001,7 @@ class FederationHandler(BaseHandler):
# lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it.
- synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
- room_queue
- )
+ logcontext.run_in_background(self._handle_queued_pdus, room_queue)
defer.returnValue(True)
@@ -982,8 +1041,7 @@ class FederationHandler(BaseHandler):
})
try:
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
except AuthError as e:
@@ -1051,13 +1109,15 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
- state_ids = context.prev_state_ids.values()
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
+ state_ids = list(prev_state_ids.values())
auth_chain = yield self.store.get_auth_chain(state_ids)
- state = yield self.store.get_events(context.prev_state_ids.values())
+ state = yield self.store.get_events(list(prev_state_ids.values()))
defer.returnValue({
- "state": state.values(),
+ "state": list(state.values()),
"auth_chain": auth_chain,
})
@@ -1069,10 +1129,23 @@ class FederationHandler(BaseHandler):
"""
event = pdu
+ if event.state_key is None:
+ raise SynapseError(400, "The invite event did not have a state key")
+
is_blocked = yield self.store.is_room_blocked(event.room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
+ if self.hs.config.block_non_admin_invites:
+ raise SynapseError(403, "This server does not accept room invites")
+
+ if not self.spam_checker.user_may_invite(
+ event.sender, event.state_key, event.room_id,
+ ):
+ raise SynapseError(
+ 403, "This user is not permitted to send invites to this server/user"
+ )
+
membership = event.content.get("membership")
if event.type != EventTypes.Member or membership != Membership.INVITE:
raise SynapseError(400, "The event was not an m.room.member invite event")
@@ -1081,12 +1154,16 @@ class FederationHandler(BaseHandler):
if sender_domain != origin:
raise SynapseError(400, "The invite event was not from the server sending it")
- if event.state_key is None:
- raise SynapseError(400, "The invite event did not have a state key")
-
if not self.is_mine_id(event.state_key):
raise SynapseError(400, "The invite event must be for this server")
+ # block any attempts to invite the server notices mxid
+ if event.state_key == self._server_notices_mxid:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "Cannot invite this user",
+ )
+
event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
@@ -1213,8 +1290,7 @@ class FederationHandler(BaseHandler):
"state_key": user_id,
})
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
@@ -1268,14 +1344,12 @@ class FederationHandler(BaseHandler):
def get_state_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
- yield run_on_reactor()
-
state_groups = yield self.store.get_state_groups(
room_id, [event_id]
)
if state_groups:
- _, state = state_groups.items().pop()
+ _, state = list(iteritems(state_groups)).pop()
results = {
(e.type, e.state_key): e for e in state
}
@@ -1291,19 +1365,7 @@ class FederationHandler(BaseHandler):
else:
del results[(event.type, event.state_key)]
- res = results.values()
- for event in res:
- # We sign these again because there was a bug where we
- # incorrectly signed things the first time round
- if self.is_mine_id(event.event_id):
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
- )
- )
-
+ res = list(results.values())
defer.returnValue(res)
else:
defer.returnValue([])
@@ -1312,8 +1374,6 @@ class FederationHandler(BaseHandler):
def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
- yield run_on_reactor()
-
state_groups = yield self.store.get_state_groups_ids(
room_id, [event_id]
)
@@ -1332,7 +1392,7 @@ class FederationHandler(BaseHandler):
else:
results.pop((event.type, event.state_key), None)
- defer.returnValue(results.values())
+ defer.returnValue(list(results.values()))
else:
defer.returnValue([])
@@ -1349,17 +1409,26 @@ class FederationHandler(BaseHandler):
limit
)
- events = yield self._filter_events_for_server(origin, room_id, events)
+ events = yield filter_events_for_server(self.store, origin, events)
defer.returnValue(events)
@defer.inlineCallbacks
@log_function
- def get_persisted_pdu(self, origin, event_id, do_auth=True):
- """ Get a PDU from the database with given origin and id.
+ def get_persisted_pdu(self, origin, event_id):
+ """Get an event from the database for the given server.
+
+ Args:
+ origin [str]: hostname of server which is requesting the event; we
+ will check that the server is allowed to see it.
+ event_id [str]: id of the event being requested
Returns:
- Deferred: Results in a `Pdu`.
+ Deferred[EventBase|None]: None if we know nothing about the event;
+ otherwise the (possibly-redacted) event.
+
+ Raises:
+ AuthError if the server is not currently in the room
"""
event = yield self.store.get_event(
event_id,
@@ -1368,32 +1437,17 @@ class FederationHandler(BaseHandler):
)
if event:
- if self.is_mine_id(event.event_id):
- # FIXME: This is a temporary work around where we occasionally
- # return events slightly differently than when they were
- # originally signed
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
- )
- )
-
- if do_auth:
- in_room = yield self.auth.check_host_in_room(
- event.room_id,
- origin
- )
- if not in_room:
- raise AuthError(403, "Host not in room.")
-
- events = yield self._filter_events_for_server(
- origin, event.room_id, [event]
- )
-
- event = events[0]
+ in_room = yield self.auth.check_host_in_room(
+ event.room_id,
+ origin
+ )
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
+ events = yield filter_events_for_server(
+ self.store, origin, [event],
+ )
+ event = events[0]
defer.returnValue(event)
else:
defer.returnValue(None)
@@ -1412,22 +1466,33 @@ class FederationHandler(BaseHandler):
auth_events=auth_events,
)
- if not event.internal_metadata.is_outlier():
- yield self.action_generator.handle_push_actions_for_event(
- event, context
+ try:
+ if not event.internal_metadata.is_outlier() and not backfilled:
+ yield self.action_generator.handle_push_actions_for_event(
+ event, context
+ )
+
+ event_stream_id, max_stream_id = yield self.store.persist_event(
+ event,
+ context=context,
+ backfilled=backfilled,
)
+ except: # noqa: E722, as we reraise the exception this is fine.
+ tp, value, tb = sys.exc_info()
- event_stream_id, max_stream_id = yield self.store.persist_event(
- event,
- context=context,
- backfilled=backfilled,
- )
+ logcontext.run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
+
+ six.reraise(tp, value, tb)
if not backfilled:
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
- preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
- event_stream_id, max_stream_id
+ logcontext.run_in_background(
+ self.pusher_pool.on_new_notifications,
+ event_stream_id, max_stream_id,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1439,22 +1504,23 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations.
"""
- contexts = yield preserve_context_over_deferred(defer.gatherResults(
+ contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self._prep_event)(
+ logcontext.run_in_background(
+ self._prep_event,
origin,
ev_info["event"],
state=ev_info.get("state"),
auth_events=ev_info.get("auth_events"),
)
for ev_info in event_infos
- ]
+ ], consumeErrors=True,
))
yield self.store.persist_events(
[
(ev_info["event"], context)
- for ev_info, context in itertools.izip(event_infos, contexts)
+ for ev_info, context in zip(event_infos, contexts)
],
backfilled=backfilled,
)
@@ -1574,8 +1640,9 @@ class FederationHandler(BaseHandler):
)
if not auth_events:
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
- event, context.prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
@@ -1605,7 +1672,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR
- if event.type == EventTypes.GuestAccess:
+ if event.type == EventTypes.GuestAccess and not context.rejected:
yield self.maybe_kick_guest_users(event)
defer.returnValue(context)
@@ -1635,15 +1702,6 @@ class FederationHandler(BaseHandler):
local_auth_chain, remote_auth_chain
)
- for event in ret["auth_chain"]:
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
- )
- )
-
logger.debug("on_query_auth returning: %s", ret)
defer.returnValue(ret)
@@ -1669,11 +1727,26 @@ class FederationHandler(BaseHandler):
min_depth=min_depth,
)
+ missing_events = yield filter_events_for_server(
+ self.store, origin, missing_events,
+ )
+
defer.returnValue(missing_events)
@defer.inlineCallbacks
@log_function
def do_auth(self, origin, event, context, auth_events):
+ """
+
+ Args:
+ origin (str):
+ event (synapse.events.FrozenEvent):
+ context (synapse.events.snapshot.EventContext):
+ auth_events (dict[(str, str)->str]):
+
+ Returns:
+ defer.Deferred[None]
+ """
# Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events)
@@ -1684,7 +1757,8 @@ class FederationHandler(BaseHandler):
event_key = None
if event_auth_events - current_state:
- have_events = yield self.store.have_events(
+ # TODO: can we use store.have_seen_events here instead?
+ have_events = yield self.store.get_seen_events_with_rejections(
event_auth_events - current_state
)
else:
@@ -1707,12 +1781,12 @@ class FederationHandler(BaseHandler):
origin, event.room_id, event.event_id
)
- seen_remotes = yield self.store.have_events(
+ seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
)
for e in remote_auth_chain:
- if e.event_id in seen_remotes.keys():
+ if e.event_id in seen_remotes:
continue
if e.event_id == event.event_id:
@@ -1739,11 +1813,11 @@ class FederationHandler(BaseHandler):
except AuthError:
pass
- have_events = yield self.store.have_events(
+ have_events = yield self.store.get_seen_events_with_rejections(
[e_id for e_id, _ in event.auth_events]
)
seen_events = set(have_events.keys())
- except:
+ except Exception:
# FIXME:
logger.exception("Failed to get auth chain")
@@ -1756,18 +1830,18 @@ class FederationHandler(BaseHandler):
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
- different_events = yield preserve_context_over_deferred(defer.gatherResults(
- [
- preserve_fn(self.store.get_event)(
+ different_events = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults([
+ logcontext.run_in_background(
+ self.store.get_event,
d,
allow_none=True,
allow_rejected=False,
)
for d in different_auth
if d in have_events and not have_events[d]
- ],
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ ], consumeErrors=True)
+ ).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)
@@ -1777,7 +1851,7 @@ class FederationHandler(BaseHandler):
})
new_state = self.state_handler.resolve_events(
- [local_view.values(), remote_view.values()],
+ [list(local_view.values()), list(remote_view.values())],
event
)
@@ -1786,16 +1860,9 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
- context.current_state_ids = dict(context.current_state_ids)
- context.current_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- if k != event_key
- })
- context.prev_state_ids = dict(context.prev_state_ids)
- context.prev_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- })
- context.state_group = self.store.get_next_state_group()
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth)
@@ -1815,9 +1882,10 @@ class FederationHandler(BaseHandler):
break
if do_resolution:
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
# 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events(
- event, context.prev_state_ids
+ event, prev_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True
@@ -1832,13 +1900,13 @@ class FederationHandler(BaseHandler):
local_auth_chain,
)
- seen_remotes = yield self.store.have_events(
+ seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in result["auth_chain"]]
)
# 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]:
- if ev.event_id in seen_remotes.keys():
+ if ev.event_id in seen_remotes:
continue
if ev.event_id == event.event_id:
@@ -1868,23 +1936,16 @@ class FederationHandler(BaseHandler):
except AuthError:
pass
- except:
+ except Exception:
# FIXME:
logger.exception("Failed to query auth chain")
# 4. Look at rejects and their proofs.
# TODO.
- context.current_state_ids = dict(context.current_state_ids)
- context.current_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- if k != event_key
- })
- context.prev_state_ids = dict(context.prev_state_ids)
- context.prev_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- })
- context.state_group = self.store.get_next_state_group()
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
try:
self.auth.check(event, auth_events=auth_events)
@@ -1893,6 +1954,58 @@ class FederationHandler(BaseHandler):
raise e
@defer.inlineCallbacks
+ def _update_context_for_auth_events(self, event, context, auth_events,
+ event_key):
+ """Update the state_ids in an event context after auth event resolution,
+ storing the changes as a new state group.
+
+ Args:
+ event (Event): The event we're handling the context for
+
+ context (synapse.events.snapshot.EventContext): event context
+ to be updated
+
+ auth_events (dict[(str, str)->str]): Events to update in the event
+ context.
+
+ event_key ((str, str)): (type, state_key) for the current event.
+ this will not be included in the current_state in the context.
+ """
+ state_updates = {
+ k: a.event_id for k, a in iteritems(auth_events)
+ if k != event_key
+ }
+ current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = dict(current_state_ids)
+
+ current_state_ids.update(state_updates)
+
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = dict(prev_state_ids)
+
+ prev_state_ids.update({
+ k: a.event_id for k, a in iteritems(auth_events)
+ })
+
+ # create a new state group as a delta from the existing one.
+ prev_group = context.state_group
+ state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=prev_group,
+ delta_ids=state_updates,
+ current_state_ids=current_state_ids,
+ )
+
+ yield context.update_state(
+ state_group=state_group,
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ prev_group=prev_group,
+ delta_ids=state_updates,
+ )
+
+ @defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
""" Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
@@ -1934,8 +2047,8 @@ class FederationHandler(BaseHandler):
def get_next(it, opt=None):
try:
- return it.next()
- except:
+ return next(it)
+ except Exception:
return opt
current_local = get_next(local_iter)
@@ -2060,8 +2173,7 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -2076,7 +2188,7 @@ class FederationHandler(BaseHandler):
raise e
yield self._check_signature(event, context)
- member_handler = self.hs.get_handlers().room_member_handler
+ member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
@@ -2089,10 +2201,17 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
+ """Handle an exchange_third_party_invite request from a remote server
+
+ The remote server will call this when it wants to turn a 3pid invite
+ into a normal m.room.member invite.
+
+ Returns:
+ Deferred: resolves (to None)
+ """
builder = self.event_builder_factory.new(event_dict)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
@@ -2107,10 +2226,13 @@ class FederationHandler(BaseHandler):
raise e
yield self._check_signature(event, context)
+ # XXX we send the invite here, but send_membership_event also sends it,
+ # so we end up making two requests. I think this is redundant.
returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
- member_handler = self.hs.get_handlers().room_member_handler
+
+ member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks
@@ -2120,7 +2242,8 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"]
)
original_invite = None
- original_invite_id = context.prev_state_ids.get(key)
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
@@ -2139,8 +2262,9 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(builder=builder)
+ event, context = yield self.event_creation_handler.create_new_client_event(
+ builder=builder,
+ )
defer.returnValue((event, context))
@defer.inlineCallbacks
@@ -2161,7 +2285,8 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
- invite_event_id = context.prev_state_ids.get(
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ invite_event_id = prev_state_ids.get(
(EventTypes.ThirdPartyInvite, token,)
)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
new file mode 100644
index 0000000000..53e5e2648b
--- /dev/null
+++ b/synapse/handlers/groups_local.py
@@ -0,0 +1,473 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import iteritems
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id
+
+logger = logging.getLogger(__name__)
+
+
+def _create_rerouter(func_name):
+ """Returns a function that looks at the group id and calls the function
+ on federation or the local group server if the group is local
+ """
+ def f(self, group_id, *args, **kwargs):
+ if self.is_mine_id(group_id):
+ return getattr(self.groups_server_handler, func_name)(
+ group_id, *args, **kwargs
+ )
+ else:
+ destination = get_domain_from_id(group_id)
+ return getattr(self.transport_client, func_name)(
+ destination, group_id, *args, **kwargs
+ )
+ return f
+
+
+class GroupsLocalHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.room_list_handler = hs.get_room_list_handler()
+ self.groups_server_handler = hs.get_groups_server_handler()
+ self.transport_client = hs.get_federation_transport_client()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.keyring = hs.get_keyring()
+ self.is_mine_id = hs.is_mine_id
+ self.signing_key = hs.config.signing_key[0]
+ self.server_name = hs.hostname
+ self.notifier = hs.get_notifier()
+ self.attestations = hs.get_groups_attestation_signing()
+
+ self.profile_handler = hs.get_profile_handler()
+
+ # Ensure attestations get renewed
+ hs.get_groups_attestation_renewer()
+
+ # The following functions merely route the query to the local groups server
+ # or federation depending on if the group is local or remote
+
+ get_group_profile = _create_rerouter("get_group_profile")
+ update_group_profile = _create_rerouter("update_group_profile")
+ get_rooms_in_group = _create_rerouter("get_rooms_in_group")
+
+ get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
+
+ add_room_to_group = _create_rerouter("add_room_to_group")
+ update_room_in_group = _create_rerouter("update_room_in_group")
+ remove_room_from_group = _create_rerouter("remove_room_from_group")
+
+ update_group_summary_room = _create_rerouter("update_group_summary_room")
+ delete_group_summary_room = _create_rerouter("delete_group_summary_room")
+
+ update_group_category = _create_rerouter("update_group_category")
+ delete_group_category = _create_rerouter("delete_group_category")
+ get_group_category = _create_rerouter("get_group_category")
+ get_group_categories = _create_rerouter("get_group_categories")
+
+ update_group_summary_user = _create_rerouter("update_group_summary_user")
+ delete_group_summary_user = _create_rerouter("delete_group_summary_user")
+
+ update_group_role = _create_rerouter("update_group_role")
+ delete_group_role = _create_rerouter("delete_group_role")
+ get_group_role = _create_rerouter("get_group_role")
+ get_group_roles = _create_rerouter("get_group_roles")
+
+ set_group_join_policy = _create_rerouter("set_group_join_policy")
+
+ @defer.inlineCallbacks
+ def get_group_summary(self, group_id, requester_user_id):
+ """Get the group summary for a group.
+
+ If the group is remote we check that the users have valid attestations.
+ """
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.get_group_summary(
+ group_id, requester_user_id
+ )
+ else:
+ res = yield self.transport_client.get_group_summary(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ )
+
+ group_server_name = get_domain_from_id(group_id)
+
+ # Loop through the users and validate the attestations.
+ chunk = res["users_section"]["users"]
+ valid_users = []
+ for entry in chunk:
+ g_user_id = entry["user_id"]
+ attestation = entry.pop("attestation", {})
+ try:
+ if get_domain_from_id(g_user_id) != group_server_name:
+ yield self.attestations.verify_attestation(
+ attestation,
+ group_id=group_id,
+ user_id=g_user_id,
+ server_name=get_domain_from_id(g_user_id),
+ )
+ valid_users.append(entry)
+ except Exception as e:
+ logger.info("Failed to verify user is in group: %s", e)
+
+ res["users_section"]["users"] = valid_users
+
+ res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
+ res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
+
+ # Add `is_publicised` flag to indicate whether the user has publicised their
+ # membership of the group on their profile
+ result = yield self.store.get_publicised_groups_for_user(requester_user_id)
+ is_publicised = group_id in result
+
+ res.setdefault("user", {})["is_publicised"] = is_publicised
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def create_group(self, group_id, user_id, content):
+ """Create a group
+ """
+
+ logger.info("Asking to create group with ID: %r", group_id)
+
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.create_group(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ content["user_profile"] = yield self.profile_handler.get_profile(user_id)
+
+ res = yield self.transport_client.create_group(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ is_publicised = content.get("publicise", False)
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=True,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_users_in_group(self, group_id, requester_user_id):
+ """Get users in a group
+ """
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.get_users_in_group(
+ group_id, requester_user_id
+ )
+ defer.returnValue(res)
+
+ group_server_name = get_domain_from_id(group_id)
+
+ res = yield self.transport_client.get_users_in_group(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ )
+
+ chunk = res["chunk"]
+ valid_entries = []
+ for entry in chunk:
+ g_user_id = entry["user_id"]
+ attestation = entry.pop("attestation", {})
+ try:
+ if get_domain_from_id(g_user_id) != group_server_name:
+ yield self.attestations.verify_attestation(
+ attestation,
+ group_id=group_id,
+ user_id=g_user_id,
+ server_name=get_domain_from_id(g_user_id),
+ )
+ valid_entries.append(entry)
+ except Exception as e:
+ logger.info("Failed to verify user is in group: %s", e)
+
+ res["chunk"] = valid_entries
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def join_group(self, group_id, user_id, content):
+ """Request to join a group
+ """
+ if self.is_mine_id(group_id):
+ yield self.groups_server_handler.join_group(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ res = yield self.transport_client.join_group(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ # TODO: Check that the group is public and we're being added publically
+ is_publicised = content.get("publicise", False)
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=False,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def accept_invite(self, group_id, user_id, content):
+ """Accept an invite to a group
+ """
+ if self.is_mine_id(group_id):
+ yield self.groups_server_handler.accept_invite(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ res = yield self.transport_client.accept_group_invite(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ # TODO: Check that the group is public and we're being added publically
+ is_publicised = content.get("publicise", False)
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=False,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def invite(self, group_id, user_id, requester_user_id, config):
+ """Invite a user to a group
+ """
+ content = {
+ "requester_user_id": requester_user_id,
+ "config": config,
+ }
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.invite_to_group(
+ group_id, user_id, requester_user_id, content,
+ )
+ else:
+ res = yield self.transport_client.invite_to_group(
+ get_domain_from_id(group_id), group_id, user_id, requester_user_id,
+ content,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def on_invite(self, group_id, user_id, content):
+ """One of our users were invited to a group
+ """
+ # TODO: Support auto join and rejection
+
+ if not self.is_mine_id(user_id):
+ raise SynapseError(400, "User not on this server")
+
+ local_profile = {}
+ if "profile" in content:
+ if "name" in content["profile"]:
+ local_profile["name"] = content["profile"]["name"]
+ if "avatar_url" in content["profile"]:
+ local_profile["avatar_url"] = content["profile"]["avatar_url"]
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="invite",
+ content={"profile": local_profile, "inviter": content["inviter"]},
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+ try:
+ user_profile = yield self.profile_handler.get_profile(user_id)
+ except Exception as e:
+ logger.warn("No profile for user %s: %s", user_id, e)
+ user_profile = {}
+
+ defer.returnValue({"state": "invite", "user_profile": user_profile})
+
+ @defer.inlineCallbacks
+ def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+ """Remove a user from a group
+ """
+ if user_id == requester_user_id:
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="leave",
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ # TODO: Should probably remember that we tried to leave so that we can
+ # retry if the group server is currently down.
+
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+ else:
+ content["requester_user_id"] = requester_user_id
+ res = yield self.transport_client.remove_user_from_group(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ user_id, content,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def user_removed_from_group(self, group_id, user_id, content):
+ """One of our users was removed/kicked from a group
+ """
+ # TODO: Check if user in group
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="leave",
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ @defer.inlineCallbacks
+ def get_joined_groups(self, user_id):
+ group_ids = yield self.store.get_joined_groups(user_id)
+ defer.returnValue({"groups": group_ids})
+
+ @defer.inlineCallbacks
+ def get_publicised_groups_for_user(self, user_id):
+ if self.hs.is_mine_id(user_id):
+ result = yield self.store.get_publicised_groups_for_user(user_id)
+
+ # Check AS associated groups for this user - this depends on the
+ # RegExps in the AS registration file (under `users`)
+ for app_service in self.store.get_app_services():
+ result.extend(app_service.get_groups_for_user(user_id))
+
+ defer.returnValue({"groups": result})
+ else:
+ bulk_result = yield self.transport_client.bulk_get_publicised_groups(
+ get_domain_from_id(user_id), [user_id],
+ )
+ result = bulk_result.get("users", {}).get(user_id)
+ # TODO: Verify attestations
+ defer.returnValue({"groups": result})
+
+ @defer.inlineCallbacks
+ def bulk_get_publicised_groups(self, user_ids, proxy=True):
+ destinations = {}
+ local_users = set()
+
+ for user_id in user_ids:
+ if self.hs.is_mine_id(user_id):
+ local_users.add(user_id)
+ else:
+ destinations.setdefault(
+ get_domain_from_id(user_id), set()
+ ).add(user_id)
+
+ if not proxy and destinations:
+ raise SynapseError(400, "Some user_ids are not local")
+
+ results = {}
+ failed_results = []
+ for destination, dest_user_ids in iteritems(destinations):
+ try:
+ r = yield self.transport_client.bulk_get_publicised_groups(
+ destination, list(dest_user_ids),
+ )
+ results.update(r["users"])
+ except Exception:
+ failed_results.extend(dest_user_ids)
+
+ for uid in local_users:
+ results[uid] = yield self.store.get_publicised_groups_for_user(
+ uid
+ )
+
+ # Check AS associated groups for this user - this depends on the
+ # RegExps in the AS registration file (under `users`)
+ for app_service in self.store.get_app_services():
+ results[uid].extend(app_service.get_groups_for_user(uid))
+
+ defer.returnValue({"users": results})
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9efcdff1d6..8c8aedb2b8 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,17 +16,21 @@
# limitations under the License.
"""Utilities for interacting with Identity Servers"""
+
+import logging
+
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.api.errors import (
- MatrixCodeMessageException, CodeMessageException
+ CodeMessageException,
+ Codes,
+ MatrixCodeMessageException,
+ SynapseError,
)
-from ._base import BaseHandler
-from synapse.util.async import run_on_reactor
-from synapse.api.errors import SynapseError, Codes
-import json
-import logging
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -36,6 +41,7 @@ class IdentityHandler(BaseHandler):
super(IdentityHandler, self).__init__(hs)
self.http_client = hs.get_simple_http_client()
+ self.federation_http_client = hs.get_http_client()
self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
self.trust_any_id_server_just_for_testing_do_not_use = (
@@ -58,8 +64,6 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
- yield run_on_reactor()
-
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
@@ -102,7 +106,6 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def bind_threepid(self, creds, mxid):
- yield run_on_reactor()
logger.debug("binding threepid %r to %s", creds, mxid)
data = None
@@ -137,9 +140,53 @@ class IdentityHandler(BaseHandler):
defer.returnValue(data)
@defer.inlineCallbacks
- def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
- yield run_on_reactor()
+ def unbind_threepid(self, mxid, threepid):
+ """
+ Removes a binding from an identity server
+ Args:
+ mxid (str): Matrix user ID of binding to be removed
+ threepid (dict): Dict with medium & address of binding to be removed
+
+ Returns:
+ Deferred[bool]: True on success, otherwise False
+ """
+ logger.debug("unbinding threepid %r from %s", threepid, mxid)
+ if not self.trusted_id_servers:
+ logger.warn("Can't unbind threepid: no trusted ID servers set in config")
+ defer.returnValue(False)
+
+ # We don't track what ID server we added 3pids on (perhaps we ought to)
+ # but we assume that any of the servers in the trusted list are in the
+ # same ID server federation, so we can pick any one of them to send the
+ # deletion request to.
+ id_server = next(iter(self.trusted_id_servers))
+
+ url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
+ content = {
+ "mxid": mxid,
+ "threepid": threepid,
+ }
+ headers = {}
+ # we abuse the federation http client to sign the request, but we have to send it
+ # using the normal http client since we don't want the SRV lookup and want normal
+ # 'browser-like' HTTPS.
+ self.federation_http_client.sign_request(
+ destination=None,
+ method='POST',
+ url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
+ headers_dict=headers,
+ content=content,
+ destination_is=id_server,
+ )
+ yield self.http_client.post_json_get_json(
+ url,
+ content,
+ headers,
+ )
+ defer.returnValue(True)
+ @defer.inlineCallbacks
+ def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
@@ -174,8 +221,6 @@ class IdentityHandler(BaseHandler):
self, id_server, country, phone_number,
client_secret, send_attempt, **kwargs
):
- yield run_on_reactor()
-
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 10f5f35a69..40e7580a61 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
@@ -21,20 +23,15 @@ from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig
-from synapse.types import (
- UserID, StreamToken,
-)
+from synapse.types import StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -151,22 +148,25 @@ class InitialSyncHandler(BaseHandler):
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
- deferred_room_state = self.state_handler.get_current_state(
- event.room_id
+ deferred_room_state = run_in_background(
+ self.state_handler.get_current_state,
+ event.room_id,
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
- deferred_room_state = self.store.get_state_for_events(
- [event.event_id], None
+ deferred_room_state = run_in_background(
+ self.store.get_state_for_events,
+ [event.event_id], None,
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
- (messages, token), current_state = yield preserve_context_over_deferred(
+ (messages, token), current_state = yield make_deferred_yieldable(
defer.gatherResults(
[
- preserve_fn(self.store.get_recent_events_for_room)(
+ run_in_background(
+ self.store.get_recent_events_for_room,
event.room_id,
limit=limit,
end_token=room_end_token,
@@ -180,8 +180,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages
)
- start_token = now_token.copy_and_replace("room_key", token[0])
- end_token = now_token.copy_and_replace("room_key", token[1])
+ start_token = now_token.copy_and_replace("room_key", token)
+ end_token = now_token.copy_and_replace("room_key", room_end_token)
time_now = self.clock.time_msec()
d["messages"] = {
@@ -214,7 +214,7 @@ class InitialSyncHandler(BaseHandler):
})
d["account_data"] = account_data_events
- except:
+ except Exception:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
@@ -324,8 +324,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages, is_peeking=is_peeking
)
- start_token = StreamToken.START.copy_and_replace("room_key", token[0])
- end_token = StreamToken.START.copy_and_replace("room_key", token[1])
+ start_token = StreamToken.START.copy_and_replace("room_key", token)
+ end_token = StreamToken.START.copy_and_replace("room_key", stream_token)
time_now = self.clock.time_msec()
@@ -389,25 +389,28 @@ class InitialSyncHandler(BaseHandler):
receipts = []
defer.returnValue(receipts)
- presence, receipts, (messages, token) = yield defer.gatherResults(
- [
- preserve_fn(get_presence)(),
- preserve_fn(get_receipts)(),
- preserve_fn(self.store.get_recent_events_for_room)(
- room_id,
- limit=limit,
- end_token=now_token.room_key,
- )
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ presence, receipts, (messages, token) = yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(get_presence),
+ run_in_background(get_receipts),
+ run_in_background(
+ self.store.get_recent_events_for_room,
+ room_id,
+ limit=limit,
+ end_token=now_token.room_key,
+ )
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError),
+ )
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking,
)
- start_token = now_token.copy_and_replace("room_key", token[0])
- end_token = now_token.copy_and_replace("room_key", token[1])
+ start_token = now_token.copy_and_replace("room_key", token)
+ end_token = now_token
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 24c9ffdb20..39d7724778 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2017 - 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,173 +13,185 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
+
+import six
+from six import iteritems, itervalues, string_types
+
+from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
+from twisted.internet.defer import succeed
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
+from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
-from synapse.types import (
- UserID, RoomAlias, RoomStreamToken,
-)
-from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
-from synapse.util.logcontext import preserve_fn
+from synapse.replication.http.send_event import send_event_to_master
+from synapse.types import RoomAlias, UserID
+from synapse.util.async import Linearizer
+from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func
-from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
-from canonicaljson import encode_canonical_json
-
-import logging
-import random
-import ujson
-
logger = logging.getLogger(__name__)
-class MessageHandler(BaseHandler):
+class MessageHandler(object):
+ """Contains some read only APIs to get state about a room
+ """
def __init__(self, hs):
- super(MessageHandler, self).__init__(hs)
- self.hs = hs
- self.state = hs.get_state_handler()
+ self.auth = hs.get_auth()
self.clock = hs.get_clock()
- self.validator = EventValidator()
-
- self.pagination_lock = ReadWriteLock()
-
- # We arbitrarily limit concurrent event creation for a room to 5.
- # This is to stop us from diverging history *too* much.
- self.limiter = Limiter(max_count=5)
-
- self.action_generator = hs.get_action_generator()
+ self.state = hs.get_state_handler()
+ self.store = hs.get_datastore()
@defer.inlineCallbacks
- def purge_history(self, room_id, event_id):
- event = yield self.store.get_event(event_id)
+ def get_room_data(self, user_id=None, room_id=None,
+ event_type=None, state_key="", is_guest=False):
+ """ Get data from a room.
- if event.room_id != room_id:
- raise SynapseError(400, "Event is for wrong room.")
+ Args:
+ event : The room path event
+ Returns:
+ The path data content.
+ Raises:
+ SynapseError if something went wrong.
+ """
+ membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
+ room_id, user_id
+ )
- depth = event.depth
+ if membership == Membership.JOIN:
+ data = yield self.state.get_current_state(
+ room_id, event_type, state_key
+ )
+ elif membership == Membership.LEAVE:
+ key = (event_type, state_key)
+ room_state = yield self.store.get_state_for_events(
+ [membership_event_id], [key]
+ )
+ data = room_state[membership_event_id].get(key)
- with (yield self.pagination_lock.write(room_id)):
- yield self.store.delete_old_state(room_id, depth)
+ defer.returnValue(data)
@defer.inlineCallbacks
- def get_messages(self, requester, room_id=None, pagin_config=None,
- as_client_event=True, event_filter=None):
- """Get messages in a room.
+ def get_state_events(self, user_id, room_id, is_guest=False):
+ """Retrieve all state events for a given room. If the user is
+ joined to the room then return the current state. If the user has
+ left the room return the state events from when they left.
Args:
- requester (Requester): The user requesting messages.
- room_id (str): The room they want messages from.
- pagin_config (synapse.api.streams.PaginationConfig): The pagination
- config rules to apply, if any.
- as_client_event (bool): True to get events in client-server format.
- event_filter (Filter): Filter to apply to results or None
+ user_id(str): The user requesting state events.
+ room_id(str): The room ID to get all state events from.
Returns:
- dict: Pagination API results
+ A list of dicts representing state events. [{}, {}, {}]
"""
- user_id = requester.user.to_string()
+ membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
+ room_id, user_id
+ )
- if pagin_config.from_token:
- room_token = pagin_config.from_token.room_key
- else:
- pagin_config.from_token = (
- yield self.hs.get_event_sources().get_current_token_for_room(
- room_id=room_id
- )
+ if membership == Membership.JOIN:
+ room_state = yield self.state.get_current_state(room_id)
+ elif membership == Membership.LEAVE:
+ room_state = yield self.store.get_state_for_events(
+ [membership_event_id], None
)
- room_token = pagin_config.from_token.room_key
-
- room_token = RoomStreamToken.parse(room_token)
+ room_state = room_state[membership_event_id]
- pagin_config.from_token = pagin_config.from_token.copy_and_replace(
- "room_key", str(room_token)
+ now = self.clock.time_msec()
+ defer.returnValue(
+ [serialize_event(c, now) for c in room_state.values()]
)
- source_config = pagin_config.get_source_config("room")
+ @defer.inlineCallbacks
+ def get_joined_members(self, requester, room_id):
+ """Get all the joined members in the room and their profile information.
+
+ If the user has left the room return the state events from when they left.
- with (yield self.pagination_lock.read(room_id)):
- membership, member_event_id = yield self._check_in_room_or_world_readable(
+ Args:
+ requester(Requester): The user requesting state events.
+ room_id(str): The room ID to get all state events from.
+ Returns:
+ A dict of user_id to profile info
+ """
+ user_id = requester.user.to_string()
+ if not requester.app_service:
+ # We check AS auth after fetching the room membership, as it
+ # requires us to pull out all joined members anyway.
+ membership, _ = yield self.auth.check_in_room_or_world_readable(
room_id, user_id
)
+ if membership != Membership.JOIN:
+ raise NotImplementedError(
+ "Getting joined members after leaving is not implemented"
+ )
- if source_config.direction == 'b':
- # if we're going backwards, we might need to backfill. This
- # requires that we have a topo token.
- if room_token.topological:
- max_topo = room_token.topological
- else:
- max_topo = yield self.store.get_max_topological_token(
- room_id, room_token.stream
- )
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
- if membership == Membership.LEAVE:
- # If they have left the room then clamp the token to be before
- # they left the room, to save the effort of loading from the
- # database.
- leave_token = yield self.store.get_topological_token_for_event(
- member_event_id
- )
- leave_token = RoomStreamToken.parse(leave_token)
- if leave_token.topological < max_topo:
- source_config.from_key = str(leave_token)
+ # If this is an AS, double check that they are allowed to see the members.
+ # This can either be because the AS user is in the room or because there
+ # is a user in the room that the AS is "interested in"
+ if requester.app_service and user_id not in users_with_profile:
+ for uid in users_with_profile:
+ if requester.app_service.is_interested_in_user(uid):
+ break
+ else:
+ # Loop fell through, AS has no interested users in room
+ raise AuthError(403, "Appservice not in room")
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
- room_id, max_topo
- )
+ defer.returnValue({
+ user_id: {
+ "avatar_url": profile.avatar_url,
+ "display_name": profile.display_name,
+ }
+ for user_id, profile in iteritems(users_with_profile)
+ })
- events, next_key = yield self.store.paginate_room_events(
- room_id=room_id,
- from_key=source_config.from_key,
- to_key=source_config.to_key,
- direction=source_config.direction,
- limit=source_config.limit,
- event_filter=event_filter,
- )
- next_token = pagin_config.from_token.copy_and_replace(
- "room_key", next_key
- )
+class EventCreationHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.clock = hs.get_clock()
+ self.validator = EventValidator()
+ self.profile_handler = hs.get_profile_handler()
+ self.event_builder_factory = hs.get_event_builder_factory()
+ self.server_name = hs.hostname
+ self.ratelimiter = hs.get_ratelimiter()
+ self.notifier = hs.get_notifier()
+ self.config = hs.config
- if not events:
- defer.returnValue({
- "chunk": [],
- "start": pagin_config.from_token.to_string(),
- "end": next_token.to_string(),
- })
-
- if event_filter:
- events = event_filter.filter(events)
-
- events = yield filter_events_for_client(
- self.store,
- user_id,
- events,
- is_peeking=(member_event_id is None),
- )
+ self.http_client = hs.get_simple_http_client()
- time_now = self.clock.time_msec()
+ # This is only used to get at ratelimit function, and maybe_kick_guest_users
+ self.base_handler = BaseHandler(hs)
- chunk = {
- "chunk": [
- serialize_event(e, time_now, as_client_event)
- for e in events
- ],
- "start": pagin_config.from_token.to_string(),
- "end": next_token.to_string(),
- }
+ self.pusher_pool = hs.get_pusherpool()
- defer.returnValue(chunk)
+ # We arbitrarily limit concurrent event creation for a room to 5.
+ # This is to stop us from diverging history *too* much.
+ self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
+
+ self.action_generator = hs.get_action_generator()
+
+ self.spam_checker = hs.get_spam_checker()
+
+ if self.config.block_events_without_consent_error is not None:
+ self._consent_uri_builder = ConsentURIBuilder(self.config)
@defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
- prev_event_ids=None):
+ prev_events_and_hashes=None):
"""
Given a dict from a client, create a new event.
@@ -192,50 +205,143 @@ class MessageHandler(BaseHandler):
event_dict (dict): An entire event
token_id (str)
txn_id (str)
- prev_event_ids (list): The prev event ids to use when creating the event
+
+ prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+ the forward extremities to use as the prev_events for the
+ new event. For each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+
+ If None, they will be requested from the database.
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
- with (yield self.limiter.queue(builder.room_id)):
- self.validator.validate_new(builder)
-
- if builder.type == EventTypes.Member:
- membership = builder.content.get("membership", None)
- target = UserID.from_string(builder.state_key)
-
- if membership in {Membership.JOIN, Membership.INVITE}:
- # If event doesn't include a display name, add one.
- profile = self.hs.get_handlers().profile_handler
- content = builder.content
-
- try:
- if "displayname" not in content:
- content["displayname"] = yield profile.get_displayname(target)
- if "avatar_url" not in content:
- content["avatar_url"] = yield profile.get_avatar_url(target)
- except Exception as e:
- logger.info(
- "Failed to get profile information for %r: %s",
- target, e
- )
+ self.validator.validate_new(builder)
+
+ if builder.type == EventTypes.Member:
+ membership = builder.content.get("membership", None)
+ target = UserID.from_string(builder.state_key)
+
+ if membership in {Membership.JOIN, Membership.INVITE}:
+ # If event doesn't include a display name, add one.
+ profile = self.profile_handler
+ content = builder.content
+
+ try:
+ if "displayname" not in content:
+ content["displayname"] = yield profile.get_displayname(target)
+ if "avatar_url" not in content:
+ content["avatar_url"] = yield profile.get_avatar_url(target)
+ except Exception as e:
+ logger.info(
+ "Failed to get profile information for %r: %s",
+ target, e
+ )
- if token_id is not None:
- builder.internal_metadata.token_id = token_id
+ is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
+ if not is_exempt:
+ yield self.assert_accepted_privacy_policy(requester)
- if txn_id is not None:
- builder.internal_metadata.txn_id = txn_id
+ if token_id is not None:
+ builder.internal_metadata.token_id = token_id
- event, context = yield self._create_new_client_event(
- builder=builder,
- requester=requester,
- prev_event_ids=prev_event_ids,
- )
+ if txn_id is not None:
+ builder.internal_metadata.txn_id = txn_id
+
+ event, context = yield self.create_new_client_event(
+ builder=builder,
+ requester=requester,
+ prev_events_and_hashes=prev_events_and_hashes,
+ )
defer.returnValue((event, context))
+ def _is_exempt_from_privacy_policy(self, builder, requester):
+ """"Determine if an event to be sent is exempt from having to consent
+ to the privacy policy
+
+ Args:
+ builder (synapse.events.builder.EventBuilder): event being created
+ requester (Requster): user requesting this event
+
+ Returns:
+ Deferred[bool]: true if the event can be sent without the user
+ consenting
+ """
+ # the only thing the user can do is join the server notices room.
+ if builder.type == EventTypes.Member:
+ membership = builder.content.get("membership", None)
+ if membership == Membership.JOIN:
+ return self._is_server_notices_room(builder.room_id)
+ elif membership == Membership.LEAVE:
+ # the user is always allowed to leave (but not kick people)
+ return builder.state_key == requester.user.to_string()
+ return succeed(False)
+
+ @defer.inlineCallbacks
+ def _is_server_notices_room(self, room_id):
+ if self.config.server_notices_mxid is None:
+ defer.returnValue(False)
+ user_ids = yield self.store.get_users_in_room(room_id)
+ defer.returnValue(self.config.server_notices_mxid in user_ids)
+
+ @defer.inlineCallbacks
+ def assert_accepted_privacy_policy(self, requester):
+ """Check if a user has accepted the privacy policy
+
+ Called when the given user is about to do something that requires
+ privacy consent. We see if the user is exempt and otherwise check that
+ they have given consent. If they have not, a ConsentNotGiven error is
+ raised.
+
+ Args:
+ requester (synapse.types.Requester):
+ The user making the request
+
+ Returns:
+ Deferred[None]: returns normally if the user has consented or is
+ exempt
+
+ Raises:
+ ConsentNotGivenError: if the user has not given consent yet
+ """
+ if self.config.block_events_without_consent_error is None:
+ return
+
+ # exempt AS users from needing consent
+ if requester.app_service is not None:
+ return
+
+ user_id = requester.user.to_string()
+
+ # exempt the system notices user
+ if (
+ self.config.server_notices_mxid is not None and
+ user_id == self.config.server_notices_mxid
+ ):
+ return
+
+ u = yield self.store.get_user_by_id(user_id)
+ assert u is not None
+ if u["appservice_id"] is not None:
+ # users registered by an appservice are exempt
+ return
+ if u["consent_version"] == self.config.user_consent_version:
+ return
+
+ consent_uri = self._consent_uri_builder.build_user_consent_uri(
+ requester.user.localpart,
+ )
+ msg = self.config.block_events_without_consent_error % {
+ 'consent_uri': consent_uri,
+ }
+ raise ConsentNotGivenError(
+ msg=msg,
+ consent_uri=consent_uri,
+ )
+
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
@@ -253,11 +359,6 @@ class MessageHandler(BaseHandler):
"Tried to send member event through non-member codepath"
)
- # We check here if we are currently being rate limited, so that we
- # don't do unnecessary work. We check again just before we actually
- # send the event.
- yield self.ratelimit(requester, update=False)
-
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
@@ -274,12 +375,6 @@ class MessageHandler(BaseHandler):
ratelimit=ratelimit,
)
- if event.type == EventTypes.Message:
- presence = self.hs.get_presence_handler()
- # We don't want to block sending messages on any presence code. This
- # matters as sometimes presence code can take a while.
- preserve_fn(presence.bump_presence_active_time)(user)
-
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
@@ -288,7 +383,8 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context.
Otherwise, returns None.
"""
- prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_event_id = prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
@@ -313,145 +409,85 @@ class MessageHandler(BaseHandler):
See self.create_event and self.send_nonmember_event.
"""
- event, context = yield self.create_event(
- requester,
- event_dict,
- token_id=requester.access_token_id,
- txn_id=txn_id
- )
- yield self.send_nonmember_event(
- requester,
- event,
- context,
- ratelimit=ratelimit,
- )
- defer.returnValue(event)
- @defer.inlineCallbacks
- def get_room_data(self, user_id=None, room_id=None,
- event_type=None, state_key="", is_guest=False):
- """ Get data from a room.
-
- Args:
- event : The room path event
- Returns:
- The path data content.
- Raises:
- SynapseError if something went wrong.
- """
- membership, membership_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
-
- if membership == Membership.JOIN:
- data = yield self.state_handler.get_current_state(
- room_id, event_type, state_key
- )
- elif membership == Membership.LEAVE:
- key = (event_type, state_key)
- room_state = yield self.store.get_state_for_events(
- [membership_event_id], [key]
+ # We limit the number of concurrent event sends in a room so that we
+ # don't fork the DAG too much. If we don't limit then we can end up in
+ # a situation where event persistence can't keep up, causing
+ # extremities to pile up, which in turn leads to state resolution
+ # taking longer.
+ with (yield self.limiter.queue(event_dict["room_id"])):
+ event, context = yield self.create_event(
+ requester,
+ event_dict,
+ token_id=requester.access_token_id,
+ txn_id=txn_id
)
- data = room_state[membership_event_id].get(key)
- defer.returnValue(data)
+ spam_error = self.spam_checker.check_event_for_spam(event)
+ if spam_error:
+ if not isinstance(spam_error, string_types):
+ spam_error = "Spam is not permitted here"
+ raise SynapseError(
+ 403, spam_error, Codes.FORBIDDEN
+ )
- @defer.inlineCallbacks
- def _check_in_room_or_world_readable(self, room_id, user_id):
- try:
- # check_user_was_in_room will return the most recent membership
- # event for the user if:
- # * The user is a non-guest user, and was ever in the room
- # * The user is a guest user, and has joined the room
- # else it will throw.
- member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
- defer.returnValue((member_event.membership, member_event.event_id))
- return
- except AuthError:
- visibility = yield self.state_handler.get_current_state(
- room_id, EventTypes.RoomHistoryVisibility, ""
- )
- if (
- visibility and
- visibility.content["history_visibility"] == "world_readable"
- ):
- defer.returnValue((Membership.JOIN, None))
- return
- raise AuthError(
- 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+ yield self.send_nonmember_event(
+ requester,
+ event,
+ context,
+ ratelimit=ratelimit,
)
+ defer.returnValue(event)
+ @measure_func("create_new_client_event")
@defer.inlineCallbacks
- def get_state_events(self, user_id, room_id, is_guest=False):
- """Retrieve all state events for a given room. If the user is
- joined to the room then return the current state. If the user has
- left the room return the state events from when they left.
+ def create_new_client_event(self, builder, requester=None,
+ prev_events_and_hashes=None):
+ """Create a new event for a local client
Args:
- user_id(str): The user requesting state events.
- room_id(str): The room ID to get all state events from.
+ builder (EventBuilder):
+
+ requester (synapse.types.Requester|None):
+
+ prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+ the forward extremities to use as the prev_events for the
+ new event. For each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+
+ If None, they will be requested from the database.
+
Returns:
- A list of dicts representing state events. [{}, {}, {}]
+ Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
"""
- membership, membership_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
- if membership == Membership.JOIN:
- room_state = yield self.state_handler.get_current_state(room_id)
- elif membership == Membership.LEAVE:
- room_state = yield self.store.get_state_for_events(
- [membership_event_id], None
+ if prev_events_and_hashes is not None:
+ assert len(prev_events_and_hashes) <= 10, \
+ "Attempting to create an event with %i prev_events" % (
+ len(prev_events_and_hashes),
)
- room_state = room_state[membership_event_id]
-
- now = self.clock.time_msec()
- defer.returnValue(
- [serialize_event(c, now) for c in room_state.values()]
- )
-
- @measure_func("_create_new_client_event")
- @defer.inlineCallbacks
- def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
- if prev_event_ids:
- prev_events = yield self.store.add_event_hashes(prev_event_ids)
- prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
- depth = prev_max_depth + 1
else:
- latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
- builder.room_id,
- )
-
- # We want to limit the max number of prev events we point to in our
- # new event
- if len(latest_ret) > 10:
- # Sort by reverse depth, so we point to the most recent.
- latest_ret.sort(key=lambda a: -a[2])
- new_latest_ret = latest_ret[:5]
-
- # We also randomly point to some of the older events, to make
- # sure that we don't completely ignore the older events.
- if latest_ret[5:]:
- sample_size = min(5, len(latest_ret[5:]))
- new_latest_ret.extend(random.sample(latest_ret[5:], sample_size))
- latest_ret = new_latest_ret
-
- if latest_ret:
- depth = max([d for _, _, d in latest_ret]) + 1
- else:
- depth = 1
+ prev_events_and_hashes = \
+ yield self.store.get_prev_events_for_room(builder.room_id)
+
+ if prev_events_and_hashes:
+ depth = max([d for _, _, d in prev_events_and_hashes]) + 1
+ # we cap depth of generated events, to ensure that they are not
+ # rejected by other servers (and so that they can be persisted in
+ # the db)
+ depth = min(depth, MAX_DEPTH)
+ else:
+ depth = 1
- prev_events = [
- (event_id, prev_hashes)
- for event_id, prev_hashes, _ in latest_ret
- ]
+ prev_events = [
+ (event_id, prev_hashes)
+ for event_id, prev_hashes, _ in prev_events_and_hashes
+ ]
builder.prev_events = prev_events
builder.depth = depth
- state_handler = self.state_handler
-
- context = yield state_handler.compute_event_context(builder)
+ context = yield self.state.compute_event_context(builder)
if requester:
context.app_service = requester.app_service
@@ -470,8 +506,8 @@ class MessageHandler(BaseHandler):
event = builder.build()
logger.debug(
- "Created event %s with state: %s",
- event.event_id, context.prev_state_ids,
+ "Created event %s",
+ event.event_id,
)
defer.returnValue(
@@ -486,12 +522,21 @@ class MessageHandler(BaseHandler):
event,
context,
ratelimit=True,
- extra_users=[]
+ extra_users=[],
):
- # We now need to go and hit out to wherever we need to hit out to.
+ """Processes a new event. This includes checking auth, persisting it,
+ notifying users, sending to remote servers, etc.
- if ratelimit:
- yield self.ratelimit(requester)
+ If called from a worker will hit out to the master process for final
+ processing.
+
+ Args:
+ requester (Requester)
+ event (FrozenEvent)
+ context (EventContext)
+ ratelimit (bool)
+ extra_users (list(UserID)): Any extra users to notify about event
+ """
try:
yield self.auth.check_from_context(event, context)
@@ -501,13 +546,72 @@ class MessageHandler(BaseHandler):
# Ensure that we can round trip before trying to persist in db
try:
- dump = ujson.dumps(event.content)
- ujson.loads(dump)
- except:
+ dump = frozendict_json_encoder.encode(event.content)
+ json.loads(dump)
+ except Exception:
logger.exception("Failed to encode content: %r", event.content)
raise
- yield self.maybe_kick_guest_users(event, context)
+ yield self.action_generator.handle_push_actions_for_event(
+ event, context
+ )
+
+ try:
+ # If we're a worker we need to hit out to the master.
+ if self.config.worker_app:
+ yield send_event_to_master(
+ clock=self.hs.get_clock(),
+ store=self.store,
+ client=self.http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ event=event,
+ context=context,
+ ratelimit=ratelimit,
+ extra_users=extra_users,
+ )
+ return
+
+ yield self.persist_and_notify_client_event(
+ requester,
+ event,
+ context,
+ ratelimit=ratelimit,
+ extra_users=extra_users,
+ )
+ except: # noqa: E722, as we reraise the exception this is fine.
+ # Ensure that we actually remove the entries in the push actions
+ # staging area, if we calculated them.
+ tp, value, tb = sys.exc_info()
+
+ run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
+
+ six.reraise(tp, value, tb)
+
+ @defer.inlineCallbacks
+ def persist_and_notify_client_event(
+ self,
+ requester,
+ event,
+ context,
+ ratelimit=True,
+ extra_users=[],
+ ):
+ """Called when we have fully built the event, have already
+ calculated the push actions for the event, and checked auth.
+
+ This should only be run on master.
+ """
+ assert not self.config.worker_app
+
+ if ratelimit:
+ yield self.base_handler.ratelimit(requester)
+
+ yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
@@ -535,9 +639,11 @@ class MessageHandler(BaseHandler):
e.sender == event.sender
)
+ current_state_ids = yield context.get_current_state_ids(self.store)
+
state_to_include_ids = [
e_id
- for k, e_id in context.current_state_ids.iteritems()
+ for k, e_id in iteritems(current_state_ids)
if k[0] in self.hs.config.room_invite_state_types
or k == (EventTypes.Member, event.sender)
]
@@ -551,7 +657,7 @@ class MessageHandler(BaseHandler):
"content": e.content,
"sender": e.sender,
}
- for e in state_to_include.itervalues()
+ for e in itervalues(state_to_include)
]
invitee = UserID.from_string(event.state_key)
@@ -573,8 +679,9 @@ class MessageHandler(BaseHandler):
)
if event.type == EventTypes.Redaction:
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
- event, context.prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
@@ -594,15 +701,13 @@ class MessageHandler(BaseHandler):
"You don't have permission to redact events"
)
- if event.type == EventTypes.Create and context.prev_state_ids:
- raise AuthError(
- 403,
- "Changing the room create event is forbidden",
- )
-
- yield self.action_generator.handle_push_actions_for_event(
- event, context
- )
+ if event.type == EventTypes.Create:
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ if prev_state_ids:
+ raise AuthError(
+ 403,
+ "Changing the room create event is forbidden",
+ )
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
@@ -610,16 +715,31 @@ class MessageHandler(BaseHandler):
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
- preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
+ run_in_background(
+ self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id
)
- @defer.inlineCallbacks
def _notify():
- yield run_on_reactor()
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
- )
+ try:
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id,
+ extra_users=extra_users
+ )
+ except Exception:
+ logger.exception("Error notifying about new room event")
+
+ run_in_background(_notify)
- preserve_fn(_notify)()
+ if event.type == EventTypes.Message:
+ # We don't want to block sending messages on any presence code. This
+ # matters as sometimes presence code can take a while.
+ run_in_background(self._bump_active_time, requester.user)
+
+ @defer.inlineCallbacks
+ def _bump_active_time(self, user):
+ try:
+ presence = self.hs.get_presence_handler()
+ yield presence.bump_presence_active_time(user)
+ except Exception:
+ logger.exception("Error bumping presence active time")
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
new file mode 100644
index 0000000000..b2849783ed
--- /dev/null
+++ b/synapse/handlers/pagination.py
@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2017 - 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+from twisted.python.failure import Failure
+
+from synapse.api.constants import Membership
+from synapse.api.errors import SynapseError
+from synapse.events.utils import serialize_event
+from synapse.types import RoomStreamToken
+from synapse.util.async import ReadWriteLock
+from synapse.util.logcontext import run_in_background
+from synapse.util.stringutils import random_string
+from synapse.visibility import filter_events_for_client
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeStatus(object):
+ """Object tracking the status of a purge request
+
+ This class contains information on the progress of a purge request, for
+ return by get_purge_status.
+
+ Attributes:
+ status (int): Tracks whether this request has completed. One of
+ STATUS_{ACTIVE,COMPLETE,FAILED}
+ """
+
+ STATUS_ACTIVE = 0
+ STATUS_COMPLETE = 1
+ STATUS_FAILED = 2
+
+ STATUS_TEXT = {
+ STATUS_ACTIVE: "active",
+ STATUS_COMPLETE: "complete",
+ STATUS_FAILED: "failed",
+ }
+
+ def __init__(self):
+ self.status = PurgeStatus.STATUS_ACTIVE
+
+ def asdict(self):
+ return {
+ "status": PurgeStatus.STATUS_TEXT[self.status]
+ }
+
+
+class PaginationHandler(object):
+ """Handles pagination and purge history requests.
+
+ These are in the same handler due to the fact we need to block clients
+ paginating during a purge.
+ """
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ self.pagination_lock = ReadWriteLock()
+ self._purges_in_progress_by_room = set()
+ # map from purge id to PurgeStatus
+ self._purges_by_id = {}
+
+ def start_purge_history(self, room_id, token,
+ delete_local_events=False):
+ """Start off a history purge on a room.
+
+ Args:
+ room_id (str): The room to purge from
+
+ token (str): topological token to delete events before
+ delete_local_events (bool): True to delete local events as well as
+ remote ones
+
+ Returns:
+ str: unique ID for this purge transaction.
+ """
+ if room_id in self._purges_in_progress_by_room:
+ raise SynapseError(
+ 400,
+ "History purge already in progress for %s" % (room_id, ),
+ )
+
+ purge_id = random_string(16)
+
+ # we log the purge_id here so that it can be tied back to the
+ # request id in the log lines.
+ logger.info("[purge] starting purge_id %s", purge_id)
+
+ self._purges_by_id[purge_id] = PurgeStatus()
+ run_in_background(
+ self._purge_history,
+ purge_id, room_id, token, delete_local_events,
+ )
+ return purge_id
+
+ @defer.inlineCallbacks
+ def _purge_history(self, purge_id, room_id, token,
+ delete_local_events):
+ """Carry out a history purge on a room.
+
+ Args:
+ purge_id (str): The id for this purge
+ room_id (str): The room to purge from
+ token (str): topological token to delete events before
+ delete_local_events (bool): True to delete local events as well as
+ remote ones
+
+ Returns:
+ Deferred
+ """
+ self._purges_in_progress_by_room.add(room_id)
+ try:
+ with (yield self.pagination_lock.write(room_id)):
+ yield self.store.purge_history(
+ room_id, token, delete_local_events,
+ )
+ logger.info("[purge] complete")
+ self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
+ except Exception:
+ logger.error("[purge] failed: %s", Failure().getTraceback().rstrip())
+ self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
+ finally:
+ self._purges_in_progress_by_room.discard(room_id)
+
+ # remove the purge from the list 24 hours after it completes
+ def clear_purge():
+ del self._purges_by_id[purge_id]
+ self.hs.get_reactor().callLater(24 * 3600, clear_purge)
+
+ def get_purge_status(self, purge_id):
+ """Get the current status of an active purge
+
+ Args:
+ purge_id (str): purge_id returned by start_purge_history
+
+ Returns:
+ PurgeStatus|None
+ """
+ return self._purges_by_id.get(purge_id)
+
+ @defer.inlineCallbacks
+ def get_messages(self, requester, room_id=None, pagin_config=None,
+ as_client_event=True, event_filter=None):
+ """Get messages in a room.
+
+ Args:
+ requester (Requester): The user requesting messages.
+ room_id (str): The room they want messages from.
+ pagin_config (synapse.api.streams.PaginationConfig): The pagination
+ config rules to apply, if any.
+ as_client_event (bool): True to get events in client-server format.
+ event_filter (Filter): Filter to apply to results or None
+ Returns:
+ dict: Pagination API results
+ """
+ user_id = requester.user.to_string()
+
+ if pagin_config.from_token:
+ room_token = pagin_config.from_token.room_key
+ else:
+ pagin_config.from_token = (
+ yield self.hs.get_event_sources().get_current_token_for_room(
+ room_id=room_id
+ )
+ )
+ room_token = pagin_config.from_token.room_key
+
+ room_token = RoomStreamToken.parse(room_token)
+
+ pagin_config.from_token = pagin_config.from_token.copy_and_replace(
+ "room_key", str(room_token)
+ )
+
+ source_config = pagin_config.get_source_config("room")
+
+ with (yield self.pagination_lock.read(room_id)):
+ membership, member_event_id = yield self.auth.check_in_room_or_world_readable(
+ room_id, user_id
+ )
+
+ if source_config.direction == 'b':
+ # if we're going backwards, we might need to backfill. This
+ # requires that we have a topo token.
+ if room_token.topological:
+ max_topo = room_token.topological
+ else:
+ max_topo = yield self.store.get_max_topological_token(
+ room_id, room_token.stream
+ )
+
+ if membership == Membership.LEAVE:
+ # If they have left the room then clamp the token to be before
+ # they left the room, to save the effort of loading from the
+ # database.
+ leave_token = yield self.store.get_topological_token_for_event(
+ member_event_id
+ )
+ leave_token = RoomStreamToken.parse(leave_token)
+ if leave_token.topological < max_topo:
+ source_config.from_key = str(leave_token)
+
+ yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ room_id, max_topo
+ )
+
+ events, next_key = yield self.store.paginate_room_events(
+ room_id=room_id,
+ from_key=source_config.from_key,
+ to_key=source_config.to_key,
+ direction=source_config.direction,
+ limit=source_config.limit,
+ event_filter=event_filter,
+ )
+
+ next_token = pagin_config.from_token.copy_and_replace(
+ "room_key", next_key
+ )
+
+ if not events:
+ defer.returnValue({
+ "chunk": [],
+ "start": pagin_config.from_token.to_string(),
+ "end": next_token.to_string(),
+ })
+
+ if event_filter:
+ events = event_filter.filter(events)
+
+ events = yield filter_events_for_client(
+ self.store,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
+ )
+
+ time_now = self.clock.time_msec()
+
+ chunk = {
+ "chunk": [
+ serialize_event(e, time_now, as_client_event)
+ for e in events
+ ],
+ "start": pagin_config.from_token.to_string(),
+ "end": next_token.to_string(),
+ }
+
+ defer.returnValue(chunk)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index c7c0b0a1e2..3732830194 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -22,41 +22,44 @@ The methods that define policy are:
- should_notify
"""
-from twisted.internet import defer, reactor
+import logging
from contextlib import contextmanager
-from synapse.api.errors import SynapseError
+from six import iteritems, itervalues
+
+from prometheus_client import Counter
+
+from twisted.internet import defer
+
from synapse.api.constants import PresenceState
+from synapse.api.errors import SynapseError
+from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState
-
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.types import UserID, get_domain_from_id
from synapse.util.async import Linearizer
-from synapse.util.logcontext import preserve_fn
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.logcontext import run_in_background
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
-from synapse.types import UserID, get_domain_from_id
-import synapse.metrics
-
-import logging
-
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-notified_presence_counter = metrics.register_counter("notified_presence")
-federation_presence_out_counter = metrics.register_counter("federation_presence_out")
-presence_updates_counter = metrics.register_counter("presence_updates")
-timers_fired_counter = metrics.register_counter("timers_fired")
-federation_presence_counter = metrics.register_counter("federation_presence")
-bump_active_time_counter = metrics.register_counter("bump_active_time")
+notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "")
+federation_presence_out_counter = Counter(
+ "synapse_handler_presence_federation_presence_out", "")
+presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "")
+timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "")
+federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "")
+bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "")
-get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
+get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"])
-notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"])
-state_transition_counter = metrics.register_counter(
- "state_transition", labels=["from", "to"]
+notify_reason_counter = Counter(
+ "synapse_handler_presence_notify_reason", "", ["reason"])
+state_transition_counter = Counter(
+ "synapse_handler_presence_state_transition", "", ["from", "to"]
)
@@ -87,35 +90,40 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object):
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
- self.replication = hs.get_replication_layer()
self.federation = hs.get_federation_sender()
-
self.state = hs.get_state_handler()
- self.replication.register_edu_handler(
+ federation_registry = hs.get_federation_registry()
+
+ federation_registry.register_edu_handler(
"m.presence", self.incoming_presence
)
- self.replication.register_edu_handler(
+ federation_registry.register_edu_handler(
"m.presence_invite",
lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
- self.replication.register_edu_handler(
+ federation_registry.register_edu_handler(
"m.presence_accept",
lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
- self.replication.register_edu_handler(
+ federation_registry.register_edu_handler(
"m.presence_deny",
lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]),
@@ -136,8 +144,9 @@ class PresenceHandler(object):
for state in active_presence
}
- metrics.register_callback(
- "user_to_current_state_size", lambda: len(self.user_to_current_state)
+ LaterGauge(
+ "synapse_handlers_presence_user_to_current_state_size", "", [],
+ lambda: len(self.user_to_current_state)
)
now = self.clock.time_msec()
@@ -169,7 +178,7 @@ class PresenceHandler(object):
# have not yet been persisted
self.unpersisted_users_changes = set()
- reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self._on_shutdown)
self.serial_to_user = {}
self._next_serial = 1
@@ -207,7 +216,8 @@ class PresenceHandler(object):
60 * 1000,
)
- metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
+ LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
+ lambda: len(self.wheel_timer))
@defer.inlineCallbacks
def _on_shutdown(self):
@@ -254,6 +264,14 @@ class PresenceHandler(object):
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks
+ def _update_states_and_catch_exception(self, new_states):
+ try:
+ res = yield self._update_states(new_states)
+ defer.returnValue(res)
+ except Exception:
+ logger.exception("Error updating presence")
+
+ @defer.inlineCallbacks
def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
@@ -302,11 +320,11 @@ class PresenceHandler(object):
# TODO: We should probably ensure there are no races hereafter
- presence_updates_counter.inc_by(len(new_states))
+ presence_updates_counter.inc(len(new_states))
if to_notify:
- notified_presence_counter.inc_by(len(to_notify))
- yield self._persist_and_notify(to_notify.values())
+ notified_presence_counter.inc(len(to_notify))
+ yield self._persist_and_notify(list(to_notify.values()))
self.unpersisted_users_changes |= set(s.user_id for s in new_states)
self.unpersisted_users_changes -= set(to_notify.keys())
@@ -316,7 +334,7 @@ class PresenceHandler(object):
if user_id not in to_notify
}
if to_federation_ping:
- federation_presence_out_counter.inc_by(len(to_federation_ping))
+ federation_presence_out_counter.inc(len(to_federation_ping))
self._push_to_remotes(to_federation_ping.values())
@@ -354,7 +372,7 @@ class PresenceHandler(object):
for user_id in users_to_check
]
- timers_fired_counter.inc_by(len(states))
+ timers_fired_counter.inc(len(states))
changes = handle_timeouts(
states,
@@ -363,8 +381,8 @@ class PresenceHandler(object):
now=now,
)
- preserve_fn(self._update_states)(changes)
- except:
+ run_in_background(self._update_states_and_catch_exception, changes)
+ except Exception:
logger.exception("Exception in _handle_timeouts loop")
@defer.inlineCallbacks
@@ -421,20 +439,23 @@ class PresenceHandler(object):
@defer.inlineCallbacks
def _end():
- if affect_presence:
+ try:
self.user_to_num_current_syncs[user_id] -= 1
prev_state = yield self.current_state_for_user(user_id)
yield self._update_states([prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec(),
)])
+ except Exception:
+ logger.exception("Error updating presence after sync")
@contextmanager
def _user_syncing():
try:
yield
finally:
- preserve_fn(_end)()
+ if affect_presence:
+ run_in_background(_end)
defer.returnValue(_user_syncing())
@@ -452,61 +473,6 @@ class PresenceHandler(object):
return syncing_user_ids
@defer.inlineCallbacks
- def update_external_syncs(self, process_id, syncing_user_ids):
- """Update the syncing users for an external process
-
- Args:
- process_id(str): An identifier for the process the users are
- syncing against. This allows synapse to process updates
- as user start and stop syncing against a given process.
- syncing_user_ids(set(str)): The set of user_ids that are
- currently syncing on that server.
- """
-
- # Grab the previous list of user_ids that were syncing on that process
- prev_syncing_user_ids = (
- self.external_process_to_current_syncs.get(process_id, set())
- )
- # Grab the current presence state for both the users that are syncing
- # now and the users that were syncing before this update.
- prev_states = yield self.current_state_for_users(
- syncing_user_ids | prev_syncing_user_ids
- )
- updates = []
- time_now_ms = self.clock.time_msec()
-
- # For each new user that is syncing check if we need to mark them as
- # being online.
- for new_user_id in syncing_user_ids - prev_syncing_user_ids:
- prev_state = prev_states[new_user_id]
- if prev_state.state == PresenceState.OFFLINE:
- updates.append(prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=time_now_ms,
- last_user_sync_ts=time_now_ms,
- ))
- else:
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=time_now_ms,
- ))
-
- # For each user that is still syncing or stopped syncing update the
- # last sync time so that we will correctly apply the grace period when
- # they stop syncing.
- for old_user_id in prev_syncing_user_ids:
- prev_state = prev_states[old_user_id]
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=time_now_ms,
- ))
-
- yield self._update_states(updates)
-
- # Update the last updated time for the process. We expire the entries
- # if we don't receive an update in the given timeframe.
- self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
- self.external_process_to_current_syncs[process_id] = syncing_user_ids
-
- @defer.inlineCallbacks
def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
"""Update the syncing users for an external process as a delta.
@@ -569,7 +535,7 @@ class PresenceHandler(object):
prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
)
- for prev_state in prev_states.itervalues()
+ for prev_state in itervalues(prev_states)
])
self.external_process_last_updated_ms.pop(process_id, None)
@@ -592,14 +558,14 @@ class PresenceHandler(object):
for user_id in user_ids
}
- missing = [user_id for user_id, state in states.iteritems() if not state]
+ missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = yield self.store.get_presence_for_users(missing)
states.update(res)
- missing = [user_id for user_id, state in states.iteritems() if not state]
+ missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id)
@@ -695,7 +661,7 @@ class PresenceHandler(object):
updates.append(prev_state.copy_and_replace(**new_fields))
if updates:
- federation_presence_counter.inc_by(len(updates))
+ federation_presence_counter.inc(len(updates))
yield self._update_states(updates)
@defer.inlineCallbacks
@@ -720,7 +686,7 @@ class PresenceHandler(object):
"""
updates = yield self.current_state_for_users(target_user_ids)
- updates = updates.values()
+ updates = list(updates.values())
for user_id in set(target_user_ids) - set(u.user_id for u in updates):
updates.append(UserPresenceState.default(user_id))
@@ -786,11 +752,11 @@ class PresenceHandler(object):
self._push_to_remotes([state])
else:
user_ids = yield self.store.get_users_in_room(room_id)
- user_ids = filter(self.is_mine_id, user_ids)
+ user_ids = list(filter(self.is_mine_id, user_ids))
states = yield self.current_state_for_users(user_ids)
- self._push_to_remotes(states.values())
+ self._push_to_remotes(list(states.values()))
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
@@ -970,28 +936,28 @@ def should_notify(old_state, new_state):
return False
if old_state.status_msg != new_state.status_msg:
- notify_reason_counter.inc("status_msg_change")
+ notify_reason_counter.labels("status_msg_change").inc()
return True
if old_state.state != new_state.state:
- notify_reason_counter.inc("state_change")
- state_transition_counter.inc(old_state.state, new_state.state)
+ notify_reason_counter.labels("state_change").inc()
+ state_transition_counter.labels(old_state.state, new_state.state).inc()
return True
if old_state.state == PresenceState.ONLINE:
if new_state.currently_active != old_state.currently_active:
- notify_reason_counter.inc("current_active_change")
+ notify_reason_counter.labels("current_active_change").inc()
return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Only notify about last active bumps if we're not currently acive
if not new_state.currently_active:
- notify_reason_counter.inc("last_active_change_online")
+ notify_reason_counter.labels("last_active_change_online").inc()
return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped.
- notify_reason_counter.inc("last_active_change_not_online")
+ notify_reason_counter.labels("last_active_change_not_online").inc()
return True
return False
@@ -1065,14 +1031,14 @@ class PresenceEventSource(object):
if changed is not None and len(changed) < 500:
# For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
- get_updates_counter.inc("stream")
+ get_updates_counter.labels("stream").inc()
for other_user_id in changed:
if other_user_id in users_interested_in:
user_ids_changed.add(other_user_id)
else:
# Too many possible updates. Find all users we can see and check
# if any of them have changed.
- get_updates_counter.inc("full")
+ get_updates_counter.labels("full").inc()
if from_key:
user_ids_changed = stream_change_cache.get_entities_changed(
@@ -1084,10 +1050,10 @@ class PresenceEventSource(object):
updates = yield presence.current_state_for_users(user_ids_changed)
if include_offline:
- defer.returnValue((updates.values(), max_token))
+ defer.returnValue((list(updates.values()), max_token))
else:
defer.returnValue(([
- s for s in updates.itervalues()
+ s for s in itervalues(updates)
if s.state != PresenceState.OFFLINE
], max_token))
@@ -1145,7 +1111,7 @@ def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
if new_state:
changes[state.user_id] = new_state
- return changes.values()
+ return list(changes.values())
def handle_timeout(state, is_mine, syncing_user_ids, now):
@@ -1199,7 +1165,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
)
changed = True
else:
- # We expect to be poked occaisonally by the other side.
+ # We expect to be poked occasionally by the other side.
# This is to protect against forgetful/buggy servers, so that
# no one gets stuck online forever.
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
@@ -1344,11 +1310,11 @@ def get_interested_remotes(store, states, state_handler):
# hosts in those rooms.
room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
- for room_id, states in room_ids_to_states.iteritems():
+ for room_id, states in iteritems(room_ids_to_states):
hosts = yield state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states))
- for user_id, states in users_to_states.iteritems():
+ for user_id, states in iteritems(users_to_states):
host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states))
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 7abee98dea..859f6d2b2e 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -17,25 +17,88 @@ import logging
from twisted.internet import defer
-import synapse.types
-from synapse.api.errors import SynapseError, AuthError, CodeMessageException
-from synapse.types import UserID
-from ._base import BaseHandler
+from synapse.api.errors import AuthError, CodeMessageException, SynapseError
+from synapse.types import UserID, get_domain_from_id
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
class ProfileHandler(BaseHandler):
+ PROFILE_UPDATE_MS = 60 * 1000
+ PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs):
super(ProfileHandler, self).__init__(hs)
- self.federation = hs.get_replication_layer()
- self.federation.register_query_handler(
+ self.federation = hs.get_federation_client()
+ hs.get_federation_registry().register_query_handler(
"profile", self.on_profile_query
)
+ self.user_directory_handler = hs.get_user_directory_handler()
+
+ if hs.config.worker_app is None:
+ self.clock.looping_call(
+ self._update_remote_profile_cache, self.PROFILE_UPDATE_MS,
+ )
+
+ @defer.inlineCallbacks
+ def get_profile(self, user_id):
+ target_user = UserID.from_string(user_id)
+ if self.hs.is_mine(target_user):
+ displayname = yield self.store.get_profile_displayname(
+ target_user.localpart
+ )
+ avatar_url = yield self.store.get_profile_avatar_url(
+ target_user.localpart
+ )
+
+ defer.returnValue({
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ })
+ else:
+ try:
+ result = yield self.federation.make_query(
+ destination=target_user.domain,
+ query_type="profile",
+ args={
+ "user_id": user_id,
+ },
+ ignore_backoff=True,
+ )
+ defer.returnValue(result)
+ except CodeMessageException as e:
+ if e.code != 404:
+ logger.exception("Failed to get displayname")
+
+ raise
+
+ @defer.inlineCallbacks
+ def get_profile_from_cache(self, user_id):
+ """Get the profile information from our local cache. If the user is
+ ours then the profile information will always be corect. Otherwise,
+ it may be out of date/missing.
+ """
+ target_user = UserID.from_string(user_id)
+ if self.hs.is_mine(target_user):
+ displayname = yield self.store.get_profile_displayname(
+ target_user.localpart
+ )
+ avatar_url = yield self.store.get_profile_avatar_url(
+ target_user.localpart
+ )
+
+ defer.returnValue({
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ })
+ else:
+ profile = yield self.store.get_from_remote_profile_cache(user_id)
+ defer.returnValue(profile or {})
+
@defer.inlineCallbacks
def get_displayname(self, target_user):
if self.hs.is_mine(target_user):
@@ -60,7 +123,7 @@ class ProfileHandler(BaseHandler):
logger.exception("Failed to get displayname")
raise
- except:
+ except Exception:
logger.exception("Failed to get displayname")
else:
defer.returnValue(result["displayname"])
@@ -82,7 +145,13 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname
)
- yield self._update_join_states(requester)
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ target_user.to_string(), profile
+ )
+
+ yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@@ -107,7 +176,7 @@ class ProfileHandler(BaseHandler):
if e.code != 404:
logger.exception("Failed to get avatar_url")
raise
- except:
+ except Exception:
logger.exception("Failed to get avatar_url")
defer.returnValue(result["avatar_url"])
@@ -126,7 +195,13 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url
)
- yield self._update_join_states(requester)
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ target_user.to_string(), profile
+ )
+
+ yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def on_profile_query(self, args):
@@ -151,28 +226,24 @@ class ProfileHandler(BaseHandler):
defer.returnValue(response)
@defer.inlineCallbacks
- def _update_join_states(self, requester):
- user = requester.user
- if not self.hs.is_mine(user):
+ def _update_join_states(self, requester, target_user):
+ if not self.hs.is_mine(target_user):
return
yield self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user(
- user.to_string(),
+ target_user.to_string(),
)
for room_id in room_ids:
- handler = self.hs.get_handlers().room_member_handler
+ handler = self.hs.get_room_member_handler()
try:
- # Assume the user isn't a guest because we don't let guests set
- # profile or avatar data.
- # XXX why are we recreating `requester` here for each room?
- # what was wrong with the `requester` we were passed?
- requester = synapse.types.create_requester(user)
+ # Assume the target_user isn't a guest,
+ # because we don't let guests set profile or avatar data.
yield handler.update_membership(
requester,
- user,
+ target_user,
room_id,
"join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic.
@@ -182,3 +253,44 @@ class ProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s",
room_id, str(e.message)
)
+
+ def _update_remote_profile_cache(self):
+ """Called periodically to check profiles of remote users we haven't
+ checked in a while.
+ """
+ entries = yield self.store.get_remote_profile_cache_entries_that_expire(
+ last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
+ )
+
+ for user_id, displayname, avatar_url in entries:
+ is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
+ user_id,
+ )
+ if not is_subscribed:
+ yield self.store.maybe_delete_remote_profile_cache(user_id)
+ continue
+
+ try:
+ profile = yield self.federation.make_query(
+ destination=get_domain_from_id(user_id),
+ query_type="profile",
+ args={
+ "user_id": user_id,
+ },
+ ignore_backoff=True,
+ )
+ except Exception:
+ logger.exception("Failed to get avatar_url")
+
+ yield self.store.update_remote_profile_cache(
+ user_id, displayname, avatar_url
+ )
+ continue
+
+ new_name = profile.get("displayname")
+ new_avatar = profile.get("avatar_url")
+
+ # We always hit update to update the last_check timestamp
+ yield self.store.update_remote_profile_cache(
+ user_id, new_name, new_avatar
+ )
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index b5b0303d54..995460f82a 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseHandler
+import logging
from twisted.internet import defer
from synapse.util.async import Linearizer
-import logging
+from ._base import BaseHandler
+
logger = logging.getLogger(__name__)
@@ -41,9 +42,9 @@ class ReadMarkerHandler(BaseHandler):
"""
with (yield self.read_marker_linearizer.queue((room_id, user_id))):
- account_data = yield self.store.get_account_data_for_room(user_id, room_id)
-
- existing_read_marker = account_data.get("m.fully_read", None)
+ existing_read_marker = yield self.store.get_account_data_for_room_and_type(
+ user_id, room_id, "m.fully_read",
+ )
should_update = True
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index e1cd3a48e9..cb905a3903 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,16 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from ._base import BaseHandler
+import logging
from twisted.internet import defer
-from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import get_domain_from_id
+from synapse.util import logcontext
+from synapse.util.logcontext import PreserveLoggingContext
-import logging
-
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -34,7 +33,7 @@ class ReceiptsHandler(BaseHandler):
self.store = hs.get_datastore()
self.hs = hs
self.federation = hs.get_federation_sender()
- hs.get_replication_layer().register_edu_handler(
+ hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt
)
self.clock = self.hs.get_clock()
@@ -59,6 +58,8 @@ class ReceiptsHandler(BaseHandler):
is_new = yield self._handle_new_receipts([receipt])
if is_new:
+ # fire off a process in the background to send the receipt to
+ # remote servers
self._push_remotes([receipt])
@defer.inlineCallbacks
@@ -126,42 +127,46 @@ class ReceiptsHandler(BaseHandler):
defer.returnValue(True)
+ @logcontext.preserve_fn # caller should not yield on this
@defer.inlineCallbacks
def _push_remotes(self, receipts):
"""Given a list of receipts, works out which remote servers should be
poked and pokes them.
"""
- # TODO: Some of this stuff should be coallesced.
- for receipt in receipts:
- room_id = receipt["room_id"]
- receipt_type = receipt["receipt_type"]
- user_id = receipt["user_id"]
- event_ids = receipt["event_ids"]
- data = receipt["data"]
-
- users = yield self.state.get_current_user_in_room(room_id)
- remotedomains = set(get_domain_from_id(u) for u in users)
- remotedomains = remotedomains.copy()
- remotedomains.discard(self.server_name)
-
- logger.debug("Sending receipt to: %r", remotedomains)
-
- for domain in remotedomains:
- self.federation.send_edu(
- destination=domain,
- edu_type="m.receipt",
- content={
- room_id: {
- receipt_type: {
- user_id: {
- "event_ids": event_ids,
- "data": data,
+ try:
+ # TODO: Some of this stuff should be coallesced.
+ for receipt in receipts:
+ room_id = receipt["room_id"]
+ receipt_type = receipt["receipt_type"]
+ user_id = receipt["user_id"]
+ event_ids = receipt["event_ids"]
+ data = receipt["data"]
+
+ users = yield self.state.get_current_user_in_room(room_id)
+ remotedomains = set(get_domain_from_id(u) for u in users)
+ remotedomains = remotedomains.copy()
+ remotedomains.discard(self.server_name)
+
+ logger.debug("Sending receipt to: %r", remotedomains)
+
+ for domain in remotedomains:
+ self.federation.send_edu(
+ destination=domain,
+ edu_type="m.receipt",
+ content={
+ room_id: {
+ receipt_type: {
+ user_id: {
+ "event_ids": event_ids,
+ "data": data,
+ }
}
- }
+ },
},
- },
- key=(room_id, receipt_type, user_id),
- )
+ key=(room_id, receipt_type, user_id),
+ )
+ except Exception:
+ logger.exception("Error pushing receipts to remote servers")
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ee3a2269a8..7caff0cbc8 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,16 +15,22 @@
"""Contains functions for registering clients."""
import logging
-import urllib
from twisted.internet import defer
+from synapse import types
from synapse.api.errors import (
- AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
+ AuthError,
+ Codes,
+ InvalidCaptchaError,
+ RegistrationError,
+ SynapseError,
)
from synapse.http.client import CaptchaServerHttpClient
-from synapse.types import UserID
-from synapse.util.async import run_on_reactor
+from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.util.async import Linearizer
+from synapse.util.threepids import check_3pid_allowed
+
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -33,24 +39,35 @@ logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
+ self.profile_handler = hs.get_profile_handler()
+ self.user_directory_handler = hs.get_user_directory_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None
self.macaroon_gen = hs.get_macaroon_generator()
+ self._generate_user_id_linearizer = Linearizer(
+ name="_generate_user_id_linearizer",
+ )
+ self._server_notices_mxid = hs.config.server_notices_mxid
+
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
- yield run_on_reactor()
-
- if urllib.quote(localpart.encode('utf-8')) != localpart:
+ if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
- "User ID can only contain characters a-z, 0-9, or '_-./'",
+ "User ID can only contain characters a-z, 0-9, or '=_-./'",
Codes.INVALID_USERNAME
)
@@ -80,7 +97,7 @@ class RegistrationHandler(BaseHandler):
"A different user ID has already been registered for this session",
)
- yield self.check_user_id_not_appservice_exclusive(user_id)
+ self.check_user_id_not_appservice_exclusive(user_id)
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
@@ -127,10 +144,9 @@ class RegistrationHandler(BaseHandler):
Raises:
RegistrationError if there was a problem registering.
"""
- yield run_on_reactor()
password_hash = None
if password:
- password_hash = self.auth_handler().hash(password)
+ password_hash = yield self.auth_handler().hash(password)
if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token)
@@ -165,6 +181,13 @@ class RegistrationHandler(BaseHandler):
),
admin=admin,
)
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
else:
# autogen a sequential user ID
attempts = 0
@@ -192,10 +215,17 @@ class RegistrationHandler(BaseHandler):
token = None
attempts += 1
+ # auto-join the user to any rooms we're supposed to dump them into
+ fake_requester = create_requester(user_id)
+ for r in self.hs.config.auto_join_rooms:
+ try:
+ yield self._join_user_to_room(fake_requester, r)
+ except Exception as e:
+ logger.error("Failed to join new user to %r: %r", r, e)
+
# We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding
# rather than there being consistent matrix-wide ones, so we don't.
-
defer.returnValue((user_id, token))
@defer.inlineCallbacks
@@ -253,11 +283,10 @@ class RegistrationHandler(BaseHandler):
"""
Registers email_id as SAML2 Based Auth.
"""
- if urllib.quote(localpart) != localpart:
+ if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
- "User ID must only contain characters which do not"
- " require URL encoding."
+ "User ID can only contain characters a-z, 0-9, or '=_-./'",
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -286,12 +315,12 @@ class RegistrationHandler(BaseHandler):
"""
for c in threepidCreds:
- logger.info("validating theeepidcred sid %s on id server %s",
+ logger.info("validating threepidcred sid %s on id server %s",
c['sid'], c['idServer'])
try:
identity_handler = self.hs.get_handlers().identity_handler
threepid = yield identity_handler.threepid_from_creds(c)
- except:
+ except Exception:
logger.exception("Couldn't validate 3pid")
raise RegistrationError(400, "Couldn't validate 3pid")
@@ -300,6 +329,11 @@ class RegistrationHandler(BaseHandler):
logger.info("got threepid with medium '%s' and address '%s'",
threepid['medium'], threepid['address'])
+ if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
+ raise RegistrationError(
+ 403, "Third party identifier is not allowed"
+ )
+
@defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server.
@@ -314,6 +348,14 @@ class RegistrationHandler(BaseHandler):
yield identity_handler.bind_threepid(c, user_id)
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
+ # don't allow people to register the server notices mxid
+ if self._server_notices_mxid is not None:
+ if user_id == self._server_notices_mxid:
+ raise SynapseError(
+ 400, "This user ID is reserved.",
+ errcode=Codes.EXCLUSIVE
+ )
+
# valid user IDs must not clash with any user ID namespaces claimed by
# application services.
services = self.store.get_app_services()
@@ -332,9 +374,11 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
- self._next_generated_user_id = (
- yield self.store.find_next_generated_user_id_localpart()
- )
+ with (yield self._generate_user_id_linearizer.queue(())):
+ if reseed or self._next_generated_user_id is None:
+ self._next_generated_user_id = (
+ yield self.store.find_next_generated_user_id_localpart()
+ )
id = self._next_generated_user_id
self._next_generated_user_id += 1
@@ -391,8 +435,6 @@ class RegistrationHandler(BaseHandler):
Raises:
RegistrationError if there was a problem registering.
"""
- yield run_on_reactor()
-
if localpart is None:
raise SynapseError(400, "Request must include user id")
@@ -418,13 +460,12 @@ class RegistrationHandler(BaseHandler):
create_profile_with_localpart=user.localpart,
)
else:
- yield self.store.user_delete_access_tokens(user_id=user_id)
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
- profile_handler = self.hs.get_handlers().profile_handler
- yield profile_handler.set_displayname(
+ yield self.profile_handler.set_displayname(
user, requester, displayname, by_admin=True,
)
@@ -434,16 +475,59 @@ class RegistrationHandler(BaseHandler):
return self.hs.get_auth_handler()
@defer.inlineCallbacks
- def guest_access_token_for(self, medium, address, inviter_user_id):
+ def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
+ """Get a guest access token for a 3PID, creating a guest account if
+ one doesn't already exist.
+
+ Args:
+ medium (str)
+ address (str)
+ inviter_user_id (str): The user ID who is trying to invite the
+ 3PID
+
+ Returns:
+ Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+ 3PID guest account.
+ """
access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token:
- defer.returnValue(access_token)
+ user_info = yield self.auth.get_user_by_access_token(
+ access_token
+ )
+
+ defer.returnValue((user_info["user"].to_string(), access_token))
- _, access_token = yield self.register(
+ user_id, access_token = yield self.register(
generate_token=True,
make_guest=True
)
access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id
)
- defer.returnValue(access_token)
+
+ defer.returnValue((user_id, access_token))
+
+ @defer.inlineCallbacks
+ def _join_user_to_room(self, requester, room_identifier):
+ room_id = None
+ room_member_handler = self.hs.get_room_member_handler()
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, remote_room_hosts = (
+ yield room_member_handler.lookup_room_alias(room_alias)
+ )
+ room_id = room_id.to_string()
+ else:
+ raise SynapseError(400, "%s was not legal room ID or room alias" % (
+ room_identifier,
+ ))
+
+ yield room_member_handler.update_membership(
+ requester=requester,
+ target=requester.user,
+ room_id=room_id,
+ remote_room_hosts=remote_room_hosts,
+ action="join",
+ )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 5698d28088..6150b7e226 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,23 +15,20 @@
# limitations under the License.
"""Contains functions for performing events on rooms."""
-from twisted.internet import defer
+import logging
+import math
+import string
+from collections import OrderedDict
-from ._base import BaseHandler
+from twisted.internet import defer
-from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
-from synapse.api.constants import (
- EventTypes, JoinRules, RoomCreationPreset
-)
-from synapse.api.errors import AuthError, StoreError, SynapseError
+from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
+from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
+from synapse.types import RoomAlias, RoomID, RoomStreamToken, UserID
from synapse.util import stringutils
from synapse.visibility import filter_events_for_client
-from collections import OrderedDict
-
-import logging
-import math
-import string
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -60,21 +58,43 @@ class RoomCreationHandler(BaseHandler):
},
}
+ def __init__(self, hs):
+ super(RoomCreationHandler, self).__init__(hs)
+
+ self.spam_checker = hs.get_spam_checker()
+ self.event_creation_handler = hs.get_event_creation_handler()
+
@defer.inlineCallbacks
- def create_room(self, requester, config, ratelimit=True):
+ def create_room(self, requester, config, ratelimit=True,
+ creator_join_profile=None):
""" Creates a new room.
Args:
- requester (Requester): The user who requested the room creation.
+ requester (synapse.types.Requester):
+ The user who requested the room creation.
config (dict) : A dict of configuration options.
+ ratelimit (bool): set to False to disable the rate limiter
+
+ creator_join_profile (dict|None):
+ Set to override the displayname and avatar for the creating
+ user in this room. If unset, displayname and avatar will be
+ derived from the user's profile. If set, should contain the
+ values to go in the body of the 'join' event (typically
+ `avatar_url` and/or `displayname`.
+
Returns:
- The new room ID.
+ Deferred[dict]:
+ a dict containing the keys `room_id` and, if an alias was
+ requested, `room_alias`.
Raises:
SynapseError if the room ID couldn't be stored, or something went
horribly wrong.
"""
user_id = requester.user.to_string()
+ if not self.spam_checker.user_may_create_room(user_id):
+ raise SynapseError(403, "You are not permitted to create rooms")
+
if ratelimit:
yield self.ratelimit(requester)
@@ -83,7 +103,7 @@ class RoomCreationHandler(BaseHandler):
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
- room_alias = RoomAlias.create(
+ room_alias = RoomAlias(
config["room_alias_name"],
self.hs.hostname,
)
@@ -92,7 +112,11 @@ class RoomCreationHandler(BaseHandler):
)
if mapping:
- raise SynapseError(400, "Room alias already taken")
+ raise SynapseError(
+ 400,
+ "Room alias already taken",
+ Codes.ROOM_IN_USE
+ )
else:
room_alias = None
@@ -100,9 +124,13 @@ class RoomCreationHandler(BaseHandler):
for i in invite_list:
try:
UserID.from_string(i)
- except:
+ except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
+ yield self.event_creation_handler.assert_accepted_privacy_policy(
+ requester,
+ )
+
invite_3pid_list = config.get("invite_3pid", [])
visibility = config.get("visibility", None)
@@ -115,7 +143,7 @@ class RoomCreationHandler(BaseHandler):
while attempts < 5:
try:
random_string = stringutils.random_string(18)
- gen_room_id = RoomID.create(
+ gen_room_id = RoomID(
random_string,
self.hs.hostname,
)
@@ -155,25 +183,24 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {})
- msg_handler = self.hs.get_handlers().message_handler
- room_member_handler = self.hs.get_handlers().room_member_handler
+ room_member_handler = self.hs.get_room_member_handler()
yield self._send_events_for_new_room(
requester,
room_id,
- msg_handler,
room_member_handler,
preset_config=preset_config,
invite_list=invite_list,
initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
- power_level_content_override=config.get("power_level_content_override", {})
+ power_level_content_override=config.get("power_level_content_override", {}),
+ creator_join_profile=creator_join_profile,
)
if "name" in config:
name = config["name"]
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@@ -186,7 +213,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@@ -197,12 +224,12 @@ class RoomCreationHandler(BaseHandler):
},
ratelimit=False)
- content = {}
- is_direct = config.get("is_direct", None)
- if is_direct:
- content["is_direct"] = is_direct
-
for invitee in invite_list:
+ content = {}
+ is_direct = config.get("is_direct", None)
+ if is_direct:
+ content["is_direct"] = is_direct
+
yield room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
@@ -216,7 +243,7 @@ class RoomCreationHandler(BaseHandler):
id_server = invite_3pid["id_server"]
address = invite_3pid["address"]
medium = invite_3pid["medium"]
- yield self.hs.get_handlers().room_member_handler.do_3pid_invite(
+ yield self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@@ -241,7 +268,6 @@ class RoomCreationHandler(BaseHandler):
self,
creator, # A Requester object.
room_id,
- msg_handler,
room_member_handler,
preset_config,
invite_list,
@@ -249,6 +275,7 @@ class RoomCreationHandler(BaseHandler):
creation_content,
room_alias,
power_level_content_override,
+ creator_join_profile,
):
def create(etype, content, **kwargs):
e = {
@@ -264,7 +291,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
creator,
event,
ratelimit=False
@@ -292,6 +319,7 @@ class RoomCreationHandler(BaseHandler):
room_id,
"join",
ratelimit=False,
+ content=creator_join_profile,
)
# We treat the power levels override specially as this needs to be one
@@ -367,7 +395,11 @@ class RoomCreationHandler(BaseHandler):
)
-class RoomContextHandler(BaseHandler):
+class RoomContextHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+
@defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit):
"""Retrieves events, pagination tokens and state around a given event
@@ -428,7 +460,7 @@ class RoomContextHandler(BaseHandler):
state = yield self.store.get_state_for_events(
[last_event_id], None
)
- results["state"] = state[last_event_id].values()
+ results["state"] = list(state[last_event_id].values())
results["start"] = now_token.copy_and_replace(
"room_key", results["start"]
@@ -468,12 +500,9 @@ class RoomEventSource(object):
user.to_string()
)
if app_service:
- events, end_key = yield self.store.get_appservice_room_stream(
- service=app_service,
- from_key=from_key,
- to_key=to_key,
- limit=limit,
- )
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
else:
room_events = yield self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 516cd9a6ac..828229f5c3 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -13,23 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import logging
+from collections import namedtuple
-from ._base import BaseHandler
+from six import iteritems
+from six.moves import range
+
+import msgpack
+from unpaddedbase64 import decode_base64, encode_base64
+
+from twisted.internet import defer
-from synapse.api.constants import (
- EventTypes, JoinRules,
-)
+from synapse.api.constants import EventTypes, JoinRules
+from synapse.types import ThirdPartyInstanceID
from synapse.util.async import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache
-from synapse.types import ThirdPartyInstanceID
-
-from collections import namedtuple
-from unpaddedbase64 import encode_base64, decode_base64
-import logging
-import msgpack
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -37,18 +38,19 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
# This is used to indicate we should only return rooms published to the main list.
-EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
+EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
- self.response_cache = ResponseCache(hs)
- self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
+ self.response_cache = ResponseCache(hs, "room_list")
+ self.remote_response_cache = ResponseCache(hs, "remote_room_list",
+ timeout_ms=30 * 1000)
def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None,
- network_tuple=EMTPY_THIRD_PARTY_ID,):
+ network_tuple=EMPTY_THIRD_PARTY_ID,):
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
@@ -70,25 +72,22 @@ class RoomListHandler(BaseHandler):
if search_filter:
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
+ logger.info("Bypassing cache as search request.")
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple,
)
key = (limit, since_token, network_tuple)
- result = self.response_cache.get(key)
- if not result:
- result = self.response_cache.set(
- key,
- self._get_public_room_list(
- limit, since_token, network_tuple=network_tuple
- )
- )
- return result
+ return self.response_cache.wrap(
+ key,
+ self._get_public_room_list,
+ limit, since_token, network_tuple=network_tuple,
+ )
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None,
- network_tuple=EMTPY_THIRD_PARTY_ID,):
+ network_tuple=EMPTY_THIRD_PARTY_ID,):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
@@ -149,6 +148,8 @@ class RoomListHandler(BaseHandler):
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
+ logger.info("Getting ordering for %i rooms since %s",
+ len(room_ids), stream_token)
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
@@ -176,34 +177,43 @@ class RoomListHandler(BaseHandler):
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse()
- # Actually generate the entries. _append_room_entry_to_chunk will append to
- # chunk but will stop if len(chunk) > limit
- chunk = []
- if limit and not search_filter:
+ logger.info("After sorting and filtering, %i rooms remain",
+ len(rooms_to_scan))
+
+ # _append_room_entry_to_chunk will append to chunk but will stop if
+ # len(chunk) > limit
+ #
+ # Normally we will generate enough results on the first iteration here,
+ # but if there is a search filter, _append_room_entry_to_chunk may
+ # filter some results out, in which case we loop again.
+ #
+ # We don't want to scan over the entire range either as that
+ # would potentially waste a lot of work.
+ #
+ # XXX if there is no limit, we may end up DoSing the server with
+ # calls to get_current_state_ids for every single room on the
+ # server. Surely we should cap this somehow?
+ #
+ if limit:
step = limit + 1
- for i in xrange(0, len(rooms_to_scan), step):
- # We iterate here because the vast majority of cases we'll stop
- # at first iteration, but occaisonally _append_room_entry_to_chunk
- # won't append to the chunk and so we need to loop again.
- # We don't want to scan over the entire range either as that
- # would potentially waste a lot of work.
- yield concurrently_execute(
- lambda r: self._append_room_entry_to_chunk(
- r, rooms_to_num_joined[r],
- chunk, limit, search_filter
- ),
- rooms_to_scan[i:i + step], 10
- )
- if len(chunk) >= limit + 1:
- break
else:
+ # step cannot be zero
+ step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
+
+ chunk = []
+ for i in range(0, len(rooms_to_scan), step):
+ batch = rooms_to_scan[i:i + step]
+ logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
- rooms_to_scan, 5
+ batch, 5,
)
+ logger.info("Now %i rooms in result", len(chunk))
+ if len(chunk) >= limit + 1:
+ break
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
@@ -276,13 +286,14 @@ class RoomListHandler(BaseHandler):
# We've already got enough, so lets just drop it.
return
- result = yield self._generate_room_entry(room_id, num_joined_users)
+ result = yield self.generate_room_entry(room_id, num_joined_users)
if result and _matches_room_entry(result, search_filter):
chunk.append(result)
@cachedInlineCallbacks(num_args=1, cache_context=True)
- def _generate_room_entry(self, room_id, num_joined_users, cache_context):
+ def generate_room_entry(self, room_id, num_joined_users, cache_context,
+ with_alias=True, allow_private=False):
"""Returns the entry for a room
"""
result = {
@@ -295,7 +306,7 @@ class RoomListHandler(BaseHandler):
)
event_map = yield self.store.get_events([
- event_id for key, event_id in current_state_ids.iteritems()
+ event_id for key, event_id in iteritems(current_state_ids)
if key[0] in (
EventTypes.JoinRules,
EventTypes.Name,
@@ -316,14 +327,15 @@ class RoomListHandler(BaseHandler):
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
- if join_rule and join_rule != JoinRules.PUBLIC:
+ if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
- aliases = yield self.store.get_aliases_for_room(
- room_id, on_invalidate=cache_context.invalidate
- )
- if aliases:
- result["aliases"] = aliases
+ if with_alias:
+ aliases = yield self.store.get_aliases_for_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+ if aliases:
+ result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
@@ -391,7 +403,7 @@ class RoomListHandler(BaseHandler):
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None, include_all_networks=False,
third_party_instance_id=None,):
- repl_layer = self.hs.get_replication_layer()
+ repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
@@ -404,18 +416,14 @@ class RoomListHandler(BaseHandler):
server_name, limit, since_token, include_all_networks,
third_party_instance_id,
)
- result = self.remote_response_cache.get(key)
- if not result:
- result = self.remote_response_cache.set(
- key,
- repl_layer.get_public_rooms(
- server_name, limit=limit, since_token=since_token,
- search_filter=search_filter,
- include_all_networks=include_all_networks,
- third_party_instance_id=third_party_instance_id,
- )
- )
- return result
+ return self.remote_response_cache.wrap(
+ key,
+ repl_layer.get_public_rooms,
+ server_name, limit=limit, since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
+ )
class RoomListNextBatch(namedtuple("RoomListNextBatch", (
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1b8dfa8254..0d4a3f4677 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,63 +14,161 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import abc
import logging
+from six.moves import http_client
+
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
-from twisted.internet import defer
from unpaddedbase64 import decode_base64
+from twisted.internet import defer
+
+import synapse.server
import synapse.types
-from synapse.api.constants import (
- EventTypes, Membership,
-)
-from synapse.api.errors import AuthError, SynapseError, Codes
-from synapse.types import UserID, RoomID
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.types import RoomID, UserID
from synapse.util.async import Linearizer
-from synapse.util.distributor import user_left_room, user_joined_room
-from ._base import BaseHandler
+from synapse.util.distributor import user_joined_room, user_left_room
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
-class RoomMemberHandler(BaseHandler):
+class RoomMemberHandler(object):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
+ __metaclass__ = abc.ABCMeta
+
def __init__(self, hs):
- super(RoomMemberHandler, self).__init__(hs)
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.state_handler = hs.get_state_handler()
+ self.config = hs.config
+ self.simple_http_client = hs.get_simple_http_client()
+
+ self.federation_handler = hs.get_handlers().federation_handler
+ self.directory_handler = hs.get_handlers().directory_handler
+ self.registration_handler = hs.get_handlers().registration_handler
+ self.profile_handler = hs.get_profile_handler()
+ self.event_creation_hander = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member")
self.clock = hs.get_clock()
+ self.spam_checker = hs.get_spam_checker()
+ self._server_notices_mxid = self.config.server_notices_mxid
- self.distributor = hs.get_distributor()
- self.distributor.declare("user_joined_room")
- self.distributor.declare("user_left_room")
+ @abc.abstractmethod
+ def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ """Try and join a room that this server is not in
+
+ Args:
+ requester (Requester)
+ remote_room_hosts (list[str]): List of servers that can be used
+ to join via.
+ room_id (str): Room that we are trying to join
+ user (UserID): User who is trying to join
+ content (dict): A dict that should be used as the content of the
+ join event.
+
+ Returns:
+ Deferred
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _remote_reject_invite(self, remote_room_hosts, room_id, target):
+ """Attempt to reject an invite for a room this server is not in. If we
+ fail to do so we locally mark the invite as rejected.
+
+ Args:
+ requester (Requester)
+ remote_room_hosts (list[str]): List of servers to use to try and
+ reject invite
+ room_id (str)
+ target (UserID): The user rejecting the invite
+
+ Returns:
+ Deferred[dict]: A dictionary to be returned to the client, may
+ include event_id etc, or nothing if we locally rejected
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+ """Get a guest access token for a 3PID, creating a guest account if
+ one doesn't already exist.
+
+ Args:
+ requester (Requester)
+ medium (str)
+ address (str)
+ inviter_user_id (str): The user ID who is trying to invite the
+ 3PID
+
+ Returns:
+ Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+ 3PID guest account.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _user_joined_room(self, target, room_id):
+ """Notifies distributor on master process that the user has joined the
+ room.
+
+ Args:
+ target (UserID)
+ room_id (str)
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _user_left_room(self, target, room_id):
+ """Notifies distributor on master process that the user has left the
+ room.
+
+ Args:
+ target (UserID)
+ room_id (str)
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
@defer.inlineCallbacks
def _local_membership_update(
self, requester, target, room_id, membership,
- prev_event_ids,
+ prev_events_and_hashes,
txn_id=None,
ratelimit=True,
content=None,
):
if content is None:
content = {}
- msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership
if requester.is_guest:
content["kind"] = "guest"
- event, context = yield msg_handler.create_event(
+ event, context = yield self.event_creation_hander.create_event(
requester,
{
"type": EventTypes.Member,
@@ -83,16 +182,18 @@ class RoomMemberHandler(BaseHandler):
},
token_id=requester.access_token_id,
txn_id=txn_id,
- prev_event_ids=prev_event_ids,
+ prev_events_and_hashes=prev_events_and_hashes,
)
# Check if this event matches the previous membership event for the user.
- duplicate = yield msg_handler.deduplicate_state_event(event, context)
+ duplicate = yield self.event_creation_hander.deduplicate_state_event(
+ event, context,
+ )
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate)
- yield msg_handler.handle_new_client_event(
+ yield self.event_creation_hander.handle_new_client_event(
requester,
event,
context,
@@ -100,7 +201,9 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit,
)
- prev_member_event_id = context.prev_state_ids.get(
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
+ prev_member_event_id = prev_state_ids.get(
(EventTypes.Member, target.to_string()),
None
)
@@ -114,33 +217,16 @@ class RoomMemberHandler(BaseHandler):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield user_joined_room(self.distributor, target, room_id)
+ yield self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
- user_left_room(self.distributor, target, room_id)
+ yield self._user_left_room(target, room_id)
defer.returnValue(event)
@defer.inlineCallbacks
- def remote_join(self, remote_room_hosts, room_id, user, content):
- if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
-
- # We don't do an auth check if we are doing an invite
- # join dance for now, since we're kinda implicitly checking
- # that we are allowed to join when we decide whether or not we
- # need to do the invite/join dance.
- yield self.hs.get_handlers().federation_handler.do_invite_join(
- remote_room_hosts,
- room_id,
- user.to_string(),
- content,
- )
- yield user_joined_room(self.distributor, user, room_id)
-
- @defer.inlineCallbacks
def update_membership(
self,
requester,
@@ -186,14 +272,19 @@ class RoomMemberHandler(BaseHandler):
content_specified = bool(content)
if content is None:
content = {}
+ else:
+ # We do a copy here as we potentially change some keys
+ # later on.
+ content = dict(content)
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
+ # if this is a join with a 3pid signature, we may need to turn a 3pid
+ # invite into a normal invite before we can handle the join.
if third_party_signed is not None:
- replication = self.hs.get_replication_layer()
- yield replication.exchange_third_party_invite(
+ yield self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
@@ -208,7 +299,51 @@ class RoomMemberHandler(BaseHandler):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ if effective_membership_state == Membership.INVITE:
+ # block any attempts to invite the server notices mxid
+ if target.to_string() == self._server_notices_mxid:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "Cannot invite this user",
+ )
+
+ block_invite = False
+
+ if (self._server_notices_mxid is not None and
+ requester.user.to_string() == self._server_notices_mxid):
+ # allow the server notices mxid to send invites
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(
+ requester.user,
+ )
+
+ if not is_requester_admin:
+ if self.config.block_non_admin_invites:
+ logger.info(
+ "Blocking invite: user is not admin and non-admin "
+ "invites disabled"
+ )
+ block_invite = True
+
+ if not self.spam_checker.user_may_invite(
+ requester.user.to_string(), target.to_string(), room_id,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ block_invite = True
+
+ if block_invite:
+ raise SynapseError(
+ 403, "Invites have been disabled on this server",
+ )
+
+ prev_events_and_hashes = yield self.store.get_prev_events_for_room(
+ room_id,
+ )
+ latest_event_ids = (
+ event_id for (event_id, _, _) in prev_events_and_hashes
+ )
current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids,
)
@@ -238,6 +373,20 @@ class RoomMemberHandler(BaseHandler):
if same_sender and same_membership and same_content:
defer.returnValue(old_state)
+ # we don't allow people to reject invites to the server notice
+ # room, but they can leave it once they are joined.
+ if (
+ old_membership == Membership.INVITE and
+ effective_membership_state == Membership.LEAVE
+ ):
+ is_blocked = yield self._is_server_notice_room(room_id)
+ if is_blocked:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "You cannot reject this invite",
+ errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
+ )
+
is_host_in_room = yield self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
@@ -249,13 +398,13 @@ class RoomMemberHandler(BaseHandler):
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
- inviter = yield self.get_inviter(target.to_string(), room_id)
+ inviter = yield self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
content["membership"] = Membership.JOIN
- profile = self.hs.get_handlers().profile_handler
+ profile = self.profile_handler
if not content_specified:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
@@ -263,15 +412,15 @@ class RoomMemberHandler(BaseHandler):
if requester.is_guest:
content["kind"] = "guest"
- ret = yield self.remote_join(
- remote_room_hosts, room_id, target, content
+ ret = yield self._remote_join(
+ requester, remote_room_hosts, room_id, target, content
)
defer.returnValue(ret)
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
- inviter = yield self.get_inviter(target.to_string(), room_id)
+ inviter = yield self._get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
@@ -285,28 +434,10 @@ class RoomMemberHandler(BaseHandler):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
- fed_handler = self.hs.get_handlers().federation_handler
- try:
- ret = yield fed_handler.do_remotely_reject_invite(
- remote_room_hosts,
- room_id,
- target.to_string(),
- )
- defer.returnValue(ret)
- except Exception as e:
- # if we were unable to reject the exception, just mark
- # it as rejected on our end and plough ahead.
- #
- # The 'except' clause is very broad, but we need to
- # capture everything from DNS failures upwards
- #
- logger.warn("Failed to reject invite: %s", e)
-
- yield self.store.locally_reject_invite(
- target.to_string(), room_id
- )
-
- defer.returnValue({})
+ res = yield self._remote_reject_invite(
+ requester, remote_room_hosts, room_id, target,
+ )
+ defer.returnValue(res)
res = yield self._local_membership_update(
requester=requester,
@@ -315,7 +446,7 @@ class RoomMemberHandler(BaseHandler):
membership=effective_membership_state,
txn_id=txn_id,
ratelimit=ratelimit,
- prev_event_ids=latest_event_ids,
+ prev_events_and_hashes=prev_events_and_hashes,
content=content,
)
defer.returnValue(res)
@@ -361,14 +492,16 @@ class RoomMemberHandler(BaseHandler):
else:
requester = synapse.types.create_requester(target_user)
- message_handler = self.hs.get_handlers().message_handler
- prev_event = yield message_handler.deduplicate_state_event(event, context)
+ prev_event = yield self.event_creation_hander.deduplicate_state_event(
+ event, context,
+ )
if prev_event is not None:
return
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
if event.membership == Membership.JOIN:
if requester.is_guest:
- guest_can_join = yield self._can_guest_join(context.prev_state_ids)
+ guest_can_join = yield self._can_guest_join(prev_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
@@ -379,7 +512,7 @@ class RoomMemberHandler(BaseHandler):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- yield message_handler.handle_new_client_event(
+ yield self.event_creation_hander.handle_new_client_event(
requester,
event,
context,
@@ -387,7 +520,7 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit,
)
- prev_member_event_id = context.prev_state_ids.get(
+ prev_member_event_id = prev_state_ids.get(
(EventTypes.Member, event.state_key),
None
)
@@ -401,12 +534,12 @@ class RoomMemberHandler(BaseHandler):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield user_joined_room(self.distributor, target_user, room_id)
+ yield self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
- user_left_room(self.distributor, target_user, room_id)
+ yield self._user_left_room(target_user, room_id)
@defer.inlineCallbacks
def _can_guest_join(self, current_state_ids):
@@ -440,7 +573,7 @@ class RoomMemberHandler(BaseHandler):
Raises:
SynapseError if room alias could not be found.
"""
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
@@ -452,7 +585,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks
- def get_inviter(self, user_id, room_id):
+ def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id,
room_id=room_id,
@@ -471,6 +604,16 @@ class RoomMemberHandler(BaseHandler):
requester,
txn_id
):
+ if self.config.block_non_admin_invites:
+ is_requester_admin = yield self.auth.is_server_admin(
+ requester.user,
+ )
+ if not is_requester_admin:
+ raise SynapseError(
+ 403, "Invites have been disabled on this server",
+ Codes.FORBIDDEN,
+ )
+
invitee = yield self._lookup_3pid(
id_server, medium, address
)
@@ -508,7 +651,7 @@ class RoomMemberHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = yield self.hs.get_simple_http_client().get_json(
+ data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
@@ -519,7 +662,7 @@ class RoomMemberHandler(BaseHandler):
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
- self.verify_any_signature(data, id_server)
+ yield self._verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
@@ -527,11 +670,11 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(None)
@defer.inlineCallbacks
- def verify_any_signature(self, data, server_hostname):
+ def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
- key_data = yield self.hs.get_simple_http_client().get_json(
+ key_data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
@@ -556,7 +699,7 @@ class RoomMemberHandler(BaseHandler):
user,
txn_id
):
- room_state = yield self.hs.get_state_handler().get_current_state(room_id)
+ room_state = yield self.state_handler.get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
@@ -591,6 +734,7 @@ class RoomMemberHandler(BaseHandler):
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
+ requester=requester,
id_server=id_server,
medium=medium,
address=address,
@@ -605,8 +749,7 @@ class RoomMemberHandler(BaseHandler):
)
)
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@@ -628,6 +771,7 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
+ requester,
id_server,
medium,
address,
@@ -644,6 +788,7 @@ class RoomMemberHandler(BaseHandler):
Asks an identity server for a third party invite.
Args:
+ requester (Requester)
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
@@ -685,24 +830,20 @@ class RoomMemberHandler(BaseHandler):
"sender_avatar_url": inviter_avatar_url,
}
- if self.hs.config.invite_3pid_guest:
- registration_handler = self.hs.get_handlers().registration_handler
- guest_access_token = yield registration_handler.guest_access_token_for(
+ if self.config.invite_3pid_guest:
+ guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
+ requester=requester,
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
- guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
- guest_access_token
- )
-
invite_config.update({
"guest_access_token": guest_access_token,
- "guest_user_id": guest_user_info["user"].to_string(),
+ "guest_user_id": guest_user_id,
})
- data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
+ data = yield self.simple_http_client.post_urlencoded_get_json(
is_url,
invite_config
)
@@ -725,27 +866,6 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue((token, public_keys, fallback_public_key, display_name))
@defer.inlineCallbacks
- def forget(self, user, room_id):
- user_id = user.to_string()
-
- member = yield self.state_handler.get_current_state(
- room_id=room_id,
- event_type=EventTypes.Member,
- state_key=user_id
- )
- membership = member.membership if member else None
-
- if membership is not None and membership not in [
- Membership.LEAVE, Membership.BAN
- ]:
- raise SynapseError(400, "User %s in room %s" % (
- user_id, room_id
- ))
-
- if membership:
- yield self.store.forget(user_id, room_id)
-
- @defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very
# first member event?
@@ -766,3 +886,109 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(True)
defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def _is_server_notice_room(self, room_id):
+ if self._server_notices_mxid is None:
+ defer.returnValue(False)
+ user_ids = yield self.store.get_users_in_room(room_id)
+ defer.returnValue(self._server_notices_mxid in user_ids)
+
+
+class RoomMemberMasterHandler(RoomMemberHandler):
+ def __init__(self, hs):
+ super(RoomMemberMasterHandler, self).__init__(hs)
+
+ self.distributor = hs.get_distributor()
+ self.distributor.declare("user_joined_room")
+ self.distributor.declare("user_left_room")
+
+ @defer.inlineCallbacks
+ def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ """Implements RoomMemberHandler._remote_join
+ """
+ # filter ourselves out of remote_room_hosts: do_invite_join ignores it
+ # and if it is the only entry we'd like to return a 404 rather than a
+ # 500.
+
+ remote_room_hosts = [
+ host for host in remote_room_hosts if host != self.hs.hostname
+ ]
+
+ if len(remote_room_hosts) == 0:
+ raise SynapseError(404, "No known servers")
+
+ # We don't do an auth check if we are doing an invite
+ # join dance for now, since we're kinda implicitly checking
+ # that we are allowed to join when we decide whether or not we
+ # need to do the invite/join dance.
+ yield self.federation_handler.do_invite_join(
+ remote_room_hosts,
+ room_id,
+ user.to_string(),
+ content,
+ )
+ yield self._user_joined_room(user, room_id)
+
+ @defer.inlineCallbacks
+ def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+ """Implements RoomMemberHandler._remote_reject_invite
+ """
+ fed_handler = self.federation_handler
+ try:
+ ret = yield fed_handler.do_remotely_reject_invite(
+ remote_room_hosts,
+ room_id,
+ target.to_string(),
+ )
+ defer.returnValue(ret)
+ except Exception as e:
+ # if we were unable to reject the exception, just mark
+ # it as rejected on our end and plough ahead.
+ #
+ # The 'except' clause is very broad, but we need to
+ # capture everything from DNS failures upwards
+ #
+ logger.warn("Failed to reject invite: %s", e)
+
+ yield self.store.locally_reject_invite(
+ target.to_string(), room_id
+ )
+ defer.returnValue({})
+
+ def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+ """Implements RoomMemberHandler.get_or_register_3pid_guest
+ """
+ rg = self.registration_handler
+ return rg.get_or_register_3pid_guest(medium, address, inviter_user_id)
+
+ def _user_joined_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_joined_room
+ """
+ return user_joined_room(self.distributor, target, room_id)
+
+ def _user_left_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_left_room
+ """
+ return user_left_room(self.distributor, target, room_id)
+
+ @defer.inlineCallbacks
+ def forget(self, user, room_id):
+ user_id = user.to_string()
+
+ member = yield self.state_handler.get_current_state(
+ room_id=room_id,
+ event_type=EventTypes.Member,
+ state_key=user_id
+ )
+ membership = member.membership if member else None
+
+ if membership is not None and membership not in [
+ Membership.LEAVE, Membership.BAN
+ ]:
+ raise SynapseError(400, "User %s in room %s" % (
+ user_id, room_id
+ ))
+
+ if membership:
+ yield self.store.forget(user_id, room_id)
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
new file mode 100644
index 0000000000..22d8b4b0d3
--- /dev/null
+++ b/synapse/handlers/room_member_worker.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.room_member import RoomMemberHandler
+from synapse.replication.http.membership import (
+ get_or_register_3pid_guest,
+ notify_user_membership_change,
+ remote_join,
+ remote_reject_invite,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class RoomMemberWorkerHandler(RoomMemberHandler):
+ @defer.inlineCallbacks
+ def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ """Implements RoomMemberHandler._remote_join
+ """
+ if len(remote_room_hosts) == 0:
+ raise SynapseError(404, "No known servers")
+
+ ret = yield remote_join(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ remote_room_hosts=remote_room_hosts,
+ room_id=room_id,
+ user_id=user.to_string(),
+ content=content,
+ )
+
+ yield self._user_joined_room(user, room_id)
+
+ defer.returnValue(ret)
+
+ def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+ """Implements RoomMemberHandler._remote_reject_invite
+ """
+ return remote_reject_invite(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ remote_room_hosts=remote_room_hosts,
+ room_id=room_id,
+ user_id=target.to_string(),
+ )
+
+ def _user_joined_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_joined_room
+ """
+ return notify_user_membership_change(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ user_id=target.to_string(),
+ room_id=room_id,
+ change="joined",
+ )
+
+ def _user_left_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_left_room
+ """
+ return notify_user_membership_change(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ user_id=target.to_string(),
+ room_id=room_id,
+ change="left",
+ )
+
+ def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+ """Implements RoomMemberHandler.get_or_register_3pid_guest
+ """
+ return get_or_register_3pid_guest(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ medium=medium,
+ address=address,
+ inviter_user_id=inviter_user_id,
+ )
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index df75d70fac..69ae9731d5 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -13,21 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import itertools
+import logging
-from ._base import BaseHandler
+from unpaddedbase64 import decode_base64, encode_base64
-from synapse.api.constants import Membership, EventTypes
-from synapse.api.filtering import Filter
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
+from synapse.api.filtering import Filter
from synapse.events.utils import serialize_event
from synapse.visibility import filter_events_for_client
-from unpaddedbase64 import decode_base64, encode_base64
-
-import itertools
-import logging
-
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -61,9 +60,16 @@ class SearchHandler(BaseHandler):
assert batch_group is not None
assert batch_group_key is not None
assert batch_token is not None
- except:
+ except Exception:
raise SynapseError(400, "Invalid batch")
+ logger.info(
+ "Search batch properties: %r, %r, %r",
+ batch_group, batch_group_key, batch_token,
+ )
+
+ logger.info("Search content: %s", content)
+
try:
room_cat = content["search_categories"]["room_events"]
@@ -271,6 +277,8 @@ class SearchHandler(BaseHandler):
# We should never get here due to the guard earlier.
raise NotImplementedError()
+ logger.info("Found %d events to return", len(allowed_events))
+
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None:
@@ -282,6 +290,11 @@ class SearchHandler(BaseHandler):
event.room_id, event.event_id, before_limit, after_limit
)
+ logger.info(
+ "Context for search returned %d and %d events",
+ len(res["events_before"]), len(res["events_after"]),
+ )
+
res["events_before"] = yield filter_events_for_client(
self.store, user.to_string(), res["events_before"]
)
@@ -348,7 +361,7 @@ class SearchHandler(BaseHandler):
rooms = set(e.room_id for e in allowed_events)
for room_id in rooms:
state = yield self.state_handler.get_current_state(room_id)
- state_results[room_id] = state.values()
+ state_results[room_id] = list(state.values())
state_results.values()
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
new file mode 100644
index 0000000000..7ecdede4dc
--- /dev/null
+++ b/synapse/handlers/set_password.py
@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import Codes, StoreError, SynapseError
+
+from ._base import BaseHandler
+
+logger = logging.getLogger(__name__)
+
+
+class SetPasswordHandler(BaseHandler):
+ """Handler which deals with changing user account passwords"""
+ def __init__(self, hs):
+ super(SetPasswordHandler, self).__init__(hs)
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
+
+ @defer.inlineCallbacks
+ def set_password(self, user_id, newpassword, requester=None):
+ password_hash = yield self._auth_handler.hash(newpassword)
+
+ except_device_id = requester.device_id if requester else None
+ except_access_token_id = requester.access_token_id if requester else None
+
+ try:
+ yield self.store.user_set_password_hash(user_id, password_hash)
+ except StoreError as e:
+ if e.code == 404:
+ raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
+ raise e
+
+ # we want to log out all of the user's other sessions. First delete
+ # all his other devices.
+ yield self._device_handler.delete_all_devices_for_user(
+ user_id, except_device_id=except_device_id,
+ )
+
+ # and now delete any access tokens which weren't associated with
+ # devices (or were associated with this device).
+ yield self._auth_handler.delete_access_tokens_for_user(
+ user_id, except_token_id=except_access_token_id,
+ )
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 91c6c6be3c..c24e35362a 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,20 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.constants import Membership, EventTypes
+import collections
+import itertools
+import logging
+
+from six import iteritems, itervalues
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.push.clientformat import format_push_rules_for_user
+from synapse.types import RoomStreamToken
from synapse.util.async import concurrently_execute
+from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure, measure_func
-from synapse.util.caches.response_cache import ResponseCache
-from synapse.push.clientformat import format_push_rules_for_user
from synapse.visibility import filter_events_for_client
-from synapse.types import RoomStreamToken
-
-from twisted.internet import defer
-
-import collections
-import logging
-import itertools
logger = logging.getLogger(__name__)
@@ -52,6 +54,7 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [
to tell if room needs to be part of the sync result.
"""
return bool(self.events)
+ __bool__ = __nonzero__ # python3
class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
@@ -76,6 +79,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
# nb the notification count does not, er, count: if there's nothing
# else in the result, we don't need to send it.
)
+ __bool__ = __nonzero__ # python3
class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
@@ -95,6 +99,7 @@ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
or self.state
or self.account_data
)
+ __bool__ = __nonzero__ # python3
class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
@@ -106,6 +111,30 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
def __nonzero__(self):
"""Invited rooms should always be reported to the client"""
return True
+ __bool__ = __nonzero__ # python3
+
+
+class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
+ "join",
+ "invite",
+ "leave",
+])):
+ __slots__ = []
+
+ def __nonzero__(self):
+ return bool(self.join or self.invite or self.leave)
+ __bool__ = __nonzero__ # python3
+
+
+class DeviceLists(collections.namedtuple("DeviceLists", [
+ "changed", # list of user_ids whose devices may have changed
+ "left", # list of user_ids whose devices we no longer track
+])):
+ __slots__ = []
+
+ def __nonzero__(self):
+ return bool(self.changed or self.left)
+ __bool__ = __nonzero__ # python3
class SyncResult(collections.namedtuple("SyncResult", [
@@ -116,9 +145,10 @@ class SyncResult(collections.namedtuple("SyncResult", [
"invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device.
- "device_lists", # List of user_ids whose devices have chanegd
+ "device_lists", # List of user_ids whose devices have changed
"device_one_time_keys_count", # Dict of algorithm to count for one time keys
# for this device
+ "groups",
])):
__slots__ = []
@@ -134,8 +164,10 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.archived or
self.account_data or
self.to_device or
- self.device_lists
+ self.device_lists or
+ self.groups
)
+ __bool__ = __nonzero__ # python3
class SyncHandler(object):
@@ -146,7 +178,7 @@ class SyncHandler(object):
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
- self.response_cache = ResponseCache(hs)
+ self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
@@ -157,15 +189,11 @@ class SyncHandler(object):
Returns:
A Deferred SyncResult.
"""
- result = self.response_cache.get(sync_config.request_key)
- if not result:
- result = self.response_cache.set(
- sync_config.request_key,
- self._wait_for_sync_for_user(
- sync_config, since_token, timeout, full_state
- )
- )
- return result
+ return self.response_cache.wrap(
+ sync_config.request_key,
+ self._wait_for_sync_for_user,
+ sync_config, since_token, timeout, full_state,
+ )
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
@@ -212,10 +240,10 @@ class SyncHandler(object):
defer.returnValue(rules)
@defer.inlineCallbacks
- def ephemeral_by_room(self, sync_config, now_token, since_token=None):
+ def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
"""Get the ephemeral events for each room the user is in
Args:
- sync_config (SyncConfig): The flags, filters and user for the sync.
+ sync_result_builder(SyncResultBuilder)
now_token (StreamToken): Where the server is currently up to.
since_token (StreamToken): Where the server was when the client
last synced.
@@ -225,10 +253,12 @@ class SyncHandler(object):
typing events for that room.
"""
+ sync_config = sync_result_builder.sync_config
+
with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0"
- room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string())
+ room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events(
@@ -247,7 +277,7 @@ class SyncHandler(object):
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
- event_copy = {k: v for (k, v) in event.iteritems()
+ event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@@ -266,7 +296,7 @@ class SyncHandler(object):
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
- event_copy = {k: v for (k, v) in event.iteritems()
+ event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@@ -290,10 +320,20 @@ class SyncHandler(object):
if recents:
recents = sync_config.filter_collection.filter_room_timeline(recents)
+
+ # We check if there are any state events, if there are then we pass
+ # all current state events to the filter_events function. This is to
+ # ensure that we always include current state in the timeline
+ current_state_ids = frozenset()
+ if any(e.is_state() for e in recents):
+ current_state_ids = yield self.state.get_current_state_ids(room_id)
+ current_state_ids = frozenset(itervalues(current_state_ids))
+
recents = yield filter_events_for_client(
self.store,
sync_config.user.to_string(),
recents,
+ always_include_ids=current_state_ids,
)
else:
recents = []
@@ -316,19 +356,41 @@ class SyncHandler(object):
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat:
- events, end_key = yield self.store.get_room_events_stream_for_room(
- room_id,
- limit=load_limit + 1,
- from_key=since_key,
- to_key=end_key,
- )
+ # If we have a since_key then we are trying to get any events
+ # that have happened since `since_key` up to `end_key`, so we
+ # can just use `get_room_events_stream_for_room`.
+ # Otherwise, we want to return the last N events in the room
+ # in toplogical ordering.
+ if since_key:
+ events, end_key = yield self.store.get_room_events_stream_for_room(
+ room_id,
+ limit=load_limit + 1,
+ from_key=since_key,
+ to_key=end_key,
+ )
+ else:
+ events, end_key = yield self.store.get_recent_events_for_room(
+ room_id,
+ limit=load_limit + 1,
+ end_token=end_key,
+ )
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
)
+
+ # We check if there are any state events, if there are then we pass
+ # all current state events to the filter_events function. This is to
+ # ensure that we always include current state in the timeline
+ current_state_ids = frozenset()
+ if any(e.is_state() for e in loaded_recents):
+ current_state_ids = yield self.state.get_current_state_ids(room_id)
+ current_state_ids = frozenset(itervalues(current_state_ids))
+
loaded_recents = yield filter_events_for_client(
self.store,
sync_config.user.to_string(),
loaded_recents,
+ always_include_ids=current_state_ids,
)
loaded_recents.extend(recents)
recents = loaded_recents
@@ -381,7 +443,11 @@ class SyncHandler(object):
Returns:
A Deferred map from ((type, state_key)->Event)
"""
- last_events, token = yield self.store.get_recent_events_for_room(
+ # FIXME this claims to get the state at a stream position, but
+ # get_recent_events_for_room operates by topo ordering. This therefore
+ # does not reliably give you the state at the given stream position.
+ # (https://github.com/matrix-org/synapse/issues/3305)
+ last_events, _ = yield self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1,
)
@@ -475,11 +541,11 @@ class SyncHandler(object):
state = {}
if state_ids:
- state = yield self.store.get_events(state_ids.values())
+ state = yield self.store.get_events(list(state_ids.values()))
defer.returnValue({
(e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(state.values())
+ for e in sync_config.filter_collection.filter_room_state(list(state.values()))
})
@defer.inlineCallbacks
@@ -522,10 +588,22 @@ class SyncHandler(object):
# Always use the `now_token` in `SyncResultBuilder`
now_token = yield self.event_sources.get_current_token()
+ user_id = sync_config.user.to_string()
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service:
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
+ else:
+ joined_room_ids = yield self.get_rooms_for_user_at(
+ user_id, now_token.room_stream_id,
+ )
+
sync_result_builder = SyncResultBuilder(
sync_config, full_state,
since_token=since_token,
now_token=now_token,
+ joined_room_ids=joined_room_ids,
)
account_data_by_room = yield self._generate_sync_entry_for_account_data(
@@ -535,7 +613,8 @@ class SyncHandler(object):
res = yield self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
- newly_joined_rooms, newly_joined_users = res
+ newly_joined_rooms, newly_joined_users, _, _ = res
+ _, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
since_token is None and
@@ -549,17 +628,22 @@ class SyncHandler(object):
yield self._generate_sync_entry_for_to_device(sync_result_builder)
device_lists = yield self._generate_sync_entry_for_device_list(
- sync_result_builder
+ sync_result_builder,
+ newly_joined_rooms=newly_joined_rooms,
+ newly_joined_users=newly_joined_users,
+ newly_left_rooms=newly_left_rooms,
+ newly_left_users=newly_left_users,
)
device_id = sync_config.device_id
one_time_key_counts = {}
if device_id:
- user_id = sync_config.user.to_string()
one_time_key_counts = yield self.store.count_e2e_one_time_keys(
user_id, device_id
)
+ yield self._generate_sync_entry_for_groups(sync_result_builder)
+
defer.returnValue(SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
@@ -568,31 +652,103 @@ class SyncHandler(object):
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
+ groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts,
next_batch=sync_result_builder.now_token,
))
+ @measure_func("_generate_sync_entry_for_groups")
+ @defer.inlineCallbacks
+ def _generate_sync_entry_for_groups(self, sync_result_builder):
+ user_id = sync_result_builder.sync_config.user.to_string()
+ since_token = sync_result_builder.since_token
+ now_token = sync_result_builder.now_token
+
+ if since_token and since_token.groups_key:
+ results = yield self.store.get_groups_changes_for_user(
+ user_id, since_token.groups_key, now_token.groups_key,
+ )
+ else:
+ results = yield self.store.get_all_groups_for_user(
+ user_id, now_token.groups_key,
+ )
+
+ invited = {}
+ joined = {}
+ left = {}
+ for result in results:
+ membership = result["membership"]
+ group_id = result["group_id"]
+ gtype = result["type"]
+ content = result["content"]
+
+ if membership == "join":
+ if gtype == "membership":
+ # TODO: Add profile
+ content.pop("membership", None)
+ joined[group_id] = content["content"]
+ else:
+ joined.setdefault(group_id, {})[gtype] = content
+ elif membership == "invite":
+ if gtype == "membership":
+ content.pop("membership", None)
+ invited[group_id] = content["content"]
+ else:
+ if gtype == "membership":
+ left[group_id] = content["content"]
+
+ sync_result_builder.groups = GroupsSyncResult(
+ join=joined,
+ invite=invited,
+ leave=left,
+ )
+
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
- def _generate_sync_entry_for_device_list(self, sync_result_builder):
+ def _generate_sync_entry_for_device_list(self, sync_result_builder,
+ newly_joined_rooms, newly_joined_users,
+ newly_left_rooms, newly_left_users):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and since_token.device_list_key:
- room_ids = yield self.store.get_rooms_for_user(user_id)
-
- user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(
since_token.device_list_key
)
- for other_user_id in changed:
- other_room_ids = yield self.store.get_rooms_for_user(other_user_id)
- if room_ids.intersection(other_room_ids):
- user_ids_changed.add(other_user_id)
- defer.returnValue(user_ids_changed)
+ # TODO: Be more clever than this, i.e. remove users who we already
+ # share a room with?
+ for room_id in newly_joined_rooms:
+ joined_users = yield self.state.get_current_user_in_room(room_id)
+ newly_joined_users.update(joined_users)
+
+ for room_id in newly_left_rooms:
+ left_users = yield self.state.get_current_user_in_room(room_id)
+ newly_left_users.update(left_users)
+
+ # TODO: Check that these users are actually new, i.e. either they
+ # weren't in the previous sync *or* they left and rejoined.
+ changed.update(newly_joined_users)
+
+ if not changed and not newly_left_users:
+ defer.returnValue(DeviceLists(
+ changed=[],
+ left=newly_left_users,
+ ))
+
+ users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ user_id
+ )
+
+ defer.returnValue(DeviceLists(
+ changed=users_who_share_room & changed,
+ left=set(newly_left_users) - users_who_share_room,
+ ))
else:
- defer.returnValue([])
+ defer.returnValue(DeviceLists(
+ changed=[],
+ left=[],
+ ))
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
@@ -738,7 +894,7 @@ class SyncHandler(object):
presence.extend(states)
# Deduplicate the presence entries so that there's at most one per user
- presence = {p.user_id: p for p in presence}.values()
+ presence = list({p.user_id: p for p in presence}.values())
presence = sync_config.filter_collection.filter_presence(
presence
@@ -756,8 +912,8 @@ class SyncHandler(object):
account_data_by_room(dict): Dictionary of per room account data
Returns:
- Deferred(tuple): Returns a 2-tuple of
- `(newly_joined_rooms, newly_joined_users)`
+ Deferred(tuple): Returns a 4-tuple of
+ `(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)`
"""
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
@@ -769,7 +925,7 @@ class SyncHandler(object):
ephemeral_by_room = {}
else:
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
- sync_result_builder.sync_config,
+ sync_result_builder,
now_token=sync_result_builder.now_token,
since_token=sync_result_builder.since_token,
)
@@ -788,7 +944,7 @@ class SyncHandler(object):
)
if not tags_by_room:
logger.debug("no-oping sync")
- defer.returnValue(([], []))
+ defer.returnValue(([], [], [], []))
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id,
@@ -801,7 +957,7 @@ class SyncHandler(object):
if since_token:
res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
- room_entries, invited, newly_joined_rooms = res
+ room_entries, invited, newly_joined_rooms, newly_left_rooms = res
tags_by_room = yield self.store.get_updated_tags(
user_id, since_token.account_data_key,
@@ -809,6 +965,7 @@ class SyncHandler(object):
else:
res = yield self._get_all_rooms(sync_result_builder, ignored_users)
room_entries, invited, newly_joined_rooms = res
+ newly_left_rooms = []
tags_by_room = yield self.store.get_tags_for_user(user_id)
@@ -829,17 +986,30 @@ class SyncHandler(object):
# Now we want to get any newly joined users
newly_joined_users = set()
+ newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
it = itertools.chain(
- joined_sync.timeline.events, joined_sync.state.values()
+ joined_sync.timeline.events, itervalues(joined_sync.state)
)
for event in it:
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
newly_joined_users.add(event.state_key)
-
- defer.returnValue((newly_joined_rooms, newly_joined_users))
+ else:
+ prev_content = event.unsigned.get("prev_content", {})
+ prev_membership = prev_content.get("membership", None)
+ if prev_membership == Membership.JOIN:
+ newly_left_users.add(event.state_key)
+
+ newly_left_users -= newly_joined_users
+
+ defer.returnValue((
+ newly_joined_rooms,
+ newly_joined_users,
+ newly_left_rooms,
+ newly_left_users,
+ ))
@defer.inlineCallbacks
def _have_rooms_changed(self, sync_result_builder):
@@ -860,15 +1030,8 @@ class SyncHandler(object):
if rooms_changed:
defer.returnValue(True)
- app_service = self.store.get_app_service_by_user_id(user_id)
- if app_service:
- rooms = yield self.store.get_app_service_rooms(app_service)
- joined_room_ids = set(r.room_id for r in rooms)
- else:
- joined_room_ids = yield self.store.get_rooms_for_user(user_id)
-
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
- for room_id in joined_room_ids:
+ for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id):
defer.returnValue(True)
defer.returnValue(False)
@@ -883,7 +1046,13 @@ class SyncHandler(object):
Returns:
Deferred(tuple): Returns a tuple of the form:
- `([RoomSyncResultBuilder], [InvitedSyncResult], newly_joined_rooms)`
+ `(room_entries, invited_rooms, newly_joined_rooms, newly_left_rooms)`
+
+ where:
+ room_entries is a list [RoomSyncResultBuilder]
+ invited_rooms is a list [InvitedSyncResult]
+ newly_joined rooms is a list[str] of room ids
+ newly_left_rooms is a list[str] of room ids
"""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@@ -892,13 +1061,6 @@ class SyncHandler(object):
assert since_token
- app_service = self.store.get_app_service_by_user_id(user_id)
- if app_service:
- rooms = yield self.store.get_app_service_rooms(app_service)
- joined_room_ids = set(r.room_id for r in rooms)
- else:
- joined_room_ids = yield self.store.get_rooms_for_user(user_id)
-
# Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
@@ -909,16 +1071,29 @@ class SyncHandler(object):
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
newly_joined_rooms = []
+ newly_left_rooms = []
room_entries = []
invited = []
- for room_id, events in mem_change_events_by_room_id.items():
+ for room_id, events in iteritems(mem_change_events_by_room_id):
non_joins = [e for e in events if e.membership != Membership.JOIN]
has_join = len(non_joins) != len(events)
# We want to figure out if we joined the room at some point since
# the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary
- if room_id in joined_room_ids or has_join:
+
+ old_state_ids = None
+ if room_id in sync_result_builder.joined_room_ids and non_joins:
+ # Always include if the user (re)joined the room, especially
+ # important so that device list changes are calculated correctly.
+ # If there are non join member events, but we are still in the room,
+ # then the user must have left and joined
+ newly_joined_rooms.append(room_id)
+
+ # User is in the room so we don't need to do the invite/leave checks
+ continue
+
+ if room_id in sync_result_builder.joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None
@@ -929,12 +1104,33 @@ class SyncHandler(object):
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id)
- if room_id in joined_room_ids:
- continue
+ # If user is in the room then we don't need to do the invite/leave checks
+ if room_id in sync_result_builder.joined_room_ids:
+ continue
if not non_joins:
continue
+ # Check if we have left the room. This can either be because we were
+ # joined before *or* that we since joined and then left.
+ if events[-1].membership != Membership.JOIN:
+ if has_join:
+ newly_left_rooms.append(room_id)
+ else:
+ if not old_state_ids:
+ old_state_ids = yield self.get_state_at(room_id, since_token)
+ old_mem_ev_id = old_state_ids.get(
+ (EventTypes.Member, user_id),
+ None,
+ )
+ old_mem_ev = None
+ if old_mem_ev_id:
+ old_mem_ev = yield self.store.get_event(
+ old_mem_ev_id, allow_none=True
+ )
+ if old_mem_ev and old_mem_ev.membership == Membership.JOIN:
+ newly_left_rooms.append(room_id)
+
# Only bother if we're still currently invited
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
@@ -976,7 +1172,7 @@ class SyncHandler(object):
# Get all events for rooms we're currently joined to.
room_to_events = yield self.store.get_room_events_stream_for_rooms(
- room_ids=joined_room_ids,
+ room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
@@ -984,7 +1180,7 @@ class SyncHandler(object):
# We loop through all room ids, even if there are no new events, in case
# there are non room events taht we need to notify about.
- for room_id in joined_room_ids:
+ for room_id in sync_result_builder.joined_room_ids:
room_entry = room_to_events.get(room_id, None)
if room_entry:
@@ -1012,7 +1208,7 @@ class SyncHandler(object):
upto_token=since_token,
))
- defer.returnValue((room_entries, invited, newly_joined_rooms))
+ defer.returnValue((room_entries, invited, newly_joined_rooms, newly_left_rooms))
@defer.inlineCallbacks
def _get_all_rooms(self, sync_result_builder, ignored_users):
@@ -1192,6 +1388,54 @@ class SyncHandler(object):
else:
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
+ @defer.inlineCallbacks
+ def get_rooms_for_user_at(self, user_id, stream_ordering):
+ """Get set of joined rooms for a user at the given stream ordering.
+
+ The stream ordering *must* be recent, otherwise this may throw an
+ exception if older than a month. (This function is called with the
+ current token, which should be perfectly fine).
+
+ Args:
+ user_id (str)
+ stream_ordering (int)
+
+ ReturnValue:
+ Deferred[frozenset[str]]: Set of room_ids the user is in at given
+ stream_ordering.
+ """
+ joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
+ user_id,
+ )
+
+ joined_room_ids = set()
+
+ # We need to check that the stream ordering of the join for each room
+ # is before the stream_ordering asked for. This might not be the case
+ # if the user joins a room between us getting the current token and
+ # calling `get_rooms_for_user_with_stream_ordering`.
+ # If the membership's stream ordering is after the given stream
+ # ordering, we need to go and work out if the user was in the room
+ # before.
+ for room_id, membership_stream_ordering in joined_rooms:
+ if membership_stream_ordering <= stream_ordering:
+ joined_room_ids.add(room_id)
+ continue
+
+ logger.info("User joined room after current token: %s", room_id)
+
+ extrems = yield self.store.get_forward_extremeties_for_room(
+ room_id, stream_ordering,
+ )
+ users_in_room = yield self.state.get_current_user_in_room(
+ room_id, extrems,
+ )
+ if user_id in users_in_room:
+ joined_room_ids.add(room_id)
+
+ joined_room_ids = frozenset(joined_room_ids)
+ defer.returnValue(joined_room_ids)
+
def _action_has_highlight(actions):
for action in actions:
@@ -1241,7 +1485,8 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
class SyncResultBuilder(object):
"Used to help build up a new SyncResult for a user"
- def __init__(self, sync_config, full_state, since_token, now_token):
+ def __init__(self, sync_config, full_state, since_token, now_token,
+ joined_room_ids):
"""
Args:
sync_config(SyncConfig)
@@ -1253,6 +1498,7 @@ class SyncResultBuilder(object):
self.full_state = full_state
self.since_token = since_token
self.now_token = now_token
+ self.joined_room_ids = joined_room_ids
self.presence = []
self.account_data = []
@@ -1260,6 +1506,8 @@ class SyncResultBuilder(object):
self.invited = []
self.archived = []
self.device = []
+ self.groups = None
+ self.to_device = []
class RoomSyncResultBuilder(object):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 82dedbbc99..2d2d3d5a0d 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -13,17 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+from collections import namedtuple
+
from twisted.internet import defer
-from synapse.api.errors import SynapseError, AuthError
-from synapse.util.logcontext import preserve_fn
+from synapse.api.errors import AuthError, SynapseError
+from synapse.types import UserID, get_domain_from_id
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
-from synapse.types import UserID, get_domain_from_id
-
-import logging
-
-from collections import namedtuple
logger = logging.getLogger(__name__)
@@ -56,7 +55,7 @@ class TypingHandler(object):
self.federation = hs.get_federation_sender()
- hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
+ hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@@ -97,7 +96,8 @@ class TypingHandler(object):
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
- preserve_fn(self._push_remote)(
+ run_in_background(
+ self._push_remote,
member=member,
typing=True
)
@@ -196,7 +196,7 @@ class TypingHandler(object):
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
- preserve_fn(self._push_remote)(member, typing)
+ run_in_background(self._push_remote, member, typing)
self._push_update_local(
member=member,
@@ -205,28 +205,31 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _push_remote(self, member, typing):
- users = yield self.state.get_current_user_in_room(member.room_id)
- self._member_last_federation_poke[member] = self.clock.time_msec()
+ try:
+ users = yield self.state.get_current_user_in_room(member.room_id)
+ self._member_last_federation_poke[member] = self.clock.time_msec()
- now = self.clock.time_msec()
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + FEDERATION_PING_INTERVAL,
- )
+ now = self.clock.time_msec()
+ self.wheel_timer.insert(
+ now=now,
+ obj=member,
+ then=now + FEDERATION_PING_INTERVAL,
+ )
- for domain in set(get_domain_from_id(u) for u in users):
- if domain != self.server_name:
- self.federation.send_edu(
- destination=domain,
- edu_type="m.typing",
- content={
- "room_id": member.room_id,
- "user_id": member.user_id,
- "typing": typing,
- },
- key=member,
- )
+ for domain in set(get_domain_from_id(u) for u in users):
+ if domain != self.server_name:
+ self.federation.send_edu(
+ destination=domain,
+ edu_type="m.typing",
+ content={
+ "room_id": member.room_id,
+ "user_id": member.user_id,
+ "typing": typing,
+ },
+ key=member,
+ )
+ except Exception:
+ logger.exception("Error pushing typing notif to remotes")
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 2a49456bfc..37dda64587 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -14,18 +14,20 @@
# limitations under the License.
import logging
+
+from six import iteritems
+
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.storage.roommember import ProfileInfo
+from synapse.types import get_localpart_from_id
from synapse.util.metrics import Measure
-from synapse.util.async import sleep
-
logger = logging.getLogger(__name__)
-class UserDirectoyHandler(object):
+class UserDirectoryHandler(object):
"""Handles querying of and keeping updated the user_directory.
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
@@ -41,9 +43,10 @@ class UserDirectoyHandler(object):
one public room.
"""
- INITIAL_SLEEP_MS = 50
- INITIAL_SLEEP_COUNT = 100
- INITIAL_BATCH_SIZE = 100
+ INITIAL_ROOM_SLEEP_MS = 50
+ INITIAL_ROOM_SLEEP_COUNT = 100
+ INITIAL_ROOM_BATCH_SIZE = 100
+ INITIAL_USER_SLEEP_MS = 10
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -53,6 +56,7 @@ class UserDirectoyHandler(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.update_user_directory
+ self.search_all_users = hs.config.user_directory_search_all_users
# When start up for the first time we need to populate the user_directory.
# This is a set of user_id's we've inserted already
@@ -111,6 +115,22 @@ class UserDirectoyHandler(object):
self._is_processing = False
@defer.inlineCallbacks
+ def handle_local_profile_change(self, user_id, profile):
+ """Called to update index of our local user profiles when they change
+ irrespective of any rooms the user may be in.
+ """
+ yield self.store.update_profile_in_user_dir(
+ user_id, profile.display_name, profile.avatar_url, None,
+ )
+
+ @defer.inlineCallbacks
+ def handle_user_deactivated(self, user_id):
+ """Called when a user ID is deactivated
+ """
+ yield self.store.remove_from_user_dir(user_id)
+ yield self.store.remove_from_user_in_public_room(user_id)
+
+ @defer.inlineCallbacks
def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
@@ -148,16 +168,30 @@ class UserDirectoyHandler(object):
room_ids = yield self.store.get_all_rooms()
logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
- num_processed_rooms = 1
+ num_processed_rooms = 0
for room_id in room_ids:
- logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
- yield self._handle_intial_room(room_id)
+ logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
+ yield self._handle_initial_room(room_id)
num_processed_rooms += 1
- yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+ yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
logger.info("Processed all rooms.")
+ if self.search_all_users:
+ num_processed_users = 0
+ user_ids = yield self.store.get_all_local_users()
+ logger.info("Doing initial update of user directory. %d users", len(user_ids))
+ for user_id in user_ids:
+ # We add profiles for all users even if they don't match the
+ # include pattern, just in case we want to change it in future
+ logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
+ yield self._handle_local_user(user_id)
+ num_processed_users += 1
+ yield self.clock.sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
+
+ logger.info("Processed all users")
+
self.initially_handled_users = None
self.initially_handled_users_in_public = None
self.initially_handled_users_share = None
@@ -166,7 +200,7 @@ class UserDirectoyHandler(object):
yield self.store.update_user_directory_stream_pos(new_pos)
@defer.inlineCallbacks
- def _handle_intial_room(self, room_id):
+ def _handle_initial_room(self, room_id):
"""Called when we initially fill out user_directory one room at a time
"""
is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
@@ -201,8 +235,8 @@ class UserDirectoyHandler(object):
to_update = set()
count = 0
for user_id in user_ids:
- if count % self.INITIAL_SLEEP_COUNT == 0:
- yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+ if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+ yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
if not self.is_mine_id(user_id):
count += 1
@@ -216,8 +250,8 @@ class UserDirectoyHandler(object):
if user_id == other_user_id:
continue
- if count % self.INITIAL_SLEEP_COUNT == 0:
- yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+ if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+ yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
count += 1
user_set = (user_id, other_user_id)
@@ -237,13 +271,13 @@ class UserDirectoyHandler(object):
else:
self.initially_handled_users_share_private_room.add(user_set)
- if len(to_insert) > self.INITIAL_BATCH_SIZE:
+ if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.add_users_who_share_room(
room_id, not is_public, to_insert,
)
to_insert.clear()
- if len(to_update) > self.INITIAL_BATCH_SIZE:
+ if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.update_users_who_share_room(
room_id, not is_public, to_update,
)
@@ -377,7 +411,7 @@ class UserDirectoyHandler(object):
if change:
users_with_profile = yield self.state.get_current_user_in_room(room_id)
- for user_id, profile in users_with_profile.iteritems():
+ for user_id, profile in iteritems(users_with_profile):
yield self._handle_new_user(room_id, user_id, profile)
else:
users = yield self.store.get_users_in_public_due_to_room(room_id)
@@ -385,14 +419,28 @@ class UserDirectoyHandler(object):
yield self._handle_remove_user(room_id, user_id)
@defer.inlineCallbacks
+ def _handle_local_user(self, user_id):
+ """Adds a new local roomless user into the user_directory_search table.
+ Used to populate up the user index when we have an
+ user_directory_search_all_users specified.
+ """
+ logger.debug("Adding new local user to dir, %r", user_id)
+
+ profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
+
+ row = yield self.store.get_user_in_directory(user_id)
+ if not row:
+ yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
+
+ @defer.inlineCallbacks
def _handle_new_user(self, room_id, user_id, profile):
"""Called when we might need to add user to directory
Args:
- room_id (str): room_id that user joined or started being public that
+ room_id (str): room_id that user joined or started being public
user_id (str)
"""
- logger.debug("Adding user to dir, %r", user_id)
+ logger.debug("Adding new user to dir, %r", user_id)
row = yield self.store.get_user_in_directory(user_id)
if not row:
@@ -407,7 +455,7 @@ class UserDirectoyHandler(object):
if not row:
yield self.store.add_users_to_public_room(room_id, [user_id])
else:
- logger.debug("Not adding user to public dir, %r", user_id)
+ logger.debug("Not adding new user to public dir, %r", user_id)
# Now we update users who share rooms with users. We do this by getting
# all the current users in the room and seeing which aren't already
|