diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 9442ae6f1d..1a50a2ec98 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -13,11 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.appservice.scheduler import AppServiceScheduler
-from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler
from .room import (
- RoomCreationHandler, RoomListHandler, RoomContextHandler,
+ RoomCreationHandler, RoomContextHandler,
)
from .room_member import RoomMemberHandler
from .message import MessageHandler
@@ -26,8 +24,6 @@ from .federation import FederationHandler
from .profile import ProfileHandler
from .directory import DirectoryHandler
from .admin import AdminHandler
-from .appservice import ApplicationServicesHandler
-from .auth import AuthHandler
from .identity import IdentityHandler
from .receipts import ReceiptsHandler
from .search import SearchHandler
@@ -35,10 +31,21 @@ from .search import SearchHandler
class Handlers(object):
- """ A collection of all the event handlers.
+ """ Deprecated. A collection of handlers.
- There's no need to lazily create these; we'll just make them all eagerly
- at construction time.
+ At some point most of the classes whose name ended "Handler" were
+ accessed through this class.
+
+ However this makes it painful to unit test the handlers and to run cut
+ down versions of synapse that only use specific handlers because using a
+ single handler required creating all of the handlers. So some of the
+ handlers have been lifted out of the Handlers object and are now accessed
+ directly through the homeserver object itself.
+
+ Any new handlers should follow the new pattern of being accessed through
+ the homeserver object and should not be added to the Handlers object.
+
+ The remaining handlers should be moved out of the handlers object.
"""
def __init__(self, hs):
@@ -50,19 +57,9 @@ class Handlers(object):
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs)
- self.room_list_handler = RoomListHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs)
- asapi = ApplicationServiceApi(hs)
- self.appservice_handler = ApplicationServicesHandler(
- hs, asapi, AppServiceScheduler(
- clock=hs.get_clock(),
- store=hs.get_datastore(),
- as_api=asapi
- )
- )
- self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index c904c6c500..11081a0cd5 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.errors import LimitExceededError
+import synapse.types
from synapse.api.constants import Membership, EventTypes
-from synapse.types import UserID, Requester
-
-
-import logging
+from synapse.api.errors import LimitExceededError
+from synapse.types import UserID
logger = logging.getLogger(__name__)
@@ -31,11 +31,15 @@ class BaseHandler(object):
Common base class for the event handlers.
Attributes:
- store (synapse.storage.events.StateStore):
+ store (synapse.storage.DataStore):
state_handler (synapse.state.StateHandler):
"""
def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer):
+ """
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
@@ -120,7 +124,8 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
- requester = Requester(target_user, "", True)
+ requester = synapse.types.create_requester(
+ target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 75fc74c797..051ccdb380 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService
-from synapse.types import UserID
import logging
@@ -35,16 +34,13 @@ def log_failure(failure):
)
-# NB: Purposefully not inheriting BaseHandler since that contains way too much
-# setup code which this handler does not need or use. This makes testing a lot
-# easier.
class ApplicationServicesHandler(object):
- def __init__(self, hs, appservice_api, appservice_scheduler):
+ def __init__(self, hs):
self.store = hs.get_datastore()
- self.hs = hs
- self.appservice_api = appservice_api
- self.scheduler = appservice_scheduler
+ self.is_mine_id = hs.is_mine_id
+ self.appservice_api = hs.get_application_service_api()
+ self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False
@defer.inlineCallbacks
@@ -169,8 +165,7 @@ class ApplicationServicesHandler(object):
@defer.inlineCallbacks
def _is_unknown_user(self, user_id):
- user = UserID.from_string(user_id)
- if not self.hs.is_mine(user):
+ if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
defer.returnValue(False)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 68d0d78fc6..82998a81ce 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,8 +18,9 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
-from synapse.api.errors import AuthError, LoginError, Codes
+from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor
+from synapse.config.ldap import LDAPMode
from twisted.web.client import PartialDownloadError
@@ -28,6 +29,12 @@ import bcrypt
import pymacaroons
import simplejson
+try:
+ import ldap3
+except ImportError:
+ ldap3 = None
+ pass
+
import synapse.util.stringutils as stringutils
@@ -38,6 +45,10 @@ class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer):
+ """
super(AuthHandler, self).__init__(hs)
self.checkers = {
LoginType.PASSWORD: self._check_password_auth,
@@ -50,19 +61,23 @@ class AuthHandler(BaseHandler):
self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled
- self.ldap_server = hs.config.ldap_server
- self.ldap_port = hs.config.ldap_port
- self.ldap_tls = hs.config.ldap_tls
- self.ldap_search_base = hs.config.ldap_search_base
- self.ldap_search_property = hs.config.ldap_search_property
- self.ldap_email_property = hs.config.ldap_email_property
- self.ldap_full_name_property = hs.config.ldap_full_name_property
-
- if self.ldap_enabled is True:
- import ldap
- logger.info("Import ldap version: %s", ldap.__version__)
+ if self.ldap_enabled:
+ if not ldap3:
+ raise RuntimeError(
+ 'Missing ldap3 library. This is required for LDAP Authentication.'
+ )
+ self.ldap_mode = hs.config.ldap_mode
+ self.ldap_uri = hs.config.ldap_uri
+ self.ldap_start_tls = hs.config.ldap_start_tls
+ self.ldap_base = hs.config.ldap_base
+ self.ldap_filter = hs.config.ldap_filter
+ self.ldap_attributes = hs.config.ldap_attributes
+ if self.ldap_mode == LDAPMode.SEARCH:
+ self.ldap_bind_dn = hs.config.ldap_bind_dn
+ self.ldap_bind_password = hs.config.ldap_bind_password
self.hs = hs # FIXME better possibility to access registrationHandler later?
+ self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
@@ -220,7 +235,6 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
- @defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
@@ -230,11 +244,7 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
- if not (yield self._check_password(user_id, password)):
- logger.warn("Failed password login for user %s", user_id)
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
- defer.returnValue(user_id)
+ return self._check_password(user_id, password)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@@ -270,8 +280,17 @@ class AuthHandler(BaseHandler):
data = pde.response
resp_body = simplejson.loads(data)
- if 'success' in resp_body and resp_body['success']:
- defer.returnValue(True)
+ if 'success' in resp_body:
+ # Note that we do NOT check the hostname here: we explicitly
+ # intend the CAPTCHA to be presented by whatever client the
+ # user is using, we just care that they have completed a CAPTCHA.
+ logger.info(
+ "%s reCAPTCHA from hostname %s",
+ "Successful" if resp_body['success'] else "Failed",
+ resp_body.get('hostname')
+ )
+ if resp_body['success']:
+ defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
@@ -338,67 +357,84 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
- @defer.inlineCallbacks
- def login_with_password(self, user_id, password):
+ def validate_password_login(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
- user_id (str): User ID
+ user_id (str): complete @user:id
password (str): Password
Returns:
- A tuple of:
- The user's ID.
- The access token for the user's session.
- The refresh token for the user's session.
+ defer.Deferred: (str) canonical user id
Raises:
- StoreError if there was a problem storing the token.
+ StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""
-
- if not (yield self._check_password(user_id, password)):
- logger.warn("Failed password login for user %s", user_id)
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
- logger.info("Logging in user %s", user_id)
- access_token = yield self.issue_access_token(user_id)
- refresh_token = yield self.issue_refresh_token(user_id)
- defer.returnValue((user_id, access_token, refresh_token))
+ return self._check_password(user_id, password)
@defer.inlineCallbacks
- def get_login_tuple_for_user_id(self, user_id):
+ def get_login_tuple_for_user_id(self, user_id, device_id=None,
+ initial_display_name=None):
"""
Gets login tuple for the user with the given user ID.
+
+ Creates a new access/refresh token for the user.
+
The user is assumed to have been authenticated by some other
- machanism (e.g. CAS)
+ machanism (e.g. CAS), and the user_id converted to the canonical case.
+
+ The device will be recorded in the table if it is not there already.
Args:
- user_id (str): User ID
+ user_id (str): canonical User ID
+ device_id (str|None): the device ID to associate with the tokens.
+ None to leave the tokens unassociated with a device (deprecated:
+ we should always have a device ID)
+ initial_display_name (str): display name to associate with the
+ device if it needs re-registering
Returns:
A tuple of:
- The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
- user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
+ logger.info("Logging in user %s on device %s", user_id, device_id)
+ access_token = yield self.issue_access_token(user_id, device_id)
+ refresh_token = yield self.issue_refresh_token(user_id, device_id)
+
+ # the device *should* have been registered before we got here; however,
+ # it's possible we raced against a DELETE operation. The thing we
+ # really don't want is active access_tokens without a record of the
+ # device, so we double-check it here.
+ if device_id is not None:
+ yield self.device_handler.check_device_registered(
+ user_id, device_id, initial_display_name
+ )
- logger.info("Logging in user %s", user_id)
- access_token = yield self.issue_access_token(user_id)
- refresh_token = yield self.issue_refresh_token(user_id)
- defer.returnValue((user_id, access_token, refresh_token))
+ defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks
- def does_user_exist(self, user_id):
+ def check_user_exists(self, user_id):
+ """
+ Checks to see if a user with the given id exists. Will check case
+ insensitively, but return None if there are multiple inexact matches.
+
+ Args:
+ (str) user_id: complete @user:id
+
+ Returns:
+ defer.Deferred: (str) canonical_user_id, or None if zero or
+ multiple matches
+ """
try:
- yield self._find_user_id_and_pwd_hash(user_id)
- defer.returnValue(True)
+ res = yield self._find_user_id_and_pwd_hash(user_id)
+ defer.returnValue(res[0])
except LoginError:
- defer.returnValue(False)
+ defer.returnValue(None)
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
@@ -428,84 +464,232 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def _check_password(self, user_id, password):
- """
+ """Authenticate a user against the LDAP and local databases.
+
+ user_id is checked case insensitively against the local database, but
+ will throw if there are multiple inexact matches.
+
+ Args:
+ user_id (str): complete @user:id
Returns:
- True if the user_id successfully authenticated
+ (str) the canonical_user_id
+ Raises:
+ LoginError if the password was incorrect
"""
valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap:
- defer.returnValue(True)
-
- valid_local_password = yield self._check_local_password(user_id, password)
- if valid_local_password:
- defer.returnValue(True)
+ defer.returnValue(user_id)
- defer.returnValue(False)
+ result = yield self._check_local_password(user_id, password)
+ defer.returnValue(result)
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
- try:
- user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
- defer.returnValue(self.validate_hash(password, password_hash))
- except LoginError:
- defer.returnValue(False)
+ """Authenticate a user against the local password database.
+
+ user_id is checked case insensitively, but will throw if there are
+ multiple inexact matches.
+
+ Args:
+ user_id (str): complete @user:id
+ Returns:
+ (str) the canonical_user_id
+ Raises:
+ LoginError if the password was incorrect
+ """
+ user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
+ result = self.validate_hash(password, password_hash)
+ if not result:
+ logger.warn("Failed password login for user %s", user_id)
+ raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+ defer.returnValue(user_id)
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
- if not self.ldap_enabled:
- logger.debug("LDAP not configured")
+ """ Attempt to authenticate a user against an LDAP Server
+ and register an account if none exists.
+
+ Returns:
+ True if authentication against LDAP was successful
+ """
+
+ if not ldap3 or not self.ldap_enabled:
defer.returnValue(False)
- import ldap
+ if self.ldap_mode not in LDAPMode.LIST:
+ raise RuntimeError(
+ 'Invalid ldap mode specified: {mode}'.format(
+ mode=self.ldap_mode
+ )
+ )
- logger.info("Authenticating %s with LDAP" % user_id)
try:
- ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
- logger.debug("Connecting LDAP server at %s" % ldap_url)
- l = ldap.initialize(ldap_url)
- if self.ldap_tls:
- logger.debug("Initiating TLS")
- self._connection.start_tls_s()
-
- local_name = UserID.from_string(user_id).localpart
-
- dn = "%s=%s, %s" % (
- self.ldap_search_property,
- local_name,
- self.ldap_search_base)
- logger.debug("DN for LDAP authentication: %s" % dn)
-
- l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
-
- if not (yield self.does_user_exist(user_id)):
- handler = self.hs.get_handlers().registration_handler
- user_id, access_token = (
- yield handler.register(localpart=local_name)
+ server = ldap3.Server(self.ldap_uri)
+ logger.debug(
+ "Attempting ldap connection with %s",
+ self.ldap_uri
+ )
+
+ localpart = UserID.from_string(user_id).localpart
+ if self.ldap_mode == LDAPMode.SIMPLE:
+ # bind with the the local users ldap credentials
+ bind_dn = "{prop}={value},{base}".format(
+ prop=self.ldap_attributes['uid'],
+ value=localpart,
+ base=self.ldap_base
+ )
+ conn = ldap3.Connection(server, bind_dn, password)
+ logger.debug(
+ "Established ldap connection in simple mode: %s",
+ conn
)
+ if self.ldap_start_tls:
+ conn.start_tls()
+ logger.debug(
+ "Upgraded ldap connection in simple mode through StartTLS: %s",
+ conn
+ )
+
+ conn.bind()
+
+ elif self.ldap_mode == LDAPMode.SEARCH:
+ # connect with preconfigured credentials and search for local user
+ conn = ldap3.Connection(
+ server,
+ self.ldap_bind_dn,
+ self.ldap_bind_password
+ )
+ logger.debug(
+ "Established ldap connection in search mode: %s",
+ conn
+ )
+
+ if self.ldap_start_tls:
+ conn.start_tls()
+ logger.debug(
+ "Upgraded ldap connection in search mode through StartTLS: %s",
+ conn
+ )
+
+ conn.bind()
+
+ # find matching dn
+ query = "({prop}={value})".format(
+ prop=self.ldap_attributes['uid'],
+ value=localpart
+ )
+ if self.ldap_filter:
+ query = "(&{query}{filter})".format(
+ query=query,
+ filter=self.ldap_filter
+ )
+ logger.debug("ldap search filter: %s", query)
+ result = conn.search(self.ldap_base, query)
+
+ if result and len(conn.response) == 1:
+ # found exactly one result
+ user_dn = conn.response[0]['dn']
+ logger.debug('ldap search found dn: %s', user_dn)
+
+ # unbind and reconnect, rebind with found dn
+ conn.unbind()
+ conn = ldap3.Connection(
+ server,
+ user_dn,
+ password,
+ auto_bind=True
+ )
+ else:
+ # found 0 or > 1 results, abort!
+ logger.warn(
+ "ldap search returned unexpected (%d!=1) amount of results",
+ len(conn.response)
+ )
+ defer.returnValue(False)
+
+ logger.info(
+ "User authenticated against ldap server: %s",
+ conn
+ )
+
+ # check for existing account, if none exists, create one
+ if not (yield self.check_user_exists(user_id)):
+ # query user metadata for account creation
+ query = "({prop}={value})".format(
+ prop=self.ldap_attributes['uid'],
+ value=localpart
+ )
+
+ if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
+ query = "(&{filter}{user_filter})".format(
+ filter=query,
+ user_filter=self.ldap_filter
+ )
+ logger.debug("ldap registration filter: %s", query)
+
+ result = conn.search(
+ search_base=self.ldap_base,
+ search_filter=query,
+ attributes=[
+ self.ldap_attributes['name'],
+ self.ldap_attributes['mail']
+ ]
+ )
+
+ if len(conn.response) == 1:
+ attrs = conn.response[0]['attributes']
+ mail = attrs[self.ldap_attributes['mail']][0]
+ name = attrs[self.ldap_attributes['name']][0]
+
+ # create account
+ registration_handler = self.hs.get_handlers().registration_handler
+ user_id, access_token = (
+ yield registration_handler.register(localpart=localpart)
+ )
+
+ # TODO: bind email, set displayname with data from ldap directory
+
+ logger.info(
+ "ldap registration successful: %d: %s (%s, %)",
+ user_id,
+ localpart,
+ name,
+ mail
+ )
+ else:
+ logger.warn(
+ "ldap registration failed: unexpected (%d!=1) amount of results",
+ len(result)
+ )
+ defer.returnValue(False)
+
defer.returnValue(True)
- except ldap.LDAPError, e:
- logger.warn("LDAP error: %s", e)
+ except ldap3.core.exceptions.LDAPException as e:
+ logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
@defer.inlineCallbacks
- def issue_access_token(self, user_id):
+ def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id)
- yield self.store.add_access_token_to_user(user_id, access_token)
+ yield self.store.add_access_token_to_user(user_id, access_token,
+ device_id)
defer.returnValue(access_token)
@defer.inlineCallbacks
- def issue_refresh_token(self, user_id):
+ def issue_refresh_token(self, user_id, device_id=None):
refresh_token = self.generate_refresh_token(user_id)
- yield self.store.add_refresh_token_to_user(user_id, refresh_token)
+ yield self.store.add_refresh_token_to_user(user_id, refresh_token,
+ device_id)
defer.returnValue(refresh_token)
- def generate_access_token(self, user_id, extra_caveats=None):
+ def generate_access_token(self, user_id, extra_caveats=None,
+ duration_in_ms=(60 * 60 * 1000)):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec()
- expiry = now + (60 * 60 * 1000)
+ expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
@@ -529,14 +713,20 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
+ def generate_delete_pusher_token(self, user_id):
+ macaroon = self._generate_base_macaroon(user_id)
+ macaroon.add_first_party_caveat("type = delete_pusher")
+ return macaroon.serialize()
+
def validate_short_term_login_token_and_get_user_id(self, login_token):
+ auth_api = self.hs.get_auth()
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
- auth_api = self.hs.get_auth()
- auth_api.validate_macaroon(macaroon, "login", True)
- return self.get_user_from_macaroon(macaroon)
- except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
- raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
+ user_id = auth_api.get_user_id_from_macaroon(macaroon)
+ auth_api.validate_macaroon(macaroon, "login", True, user_id)
+ return user_id
+ except Exception:
+ raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
@@ -547,23 +737,18 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
- def get_user_from_macaroon(self, macaroon):
- user_prefix = "user_id = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(user_prefix):
- return caveat.caveat_id[len(user_prefix):]
- raise AuthError(
- self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
- errcode=Codes.UNKNOWN_TOKEN
- )
-
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else []
- yield self.store.user_set_password_hash(user_id, password_hash)
+ try:
+ yield self.store.user_set_password_hash(user_id, password_hash)
+ except StoreError as e:
+ if e.code == 404:
+ raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
+ raise e
yield self.store.user_delete_access_tokens(
user_id, except_access_token_ids
)
@@ -603,7 +788,8 @@ class AuthHandler(BaseHandler):
Returns:
Hashed password (str).
"""
- return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
+ return bcrypt.hashpw(password + self.hs.config.password_pepper,
+ bcrypt.gensalt(self.bcrypt_rounds))
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@@ -616,6 +802,7 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash (bool).
"""
if stored_hash:
- return bcrypt.hashpw(password, stored_hash) == stored_hash
+ return bcrypt.hashpw(password + self.hs.config.password_pepper,
+ stored_hash.encode('utf-8')) == stored_hash
else:
return False
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
new file mode 100644
index 0000000000..8d630c6b1a
--- /dev/null
+++ b/synapse/handlers/device.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api import errors
+from synapse.util import stringutils
+from twisted.internet import defer
+from ._base import BaseHandler
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceHandler(BaseHandler):
+ def __init__(self, hs):
+ super(DeviceHandler, self).__init__(hs)
+
+ @defer.inlineCallbacks
+ def check_device_registered(self, user_id, device_id,
+ initial_device_display_name=None):
+ """
+ If the given device has not been registered, register it with the
+ supplied display name.
+
+ If no device_id is supplied, we make one up.
+
+ Args:
+ user_id (str): @user:id
+ device_id (str | None): device id supplied by client
+ initial_device_display_name (str | None): device display name from
+ client
+ Returns:
+ str: device id (generated if none was supplied)
+ """
+ if device_id is not None:
+ yield self.store.store_device(
+ user_id=user_id,
+ device_id=device_id,
+ initial_device_display_name=initial_device_display_name,
+ ignore_if_known=True,
+ )
+ defer.returnValue(device_id)
+
+ # if the device id is not specified, we'll autogen one, but loop a few
+ # times in case of a clash.
+ attempts = 0
+ while attempts < 5:
+ try:
+ device_id = stringutils.random_string_with_symbols(16)
+ yield self.store.store_device(
+ user_id=user_id,
+ device_id=device_id,
+ initial_device_display_name=initial_device_display_name,
+ ignore_if_known=False,
+ )
+ defer.returnValue(device_id)
+ except errors.StoreError:
+ attempts += 1
+
+ raise errors.StoreError(500, "Couldn't generate a device ID.")
+
+ @defer.inlineCallbacks
+ def get_devices_by_user(self, user_id):
+ """
+ Retrieve the given user's devices
+
+ Args:
+ user_id (str):
+ Returns:
+ defer.Deferred: list[dict[str, X]]: info on each device
+ """
+
+ device_map = yield self.store.get_devices_by_user(user_id)
+
+ ips = yield self.store.get_last_client_ip_by_device(
+ devices=((user_id, device_id) for device_id in device_map.keys())
+ )
+
+ devices = device_map.values()
+ for device in devices:
+ _update_device_from_client_ips(device, ips)
+
+ defer.returnValue(devices)
+
+ @defer.inlineCallbacks
+ def get_device(self, user_id, device_id):
+ """ Retrieve the given device
+
+ Args:
+ user_id (str):
+ device_id (str):
+
+ Returns:
+ defer.Deferred: dict[str, X]: info on the device
+ Raises:
+ errors.NotFoundError: if the device was not found
+ """
+ try:
+ device = yield self.store.get_device(user_id, device_id)
+ except errors.StoreError:
+ raise errors.NotFoundError
+ ips = yield self.store.get_last_client_ip_by_device(
+ devices=((user_id, device_id),)
+ )
+ _update_device_from_client_ips(device, ips)
+ defer.returnValue(device)
+
+ @defer.inlineCallbacks
+ def delete_device(self, user_id, device_id):
+ """ Delete the given device
+
+ Args:
+ user_id (str):
+ device_id (str):
+
+ Returns:
+ defer.Deferred:
+ """
+
+ try:
+ yield self.store.delete_device(user_id, device_id)
+ except errors.StoreError, e:
+ if e.code == 404:
+ # no match
+ pass
+ else:
+ raise
+
+ yield self.store.user_delete_access_tokens(
+ user_id, device_id=device_id,
+ delete_refresh_tokens=True,
+ )
+
+ yield self.store.delete_e2e_keys_by_device(
+ user_id=user_id, device_id=device_id
+ )
+
+ @defer.inlineCallbacks
+ def update_device(self, user_id, device_id, content):
+ """ Update the given device
+
+ Args:
+ user_id (str):
+ device_id (str):
+ content (dict): body of update request
+
+ Returns:
+ defer.Deferred:
+ """
+
+ try:
+ yield self.store.update_device(
+ user_id,
+ device_id,
+ new_display_name=content.get("display_name")
+ )
+ except errors.StoreError, e:
+ if e.code == 404:
+ raise errors.NotFoundError()
+ else:
+ raise
+
+
+def _update_device_from_client_ips(device, client_ips):
+ ip = client_ips.get((device["user_id"], device["device_id"]), {})
+ device.update({
+ "last_seen_ts": ip.get("last_seen"),
+ "last_seen_ip": ip.get("ip"),
+ })
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 8eeb225811..4bea7f2b19 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -33,6 +33,7 @@ class DirectoryHandler(BaseHandler):
super(DirectoryHandler, self).__init__(hs)
self.state = hs.get_state_handler()
+ self.appservice_handler = hs.get_application_service_handler()
self.federation = hs.get_replication_layer()
self.federation.register_query_handler(
@@ -281,7 +282,7 @@ class DirectoryHandler(BaseHandler):
)
if not result:
# Query AS to see if it exists
- as_handler = self.hs.get_handlers().appservice_handler
+ as_handler = self.appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias)
defer.returnValue(result)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
new file mode 100644
index 0000000000..2c7bfd91ed
--- /dev/null
+++ b/synapse/handlers/e2e_keys.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import json
+import logging
+
+from twisted.internet import defer
+
+from synapse.api import errors
+import synapse.types
+
+logger = logging.getLogger(__name__)
+
+
+class E2eKeysHandler(object):
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.federation = hs.get_replication_layer()
+ self.is_mine_id = hs.is_mine_id
+ self.server_name = hs.hostname
+
+ # doesn't really work as part of the generic query API, because the
+ # query request requires an object POST, but we abuse the
+ # "query handler" interface.
+ self.federation.register_query_handler(
+ "client_keys", self.on_federation_query_client_keys
+ )
+
+ @defer.inlineCallbacks
+ def query_devices(self, query_body):
+ """ Handle a device key query from a client
+
+ {
+ "device_keys": {
+ "<user_id>": ["<device_id>"]
+ }
+ }
+ ->
+ {
+ "device_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ ...
+ }
+ }
+ }
+ }
+ """
+ device_keys_query = query_body.get("device_keys", {})
+
+ # separate users by domain.
+ # make a map from domain to user_id to device_ids
+ queries_by_domain = collections.defaultdict(dict)
+ for user_id, device_ids in device_keys_query.items():
+ user = synapse.types.UserID.from_string(user_id)
+ queries_by_domain[user.domain][user_id] = device_ids
+
+ # do the queries
+ # TODO: do these in parallel
+ results = {}
+ for destination, destination_query in queries_by_domain.items():
+ if destination == self.server_name:
+ res = yield self.query_local_devices(destination_query)
+ else:
+ res = yield self.federation.query_client_keys(
+ destination, {"device_keys": destination_query}
+ )
+ res = res["device_keys"]
+ for user_id, keys in res.items():
+ if user_id in destination_query:
+ results[user_id] = keys
+
+ defer.returnValue((200, {"device_keys": results}))
+
+ @defer.inlineCallbacks
+ def query_local_devices(self, query):
+ """Get E2E device keys for local users
+
+ Args:
+ query (dict[string, list[string]|None): map from user_id to a list
+ of devices to query (None for all devices)
+
+ Returns:
+ defer.Deferred: (resolves to dict[string, dict[string, dict]]):
+ map from user_id -> device_id -> device details
+ """
+ local_query = []
+
+ result_dict = {}
+ for user_id, device_ids in query.items():
+ if not self.is_mine_id(user_id):
+ logger.warning("Request for keys for non-local user %s",
+ user_id)
+ raise errors.SynapseError(400, "Not a user here")
+
+ if not device_ids:
+ local_query.append((user_id, None))
+ else:
+ for device_id in device_ids:
+ local_query.append((user_id, device_id))
+
+ # make sure that each queried user appears in the result dict
+ result_dict[user_id] = {}
+
+ results = yield self.store.get_e2e_device_keys(local_query)
+
+ # Build the result structure, un-jsonify the results, and add the
+ # "unsigned" section
+ for user_id, device_keys in results.items():
+ for device_id, device_info in device_keys.items():
+ r = json.loads(device_info["key_json"])
+ r["unsigned"] = {}
+ display_name = device_info["device_display_name"]
+ if display_name is not None:
+ r["unsigned"]["device_display_name"] = display_name
+ result_dict[user_id][device_id] = r
+
+ defer.returnValue(result_dict)
+
+ @defer.inlineCallbacks
+ def on_federation_query_client_keys(self, query_body):
+ """ Handle a device key query from a federated server
+ """
+ device_keys_query = query_body.get("device_keys", {})
+ res = yield self.query_local_devices(device_keys_query)
+ defer.returnValue({"device_keys": res})
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 648a505e65..ff6bb475b5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -66,10 +66,6 @@ class FederationHandler(BaseHandler):
self.hs = hs
- self.distributor.observe("user_joined_room", self.user_joined_room)
-
- self.waiting_for_join_list = {}
-
self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer()
self.state_handler = hs.get_state_handler()
@@ -128,7 +124,7 @@ class FederationHandler(BaseHandler):
try:
event_stream_id, max_stream_id = yield self._persist_auth_tree(
- auth_chain, state, event
+ origin, auth_chain, state, event
)
except AuthError as e:
raise FederationError(
@@ -253,7 +249,7 @@ class FederationHandler(BaseHandler):
if ev.type != EventTypes.Member:
continue
try:
- domain = UserID.from_string(ev.state_key).domain
+ domain = get_domain_from_id(ev.state_key)
except:
continue
@@ -339,29 +335,58 @@ class FederationHandler(BaseHandler):
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
- seen_events = yield self.store.have_events(
- set(auth_events.keys()) | set(state_events.keys())
- )
-
- all_events = events + state_events.values() + auth_events.values()
required_auth = set(
- a_id for event in all_events for a_id, _ in event.auth_events
+ a_id
+ for event in events + state_events.values() + auth_events.values()
+ for a_id, _ in event.auth_events
)
-
+ auth_events.update({
+ e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
+ })
missing_auth = required_auth - set(auth_events)
- results = yield defer.gatherResults(
- [
- self.replication_layer.get_pdu(
- [dest],
- event_id,
- outlier=True,
- timeout=10000,
+ failed_to_fetch = set()
+
+ # Try and fetch any missing auth events from both DB and remote servers.
+ # We repeatedly do this until we stop finding new auth events.
+ while missing_auth - failed_to_fetch:
+ logger.info("Missing auth for backfill: %r", missing_auth)
+ ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
+ auth_events.update(ret_events)
+
+ required_auth.update(
+ a_id for event in ret_events.values() for a_id, _ in event.auth_events
+ )
+ missing_auth = required_auth - set(auth_events)
+
+ if missing_auth - failed_to_fetch:
+ logger.info(
+ "Fetching missing auth for backfill: %r",
+ missing_auth - failed_to_fetch
)
- for event_id in missing_auth
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError)
- auth_events.update({a.event_id: a for a in results})
+
+ results = yield defer.gatherResults(
+ [
+ self.replication_layer.get_pdu(
+ [dest],
+ event_id,
+ outlier=True,
+ timeout=10000,
+ )
+ for event_id in missing_auth - failed_to_fetch
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
+ auth_events.update({a.event_id: a for a in results})
+ required_auth.update(
+ a_id for event in results for a_id, _ in event.auth_events
+ )
+ missing_auth = required_auth - set(auth_events)
+
+ failed_to_fetch = missing_auth - set(auth_events)
+
+ seen_events = yield self.store.have_events(
+ set(auth_events.keys()) | set(state_events.keys())
+ )
ev_infos = []
for a in auth_events.values():
@@ -374,6 +399,7 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in a.auth_events
+ if a_id in auth_events
}
})
@@ -385,6 +411,7 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events
+ if a_id in auth_events
}
})
@@ -403,7 +430,7 @@ class FederationHandler(BaseHandler):
# previous to work out the state.
# TODO: We can probably do something more clever here.
yield self._handle_new_event(
- dest, event
+ dest, event, backfilled=True,
)
defer.returnValue(events)
@@ -639,7 +666,7 @@ class FederationHandler(BaseHandler):
pass
event_stream_id, max_stream_id = yield self._persist_auth_tree(
- auth_chain, state, event
+ origin, auth_chain, state, event
)
with PreserveLoggingContext():
@@ -690,7 +717,9 @@ class FederationHandler(BaseHandler):
logger.warn("Failed to create join %r because %s", event, e)
raise e
- self.auth.check(event, auth_events=context.current_state)
+ # The remote hasn't signed it yet, obviously. We'll do the full checks
+ # when we get the event back in `on_send_join_request`
+ self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
defer.returnValue(event)
@@ -920,7 +949,9 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check(event, auth_events=context.current_state)
+ # The remote hasn't signed it yet, obviously. We'll do the full checks
+ # when we get the event back in `on_send_leave_request`
+ self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
@@ -989,14 +1020,9 @@ class FederationHandler(BaseHandler):
defer.returnValue(None)
@defer.inlineCallbacks
- def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
+ def get_state_for_pdu(self, room_id, event_id):
yield run_on_reactor()
- if do_auth:
- in_room = yield self.auth.check_host_in_room(room_id, origin)
- if not in_room:
- raise AuthError(403, "Host not in room.")
-
state_groups = yield self.store.get_state_groups(
room_id, [event_id]
)
@@ -1020,13 +1046,16 @@ class FederationHandler(BaseHandler):
res = results.values()
for event in res:
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
+ # We sign these again because there was a bug where we
+ # incorrectly signed things the first time round
+ if self.hs.is_mine_id(event.event_id):
+ event.signatures.update(
+ compute_event_signature(
+ event,
+ self.hs.hostname,
+ self.hs.config.signing_key[0]
+ )
)
- )
defer.returnValue(res)
else:
@@ -1064,16 +1093,17 @@ class FederationHandler(BaseHandler):
)
if event:
- # FIXME: This is a temporary work around where we occasionally
- # return events slightly differently than when they were
- # originally signed
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
+ if self.hs.is_mine_id(event.event_id):
+ # FIXME: This is a temporary work around where we occasionally
+ # return events slightly differently than when they were
+ # originally signed
+ event.signatures.update(
+ compute_event_signature(
+ event,
+ self.hs.hostname,
+ self.hs.config.signing_key[0]
+ )
)
- )
if do_auth:
in_room = yield self.auth.check_host_in_room(
@@ -1083,6 +1113,12 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
+ events = yield self._filter_events_for_server(
+ origin, event.room_id, [event]
+ )
+
+ event = events[0]
+
defer.returnValue(event)
else:
defer.returnValue(None)
@@ -1091,15 +1127,6 @@ class FederationHandler(BaseHandler):
def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context)
- @log_function
- def user_joined_room(self, user, room_id):
- waiters = self.waiting_for_join_list.get(
- (user.to_string(), room_id),
- []
- )
- while waiters:
- waiters.pop().callback(None)
-
@defer.inlineCallbacks
@log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None,
@@ -1122,11 +1149,12 @@ class FederationHandler(BaseHandler):
backfilled=backfilled,
)
- # this intentionally does not yield: we don't care about the result
- # and don't need to wait for it.
- preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
- event_stream_id, max_stream_id
- )
+ if not backfilled:
+ # this intentionally does not yield: we don't care about the result
+ # and don't need to wait for it.
+ preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
+ event_stream_id, max_stream_id
+ )
defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1158,11 +1186,19 @@ class FederationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def _persist_auth_tree(self, auth_events, state, event):
+ def _persist_auth_tree(self, origin, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event seperately.
+ Will attempt to fetch missing auth events.
+
+ Args:
+ origin (str): Where the events came from
+ auth_events (list)
+ state (list)
+ event (Event)
+
Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event`
@@ -1175,7 +1211,7 @@ class FederationHandler(BaseHandler):
event_map = {
e.event_id: e
- for e in auth_events
+ for e in itertools.chain(auth_events, state, [event])
}
create_event = None
@@ -1184,10 +1220,29 @@ class FederationHandler(BaseHandler):
create_event = e
break
+ missing_auth_events = set()
+ for e in itertools.chain(auth_events, state, [event]):
+ for e_id, _ in e.auth_events:
+ if e_id not in event_map:
+ missing_auth_events.add(e_id)
+
+ for e_id in missing_auth_events:
+ m_ev = yield self.replication_layer.get_pdu(
+ [origin],
+ e_id,
+ outlier=True,
+ timeout=10000,
+ )
+ if m_ev and m_ev.event_id == e_id:
+ event_map[e_id] = m_ev
+ else:
+ logger.info("Failed to find auth event %r", e_id)
+
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events
+ if e_id in event_map
}
if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event
@@ -1421,7 +1476,7 @@ class FederationHandler(BaseHandler):
local_view = dict(auth_events)
remote_view = dict(auth_events)
remote_view.update({
- (d.type, d.state_key): d for d in different_events
+ (d.type, d.state_key): d for d in different_events if d
})
new_state, prev_state = self.state_handler.resolve_events(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 656ce124f9..559e5d5a71 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -21,7 +21,7 @@ from synapse.api.errors import (
)
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, Codes
import json
import logging
@@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler):
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
)
+ def _should_trust_id_server(self, id_server):
+ if id_server not in self.trusted_id_servers:
+ if self.trust_any_id_server_just_for_testing_do_not_use:
+ logger.warn(
+ "Trusting untrustworthy ID server %r even though it isn't"
+ " in the trusted id list for testing because"
+ " 'use_insecure_ssl_client_just_for_testing_do_not_use'"
+ " is set in the config",
+ id_server,
+ )
+ else:
+ return False
+ return True
+
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
yield run_on_reactor()
@@ -59,19 +73,12 @@ class IdentityHandler(BaseHandler):
else:
raise SynapseError(400, "No client_secret in creds")
- if id_server not in self.trusted_id_servers:
- if self.trust_any_id_server_just_for_testing_do_not_use:
- logger.warn(
- "Trusting untrustworthy ID server %r even though it isn't"
- " in the trusted id list for testing because"
- " 'use_insecure_ssl_client_just_for_testing_do_not_use'"
- " is set in the config",
- id_server,
- )
- else:
- logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
- 'credentials', id_server)
- defer.returnValue(None)
+ if not self._should_trust_id_server(id_server):
+ logger.warn(
+ '%s is not a trusted ID server: rejecting 3pid ' +
+ 'credentials', id_server
+ )
+ defer.returnValue(None)
data = {}
try:
@@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler):
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
yield run_on_reactor()
+ if not self._should_trust_id_server(id_server):
+ raise SynapseError(
+ 400, "Untrusted ID server '%s'" % id_server,
+ Codes.SERVER_NOT_TRUSTED
+ )
+
params = {
'email': email,
'client_secret': client_secret,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c41dafdef5..dc76d34a52 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -26,9 +26,9 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
)
from synapse.util import unwrapFirstError
-from synapse.util.async import concurrently_execute
+from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
+from synapse.util.logcontext import preserve_fn
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -50,9 +50,23 @@ class MessageHandler(BaseHandler):
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
+ self.pagination_lock = ReadWriteLock()
+
+ @defer.inlineCallbacks
+ def purge_history(self, room_id, event_id):
+ event = yield self.store.get_event(event_id)
+
+ if event.room_id != room_id:
+ raise SynapseError(400, "Event is for wrong room.")
+
+ depth = event.depth
+
+ with (yield self.pagination_lock.write(room_id)):
+ yield self.store.delete_old_state(room_id, depth)
+
@defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None,
- as_client_event=True):
+ as_client_event=True, event_filter=None):
"""Get messages in a room.
Args:
@@ -61,11 +75,11 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
as_client_event (bool): True to get events in client-server format.
+ event_filter (Filter): Filter to apply to results or None
Returns:
dict: Pagination API results
"""
user_id = requester.user.to_string()
- data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
@@ -85,42 +99,48 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room")
- membership, member_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
+ with (yield self.pagination_lock.read(room_id)):
+ membership, member_event_id = yield self._check_in_room_or_world_readable(
+ room_id, user_id
+ )
- if source_config.direction == 'b':
- # if we're going backwards, we might need to backfill. This
- # requires that we have a topo token.
- if room_token.topological:
- max_topo = room_token.topological
- else:
- max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
- room_id, room_token.stream
- )
+ if source_config.direction == 'b':
+ # if we're going backwards, we might need to backfill. This
+ # requires that we have a topo token.
+ if room_token.topological:
+ max_topo = room_token.topological
+ else:
+ max_topo = yield self.store.get_max_topological_token(
+ room_id, room_token.stream
+ )
+
+ if membership == Membership.LEAVE:
+ # If they have left the room then clamp the token to be before
+ # they left the room, to save the effort of loading from the
+ # database.
+ leave_token = yield self.store.get_topological_token_for_event(
+ member_event_id
+ )
+ leave_token = RoomStreamToken.parse(leave_token)
+ if leave_token.topological < max_topo:
+ source_config.from_key = str(leave_token)
- if membership == Membership.LEAVE:
- # If they have left the room then clamp the token to be before
- # they left the room, to save the effort of loading from the
- # database.
- leave_token = yield self.store.get_topological_token_for_event(
- member_event_id
+ yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ room_id, max_topo
)
- leave_token = RoomStreamToken.parse(leave_token)
- if leave_token.topological < max_topo:
- source_config.from_key = str(leave_token)
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
- room_id, max_topo
+ events, next_key = yield self.store.paginate_room_events(
+ room_id=room_id,
+ from_key=source_config.from_key,
+ to_key=source_config.to_key,
+ direction=source_config.direction,
+ limit=source_config.limit,
+ event_filter=event_filter,
)
- events, next_key = yield data_source.get_pagination_rows(
- requester.user, source_config, room_id
- )
-
- next_token = pagin_config.from_token.copy_and_replace(
- "room_key", next_key
- )
+ next_token = pagin_config.from_token.copy_and_replace(
+ "room_key", next_key
+ )
if not events:
defer.returnValue({
@@ -129,6 +149,9 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(),
})
+ if event_filter:
+ events = event_filter.filter(events)
+
events = yield filter_events_for_client(
self.store,
user_id,
@@ -908,13 +931,16 @@ class MessageHandler(BaseHandler):
"Failed to get destination from event %s", s.event_id
)
- with PreserveLoggingContext():
- # Don't block waiting on waking up all the listeners.
+ @defer.inlineCallbacks
+ def _notify():
+ yield run_on_reactor()
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
+ preserve_fn(_notify)()
+
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 37f57301fb..6b70fa3817 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -50,6 +50,8 @@ timers_fired_counter = metrics.register_counter("timers_fired")
federation_presence_counter = metrics.register_counter("federation_presence")
bump_active_time_counter = metrics.register_counter("bump_active_time")
+get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
+
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
# "currently_active"
@@ -68,6 +70,10 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000
# How often to resend presence to remote servers
FEDERATION_PING_INTERVAL = 25 * 60 * 1000
+# How long we will wait before assuming that the syncs from an external process
+# are dead.
+EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
+
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
@@ -158,15 +164,26 @@ class PresenceHandler(object):
self.serial_to_user = {}
self._next_serial = 1
- # Keeps track of the number of *ongoing* syncs. While this is non zero
- # a user will never go offline.
+ # Keeps track of the number of *ongoing* syncs on this process. While
+ # this is non zero a user will never go offline.
self.user_to_num_current_syncs = {}
+ # Keeps track of the number of *ongoing* syncs on other processes.
+ # While any sync is ongoing on another process the user will never
+ # go offline.
+ # Each process has a unique identifier and an update frequency. If
+ # no update is received from that process within the update period then
+ # we assume that all the sync requests on that process have stopped.
+ # Stored as a dict from process_id to set of user_id, and a dict of
+ # process_id to millisecond timestamp last updated.
+ self.external_process_to_current_syncs = {}
+ self.external_process_last_updated_ms = {}
+
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
self.clock.call_later(
- 30 * 1000,
+ 30,
self.clock.looping_call,
self._handle_timeouts,
5000,
@@ -266,31 +283,48 @@ class PresenceHandler(object):
"""Checks the presence of users that have timed out and updates as
appropriate.
"""
+ logger.info("Handling presence timeouts")
now = self.clock.time_msec()
- with Measure(self.clock, "presence_handle_timeouts"):
- # Fetch the list of users that *may* have timed out. Things may have
- # changed since the timeout was set, so we won't necessarily have to
- # take any action.
- users_to_check = self.wheel_timer.fetch(now)
+ try:
+ with Measure(self.clock, "presence_handle_timeouts"):
+ # Fetch the list of users that *may* have timed out. Things may have
+ # changed since the timeout was set, so we won't necessarily have to
+ # take any action.
+ users_to_check = set(self.wheel_timer.fetch(now))
+
+ # Check whether the lists of syncing processes from an external
+ # process have expired.
+ expired_process_ids = [
+ process_id for process_id, last_update
+ in self.external_process_last_updated_ms.items()
+ if now - last_update > EXTERNAL_PROCESS_EXPIRY
+ ]
+ for process_id in expired_process_ids:
+ users_to_check.update(
+ self.external_process_last_updated_ms.pop(process_id, ())
+ )
+ self.external_process_last_update.pop(process_id)
- states = [
- self.user_to_current_state.get(
- user_id, UserPresenceState.default(user_id)
- )
- for user_id in set(users_to_check)
- ]
+ states = [
+ self.user_to_current_state.get(
+ user_id, UserPresenceState.default(user_id)
+ )
+ for user_id in users_to_check
+ ]
- timers_fired_counter.inc_by(len(states))
+ timers_fired_counter.inc_by(len(states))
- changes = handle_timeouts(
- states,
- is_mine_fn=self.is_mine_id,
- user_to_num_current_syncs=self.user_to_num_current_syncs,
- now=now,
- )
+ changes = handle_timeouts(
+ states,
+ is_mine_fn=self.is_mine_id,
+ syncing_user_ids=self.get_currently_syncing_users(),
+ now=now,
+ )
- preserve_fn(self._update_states)(changes)
+ preserve_fn(self._update_states)(changes)
+ except:
+ logger.exception("Exception in _handle_timeouts loop")
@defer.inlineCallbacks
def bump_presence_active_time(self, user):
@@ -363,6 +397,74 @@ class PresenceHandler(object):
defer.returnValue(_user_syncing())
+ def get_currently_syncing_users(self):
+ """Get the set of user ids that are currently syncing on this HS.
+ Returns:
+ set(str): A set of user_id strings.
+ """
+ syncing_user_ids = {
+ user_id for user_id, count in self.user_to_num_current_syncs.items()
+ if count
+ }
+ for user_ids in self.external_process_to_current_syncs.values():
+ syncing_user_ids.update(user_ids)
+ return syncing_user_ids
+
+ @defer.inlineCallbacks
+ def update_external_syncs(self, process_id, syncing_user_ids):
+ """Update the syncing users for an external process
+
+ Args:
+ process_id(str): An identifier for the process the users are
+ syncing against. This allows synapse to process updates
+ as user start and stop syncing against a given process.
+ syncing_user_ids(set(str)): The set of user_ids that are
+ currently syncing on that server.
+ """
+
+ # Grab the previous list of user_ids that were syncing on that process
+ prev_syncing_user_ids = (
+ self.external_process_to_current_syncs.get(process_id, set())
+ )
+ # Grab the current presence state for both the users that are syncing
+ # now and the users that were syncing before this update.
+ prev_states = yield self.current_state_for_users(
+ syncing_user_ids | prev_syncing_user_ids
+ )
+ updates = []
+ time_now_ms = self.clock.time_msec()
+
+ # For each new user that is syncing check if we need to mark them as
+ # being online.
+ for new_user_id in syncing_user_ids - prev_syncing_user_ids:
+ prev_state = prev_states[new_user_id]
+ if prev_state.state == PresenceState.OFFLINE:
+ updates.append(prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=time_now_ms,
+ last_user_sync_ts=time_now_ms,
+ ))
+ else:
+ updates.append(prev_state.copy_and_replace(
+ last_user_sync_ts=time_now_ms,
+ ))
+
+ # For each user that is still syncing or stopped syncing update the
+ # last sync time so that we will correctly apply the grace period when
+ # they stop syncing.
+ for old_user_id in prev_syncing_user_ids:
+ prev_state = prev_states[old_user_id]
+ updates.append(prev_state.copy_and_replace(
+ last_user_sync_ts=time_now_ms,
+ ))
+
+ yield self._update_states(updates)
+
+ # Update the last updated time for the process. We expire the entries
+ # if we don't receive an update in the given timeframe.
+ self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
+ self.external_process_to_current_syncs[process_id] = syncing_user_ids
+
@defer.inlineCallbacks
def current_state_for_user(self, user_id):
"""Get the current presence state for a user.
@@ -879,13 +981,13 @@ class PresenceEventSource(object):
user_ids_changed = set()
changed = None
- if from_key and max_token - from_key < 100:
- # For small deltas, its quicker to get all changes and then
- # work out if we share a room or they're in our presence list
+ if from_key:
changed = stream_change_cache.get_all_entities_changed(from_key)
- # get_all_entities_changed can return None
- if changed is not None:
+ if changed is not None and len(changed) < 500:
+ # For small deltas, its quicker to get all changes and then
+ # work out if we share a room or they're in our presence list
+ get_updates_counter.inc("stream")
for other_user_id in changed:
if other_user_id in friends:
user_ids_changed.add(other_user_id)
@@ -897,6 +999,8 @@ class PresenceEventSource(object):
else:
# Too many possible updates. Find all users we can see and check
# if any of them have changed.
+ get_updates_counter.inc("full")
+
user_ids_to_check = set()
for room_id in room_ids:
users = yield self.store.get_users_in_room(room_id)
@@ -935,15 +1039,14 @@ class PresenceEventSource(object):
return self.get_new_events(user, from_key=None, include_offline=False)
-def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
+def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
"""Checks the presence of users that have timed out and updates as
appropriate.
Args:
user_states(list): List of UserPresenceState's to check.
is_mine_fn (fn): Function that returns if a user_id is ours
- user_to_num_current_syncs (dict): Mapping of user_id to number of currently
- active syncs.
+ syncing_user_ids (set): Set of user_ids with active syncs.
now (int): Current time in ms.
Returns:
@@ -954,21 +1057,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
for state in user_states:
is_mine = is_mine_fn(state.user_id)
- new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now)
+ new_state = handle_timeout(state, is_mine, syncing_user_ids, now)
if new_state:
changes[state.user_id] = new_state
return changes.values()
-def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
+def handle_timeout(state, is_mine, syncing_user_ids, now):
"""Checks the presence of the user to see if any of the timers have elapsed
Args:
state (UserPresenceState)
is_mine (bool): Whether the user is ours
- user_to_num_current_syncs (dict): Mapping of user_id to number of currently
- active syncs.
+ syncing_user_ids (set): Set of user_ids with active syncs.
now (int): Current time in ms.
Returns:
@@ -1002,7 +1104,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
# If there are have been no sync for a while (and none ongoing),
# set presence to offline
- if not user_to_num_current_syncs.get(user_id, 0):
+ if user_id not in syncing_user_ids:
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace(
state=PresenceState.OFFLINE,
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index e37409170d..d9ac09078d 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
+import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
-from synapse.types import UserID, Requester
-
+from synapse.types import UserID
from ._base import BaseHandler
-import logging
-
logger = logging.getLogger(__name__)
@@ -36,13 +36,6 @@ class ProfileHandler(BaseHandler):
"profile", self.on_profile_query
)
- distributor = hs.get_distributor()
-
- distributor.observe("registered_user", self.registered_user)
-
- def registered_user(self, user):
- return self.store.create_profile(user.localpart)
-
@defer.inlineCallbacks
def get_displayname(self, target_user):
if self.hs.is_mine(target_user):
@@ -172,7 +165,9 @@ class ProfileHandler(BaseHandler):
try:
# Assume the user isn't a guest because we don't let guests set
# profile or avatar data.
- requester = Requester(user, "", False)
+ # XXX why are we recreating `requester` here for each room?
+ # what was wrong with the `requester` we were passed?
+ requester = synapse.types.create_requester(user)
yield handler.update_membership(
requester,
user,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 5883b9111e..dd75c4fecf 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -14,19 +14,19 @@
# limitations under the License.
"""Contains functions for registering clients."""
+import logging
+import urllib
+
from twisted.internet import defer
-from synapse.types import UserID
+import synapse.types
from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
)
-from ._base import BaseHandler
-from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
-from synapse.util.distributor import registered_user
-
-import logging
-import urllib
+from synapse.types import UserID
+from synapse.util.async import run_on_reactor
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -37,8 +37,6 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth()
- self.distributor = hs.get_distributor()
- self.distributor.declare("registered_user")
self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None
@@ -55,6 +53,13 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME
)
+ if localpart[0] == '_':
+ raise SynapseError(
+ 400,
+ "User ID may not begin with _",
+ Codes.INVALID_USERNAME
+ )
+
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -93,7 +98,8 @@ class RegistrationHandler(BaseHandler):
password=None,
generate_token=True,
guest_access_token=None,
- make_guest=False
+ make_guest=False,
+ admin=False,
):
"""Registers a new client on the server.
@@ -101,8 +107,13 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None,
one will be generated.
password (str) : The password to assign to this user so they can
- login again. This can be None which means they cannot login again
- via a password (e.g. the user is an application service user).
+ login again. This can be None which means they cannot login again
+ via a password (e.g. the user is an application service user).
+ generate_token (bool): Whether a new access token should be
+ generated. Having this be True should be considered deprecated,
+ since it offers no means of associating a device_id with the
+ access_token. Instead you should call auth_handler.issue_access_token
+ after registration.
Returns:
A tuple of (user_id, access_token).
Raises:
@@ -140,9 +151,12 @@ class RegistrationHandler(BaseHandler):
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
+ create_profile_with_localpart=(
+ # If the user was a guest then they already have a profile
+ None if was_guest else user.localpart
+ ),
+ admin=admin,
)
-
- yield registered_user(self.distributor, user)
else:
# autogen a sequential user ID
attempts = 0
@@ -160,7 +174,8 @@ class RegistrationHandler(BaseHandler):
user_id=user_id,
token=token,
password_hash=password_hash,
- make_guest=make_guest
+ make_guest=make_guest,
+ create_profile_with_localpart=user.localpart,
)
except SynapseError:
# if user id is taken, just generate another
@@ -168,7 +183,6 @@ class RegistrationHandler(BaseHandler):
user_id = None
token = None
attempts += 1
- yield registered_user(self.distributor, user)
# We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding
@@ -195,15 +209,13 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service
)
- token = self.auth_handler().generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
- token=token,
password_hash="",
appservice_id=service_id,
+ create_profile_with_localpart=user.localpart,
)
- yield registered_user(self.distributor, user)
- defer.returnValue((user_id, token))
+ defer.returnValue(user_id)
@defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response):
@@ -248,9 +260,9 @@ class RegistrationHandler(BaseHandler):
yield self.store.register(
user_id=user_id,
token=token,
- password_hash=None
+ password_hash=None,
+ create_profile_with_localpart=user.localpart,
)
- yield registered_user(self.distributor, user)
except Exception as e:
yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors
@@ -359,8 +371,10 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data)
@defer.inlineCallbacks
- def get_or_create_user(self, localpart, displayname, duration_seconds):
- """Creates a new user or returns an access token for an existing one
+ def get_or_create_user(self, localpart, displayname, duration_in_ms,
+ password_hash=None):
+ """Creates a new user if the user does not exist,
+ else revokes all previous access tokens and generates a new one.
Args:
localpart : The local part of the user ID to register. If None,
@@ -387,32 +401,32 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
- auth_handler = self.hs.get_handlers().auth_handler
- token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
+ token = self.auth_handler().generate_access_token(
+ user_id, None, duration_in_ms)
if need_register:
yield self.store.register(
user_id=user_id,
token=token,
- password_hash=None
+ password_hash=password_hash,
+ create_profile_with_localpart=user.localpart,
)
-
- yield registered_user(self.distributor, user)
else:
- yield self.store.flush_user(user_id=user_id)
+ yield self.store.user_delete_access_tokens(user_id=user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler
+ requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname(
- user, user, displayname
+ user, requester, displayname
)
defer.returnValue((user_id, token))
def auth_handler(self):
- return self.hs.get_handlers().auth_handler
+ return self.hs.get_auth_handler()
@defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id):
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 3d63b3c513..bf6b1c1535 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -20,7 +20,7 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import (
- EventTypes, JoinRules, RoomCreationPreset,
+ EventTypes, JoinRules, RoomCreationPreset, Membership,
)
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils
@@ -36,6 +36,8 @@ import string
logger = logging.getLogger(__name__)
+REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
+
id_server_scheme = "https://"
@@ -343,9 +345,15 @@ class RoomCreationHandler(BaseHandler):
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
- self.response_cache = ResponseCache()
+ self.response_cache = ResponseCache(hs)
+ self.remote_list_request_cache = ResponseCache(hs)
+ self.remote_list_cache = {}
+ self.fetch_looping_call = hs.get_clock().looping_call(
+ self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
+ )
+ self.fetch_all_remote_lists()
- def get_public_room_list(self):
+ def get_local_public_room_list(self):
result = self.response_cache.get(())
if not result:
result = self.response_cache.set((), self._get_public_room_list())
@@ -359,14 +367,10 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def handle_room(room_id):
- # We pull each bit of state out indvidually to avoid pulling the
- # full state into memory. Due to how the caching works this should
- # be fairly quick, even if not originally in the cache.
- def get_state(etype, state_key):
- return self.state_handler.get_current_state(room_id, etype, state_key)
+ current_state = yield self.state_handler.get_current_state(room_id)
# Double check that this is actually a public room.
- join_rules_event = yield get_state(EventTypes.JoinRules, "")
+ join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC:
@@ -374,47 +378,51 @@ class RoomListHandler(BaseHandler):
result = {"room_id": room_id}
- joined_users = yield self.store.get_users_in_room(room_id)
- if len(joined_users) == 0:
+ num_joined_users = len([
+ 1 for _, event in current_state.items()
+ if event.type == EventTypes.Member
+ and event.membership == Membership.JOIN
+ ])
+ if num_joined_users == 0:
return
- result["num_joined_members"] = len(joined_users)
+ result["num_joined_members"] = num_joined_users
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
result["aliases"] = aliases
- name_event = yield get_state(EventTypes.Name, "")
+ name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
- topic_event = yield get_state(EventTypes.Topic, "")
+ topic_event = current_state.get((EventTypes.Topic, ""))
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
- canonical_event = yield get_state(EventTypes.CanonicalAlias, "")
+ canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
- visibility_event = yield get_state(EventTypes.RoomHistoryVisibility, "")
+ visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
- guest_event = yield get_state(EventTypes.GuestAccess, "")
+ guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
- avatar_event = yield get_state("m.room.avatar", "")
+ avatar_event = current_state.get(("m.room.avatar", ""))
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
@@ -427,6 +435,55 @@ class RoomListHandler(BaseHandler):
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": results})
+ @defer.inlineCallbacks
+ def fetch_all_remote_lists(self):
+ deferred = self.hs.get_replication_layer().get_public_rooms(
+ self.hs.config.secondary_directory_servers
+ )
+ self.remote_list_request_cache.set((), deferred)
+ self.remote_list_cache = yield deferred
+
+ @defer.inlineCallbacks
+ def get_aggregated_public_room_list(self):
+ """
+ Get the public room list from this server and the servers
+ specified in the secondary_directory_servers config option.
+ XXX: Pagination...
+ """
+ # We return the results from out cache which is updated by a looping call,
+ # unless we're missing a cache entry, in which case wait for the result
+ # of the fetch if there's one in progress. If not, omit that server.
+ wait = False
+ for s in self.hs.config.secondary_directory_servers:
+ if s not in self.remote_list_cache:
+ logger.warn("No cached room list from %s: waiting for fetch", s)
+ wait = True
+ break
+
+ if wait and self.remote_list_request_cache.get(()):
+ yield self.remote_list_request_cache.get(())
+
+ public_rooms = yield self.get_local_public_room_list()
+
+ # keep track of which room IDs we've seen so we can de-dup
+ room_ids = set()
+
+ # tag all the ones in our list with our server name.
+ # Also add the them to the de-deping set
+ for room in public_rooms['chunk']:
+ room["server_name"] = self.hs.hostname
+ room_ids.add(room["room_id"])
+
+ # Now add the results from federation
+ for server_name, server_result in self.remote_list_cache.items():
+ for room in server_result["chunk"]:
+ if room["room_id"] not in room_ids:
+ room["server_name"] = server_name
+ public_rooms["chunk"].append(room)
+ room_ids.add(room["room_id"])
+
+ defer.returnValue(public_rooms)
+
class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7e616f44fd..8cec8fc4ed 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -14,24 +14,22 @@
# limitations under the License.
-from twisted.internet import defer
+import logging
-from ._base import BaseHandler
+from signedjson.key import decode_verify_key_bytes
+from signedjson.sign import verify_signed_json
+from twisted.internet import defer
+from unpaddedbase64 import decode_base64
-from synapse.types import UserID, RoomID, Requester
+import synapse.types
from synapse.api.constants import (
EventTypes, Membership,
)
from synapse.api.errors import AuthError, SynapseError, Codes
+from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room
-
-from signedjson.sign import verify_signed_json
-from signedjson.key import decode_verify_key_bytes
-
-from unpaddedbase64 import decode_base64
-
-import logging
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
- requester = Requester(target_user, None, False)
+ requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9ebfccc8bf..0ee4ebe504 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes
from synapse.util.async import concurrently_execute
from synapse.util.logcontext import LoggingContext
@@ -139,7 +138,7 @@ class SyncHandler(object):
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
- self.response_cache = ResponseCache()
+ self.response_cache = ResponseCache(hs)
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):
@@ -194,195 +193,15 @@ class SyncHandler(object):
Returns:
A Deferred SyncResult.
"""
- if since_token is None or full_state:
- return self.full_state_sync(sync_config, since_token)
- else:
- return self.incremental_sync_with_gap(sync_config, since_token)
-
- @defer.inlineCallbacks
- def full_state_sync(self, sync_config, timeline_since_token):
- """Get a sync for a client which is starting without any state.
-
- If a 'message_since_token' is given, only timeline events which have
- happened since that token will be returned.
-
- Returns:
- A Deferred SyncResult.
- """
- now_token = yield self.event_sources.get_current_token()
-
- now_token, ephemeral_by_room = yield self.ephemeral_by_room(
- sync_config, now_token
- )
-
- presence_stream = self.event_sources.sources["presence"]
- # TODO (mjark): This looks wrong, shouldn't we be getting the presence
- # UP to the present rather than after the present?
- pagination_config = PaginationConfig(from_token=now_token)
- presence, _ = yield presence_stream.get_pagination_rows(
- user=sync_config.user,
- pagination_config=pagination_config.get_source_config("presence"),
- key=None
- )
-
- membership_list = (
- Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN
- )
-
- room_list = yield self.store.get_rooms_for_user_where_membership_is(
- user_id=sync_config.user.to_string(),
- membership_list=membership_list
- )
-
- account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(
- sync_config.user.to_string()
- )
- )
-
- account_data['m.push_rules'] = yield self.push_rules_for_user(
- sync_config.user
- )
-
- tags_by_room = yield self.store.get_tags_for_user(
- sync_config.user.to_string()
- )
-
- ignored_users = account_data.get(
- "m.ignored_user_list", {}
- ).get("ignored_users", {}).keys()
-
- joined = []
- invited = []
- archived = []
-
- user_id = sync_config.user.to_string()
-
- @defer.inlineCallbacks
- def _generate_room_entry(event):
- if event.membership == Membership.JOIN:
- room_result = yield self.full_state_sync_for_joined_room(
- room_id=event.room_id,
- sync_config=sync_config,
- now_token=now_token,
- timeline_since_token=timeline_since_token,
- ephemeral_by_room=ephemeral_by_room,
- tags_by_room=tags_by_room,
- account_data_by_room=account_data_by_room,
- )
- joined.append(room_result)
- elif event.membership == Membership.INVITE:
- if event.sender in ignored_users:
- return
- invite = yield self.store.get_event(event.event_id)
- invited.append(InvitedSyncResult(
- room_id=event.room_id,
- invite=invite,
- ))
- elif event.membership in (Membership.LEAVE, Membership.BAN):
- # Always send down rooms we were banned or kicked from.
- if not sync_config.filter_collection.include_leave:
- if event.membership == Membership.LEAVE:
- if user_id == event.sender:
- return
-
- leave_token = now_token.copy_and_replace(
- "room_key", "s%d" % (event.stream_ordering,)
- )
- room_result = yield self.full_state_sync_for_archived_room(
- sync_config=sync_config,
- room_id=event.room_id,
- leave_event_id=event.event_id,
- leave_token=leave_token,
- timeline_since_token=timeline_since_token,
- tags_by_room=tags_by_room,
- account_data_by_room=account_data_by_room,
- )
- archived.append(room_result)
-
- yield concurrently_execute(_generate_room_entry, room_list, 10)
-
- account_data_for_user = sync_config.filter_collection.filter_account_data(
- self.account_data_for_user(account_data)
- )
-
- presence = sync_config.filter_collection.filter_presence(
- presence
- )
-
- defer.returnValue(SyncResult(
- presence=presence,
- account_data=account_data_for_user,
- joined=joined,
- invited=invited,
- archived=archived,
- next_batch=now_token,
- ))
-
- @defer.inlineCallbacks
- def full_state_sync_for_joined_room(self, room_id, sync_config,
- now_token, timeline_since_token,
- ephemeral_by_room, tags_by_room,
- account_data_by_room):
- """Sync a room for a client which is starting without any state
- Returns:
- A Deferred JoinedSyncResult.
- """
-
- batch = yield self.load_filtered_recents(
- room_id, sync_config, now_token, since_token=timeline_since_token
- )
-
- room_sync = yield self.incremental_sync_with_gap_for_room(
- room_id, sync_config,
- now_token=now_token,
- since_token=timeline_since_token,
- ephemeral_by_room=ephemeral_by_room,
- tags_by_room=tags_by_room,
- account_data_by_room=account_data_by_room,
- batch=batch,
- full_state=True,
- )
-
- defer.returnValue(room_sync)
+ return self.generate_sync_result(sync_config, since_token, full_state)
@defer.inlineCallbacks
def push_rules_for_user(self, user):
user_id = user.to_string()
- rawrules = yield self.store.get_push_rules_for_user(user_id)
- enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
- rules = format_push_rules_for_user(user, rawrules, enabled_map)
+ rules = yield self.store.get_push_rules_for_user(user_id)
+ rules = format_push_rules_for_user(user, rules)
defer.returnValue(rules)
- def account_data_for_user(self, account_data):
- account_data_events = []
-
- for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
-
- return account_data_events
-
- def account_data_for_room(self, room_id, tags_by_room, account_data_by_room):
- account_data_events = []
- tags = tags_by_room.get(room_id)
- if tags is not None:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
-
- account_data = account_data_by_room.get(room_id, {})
- for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
-
- return account_data_events
-
@defer.inlineCallbacks
def ephemeral_by_room(self, sync_config, now_token, since_token=None):
"""Get the ephemeral events for each room the user is in
@@ -445,258 +264,22 @@ class SyncHandler(object):
defer.returnValue((now_token, ephemeral_by_room))
- def full_state_sync_for_archived_room(self, room_id, sync_config,
- leave_event_id, leave_token,
- timeline_since_token, tags_by_room,
- account_data_by_room):
- """Sync a room for a client which is starting without any state
- Returns:
- A Deferred ArchivedSyncResult.
- """
-
- return self.incremental_sync_for_archived_room(
- sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room,
- account_data_by_room, full_state=True, leave_token=leave_token,
- )
-
- @defer.inlineCallbacks
- def incremental_sync_with_gap(self, sync_config, since_token):
- """ Get the incremental delta needed to bring the client up to
- date with the server.
- Returns:
- A Deferred SyncResult.
- """
- now_token = yield self.event_sources.get_current_token()
-
- rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
- room_ids = [room.room_id for room in rooms]
-
- presence_source = self.event_sources.sources["presence"]
- presence, presence_key = yield presence_source.get_new_events(
- user=sync_config.user,
- from_key=since_token.presence_key,
- limit=sync_config.filter_collection.presence_limit(),
- room_ids=room_ids,
- is_guest=sync_config.is_guest,
- )
- now_token = now_token.copy_and_replace("presence_key", presence_key)
-
- now_token, ephemeral_by_room = yield self.ephemeral_by_room(
- sync_config, now_token, since_token
- )
-
- app_service = yield self.store.get_app_service_by_user_id(
- sync_config.user.to_string()
- )
- if app_service:
- rooms = yield self.store.get_app_service_rooms(app_service)
- joined_room_ids = set(r.room_id for r in rooms)
- else:
- rooms = yield self.store.get_rooms_for_user(
- sync_config.user.to_string()
- )
- joined_room_ids = set(r.room_id for r in rooms)
-
- user_id = sync_config.user.to_string()
-
- timeline_limit = sync_config.filter_collection.timeline_limit()
-
- tags_by_room = yield self.store.get_updated_tags(
- user_id,
- since_token.account_data_key,
- )
-
- account_data, account_data_by_room = (
- yield self.store.get_updated_account_data_for_user(
- user_id,
- since_token.account_data_key,
- )
- )
-
- push_rules_changed = yield self.store.have_push_rules_changed_for_user(
- user_id, int(since_token.push_rules_key)
- )
-
- if push_rules_changed:
- account_data["m.push_rules"] = yield self.push_rules_for_user(
- sync_config.user
- )
-
- ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", user_id=user_id,
- )
-
- if ignored_account_data:
- ignored_users = ignored_account_data.get("ignored_users", {}).keys()
- else:
- ignored_users = frozenset()
-
- # Get a list of membership change events that have happened.
- rooms_changed = yield self.store.get_membership_changes_for_user(
- user_id, since_token.room_key, now_token.room_key
- )
-
- mem_change_events_by_room_id = {}
- for event in rooms_changed:
- mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
-
- newly_joined_rooms = []
- archived = []
- invited = []
- for room_id, events in mem_change_events_by_room_id.items():
- non_joins = [e for e in events if e.membership != Membership.JOIN]
- has_join = len(non_joins) != len(events)
-
- # We want to figure out if we joined the room at some point since
- # the last sync (even if we have since left). This is to make sure
- # we do send down the room, and with full state, where necessary
- if room_id in joined_room_ids or has_join:
- old_state = yield self.get_state_at(room_id, since_token)
- old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
- if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
- newly_joined_rooms.append(room_id)
-
- if room_id in joined_room_ids:
- continue
-
- if not non_joins:
- continue
-
- # Only bother if we're still currently invited
- should_invite = non_joins[-1].membership == Membership.INVITE
- if should_invite:
- if event.sender not in ignored_users:
- room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
- if room_sync:
- invited.append(room_sync)
-
- # Always include leave/ban events. Just take the last one.
- # TODO: How do we handle ban -> leave in same batch?
- leave_events = [
- e for e in non_joins
- if e.membership in (Membership.LEAVE, Membership.BAN)
- ]
-
- if leave_events:
- leave_event = leave_events[-1]
- room_sync = yield self.incremental_sync_for_archived_room(
- sync_config, room_id, leave_event.event_id, since_token,
- tags_by_room, account_data_by_room,
- full_state=room_id in newly_joined_rooms
- )
- if room_sync:
- archived.append(room_sync)
-
- # Get all events for rooms we're currently joined to.
- room_to_events = yield self.store.get_room_events_stream_for_rooms(
- room_ids=joined_room_ids,
- from_key=since_token.room_key,
- to_key=now_token.room_key,
- limit=timeline_limit + 1,
- )
-
- joined = []
- # We loop through all room ids, even if there are no new events, in case
- # there are non room events taht we need to notify about.
- for room_id in joined_room_ids:
- room_entry = room_to_events.get(room_id, None)
-
- if room_entry:
- events, start_key = room_entry
-
- prev_batch_token = now_token.copy_and_replace("room_key", start_key)
-
- newly_joined_room = room_id in newly_joined_rooms
- full_state = newly_joined_room
-
- batch = yield self.load_filtered_recents(
- room_id, sync_config, prev_batch_token,
- since_token=since_token,
- recents=events,
- newly_joined_room=newly_joined_room,
- )
- else:
- batch = TimelineBatch(
- events=[],
- prev_batch=since_token,
- limited=False,
- )
- full_state = False
-
- room_sync = yield self.incremental_sync_with_gap_for_room(
- room_id=room_id,
- sync_config=sync_config,
- since_token=since_token,
- now_token=now_token,
- ephemeral_by_room=ephemeral_by_room,
- tags_by_room=tags_by_room,
- account_data_by_room=account_data_by_room,
- batch=batch,
- full_state=full_state,
- )
- if room_sync:
- joined.append(room_sync)
-
- # For each newly joined room, we want to send down presence of
- # existing users.
- presence_handler = self.presence_handler
- extra_presence_users = set()
- for room_id in newly_joined_rooms:
- users = yield self.store.get_users_in_room(event.room_id)
- extra_presence_users.update(users)
-
- # For each new member, send down presence.
- for joined_sync in joined:
- it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values())
- for event in it:
- if event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- extra_presence_users.add(event.state_key)
-
- states = yield presence_handler.get_states(
- [u for u in extra_presence_users if u != user_id],
- as_event=True,
- )
- presence.extend(states)
-
- account_data_for_user = sync_config.filter_collection.filter_account_data(
- self.account_data_for_user(account_data)
- )
-
- presence = sync_config.filter_collection.filter_presence(
- presence
- )
-
- defer.returnValue(SyncResult(
- presence=presence,
- account_data=account_data_for_user,
- joined=joined,
- invited=invited,
- archived=archived,
- next_batch=now_token,
- ))
-
@defer.inlineCallbacks
- def load_filtered_recents(self, room_id, sync_config, now_token,
- since_token=None, recents=None, newly_joined_room=False):
+ def _load_filtered_recents(self, room_id, sync_config, now_token,
+ since_token=None, recents=None, newly_joined_room=False):
"""
Returns:
a Deferred TimelineBatch
"""
with Measure(self.clock, "load_filtered_recents"):
- filtering_factor = 2
timeline_limit = sync_config.filter_collection.timeline_limit()
- load_limit = max(timeline_limit * filtering_factor, 10)
- max_repeat = 5 # Only try a few times per room, otherwise
- room_key = now_token.room_key
- end_key = room_key
if recents is None or newly_joined_room or timeline_limit < len(recents):
limited = True
else:
limited = False
- if recents is not None:
+ if recents:
recents = sync_config.filter_collection.filter_room_timeline(recents)
recents = yield filter_events_for_client(
self.store,
@@ -706,6 +289,19 @@ class SyncHandler(object):
else:
recents = []
+ if not limited:
+ defer.returnValue(TimelineBatch(
+ events=recents,
+ prev_batch=now_token,
+ limited=False
+ ))
+
+ filtering_factor = 2
+ load_limit = max(timeline_limit * filtering_factor, 10)
+ max_repeat = 5 # Only try a few times per room, otherwise
+ room_key = now_token.room_key
+ end_key = room_key
+
since_key = None
if since_token and not newly_joined_room:
since_key = since_token.room_key
@@ -749,103 +345,6 @@ class SyncHandler(object):
))
@defer.inlineCallbacks
- def incremental_sync_with_gap_for_room(self, room_id, sync_config,
- since_token, now_token,
- ephemeral_by_room, tags_by_room,
- account_data_by_room,
- batch, full_state=False):
- state = yield self.compute_state_delta(
- room_id, batch, sync_config, since_token, now_token,
- full_state=full_state
- )
-
- account_data = self.account_data_for_room(
- room_id, tags_by_room, account_data_by_room
- )
-
- account_data = sync_config.filter_collection.filter_room_account_data(
- account_data
- )
-
- ephemeral = sync_config.filter_collection.filter_room_ephemeral(
- ephemeral_by_room.get(room_id, [])
- )
-
- unread_notifications = {}
- room_sync = JoinedSyncResult(
- room_id=room_id,
- timeline=batch,
- state=state,
- ephemeral=ephemeral,
- account_data=account_data,
- unread_notifications=unread_notifications,
- )
-
- if room_sync:
- notifs = yield self.unread_notifs_for_room_id(
- room_id, sync_config
- )
-
- if notifs is not None:
- unread_notifications["notification_count"] = notifs["notify_count"]
- unread_notifications["highlight_count"] = notifs["highlight_count"]
-
- logger.debug("Room sync: %r", room_sync)
-
- defer.returnValue(room_sync)
-
- @defer.inlineCallbacks
- def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id,
- since_token, tags_by_room,
- account_data_by_room, full_state,
- leave_token=None):
- """ Get the incremental delta needed to bring the client up to date for
- the archived room.
- Returns:
- A Deferred ArchivedSyncResult
- """
-
- if not leave_token:
- stream_token = yield self.store.get_stream_token_for_event(
- leave_event_id
- )
-
- leave_token = since_token.copy_and_replace("room_key", stream_token)
-
- if since_token and since_token.is_after(leave_token):
- defer.returnValue(None)
-
- batch = yield self.load_filtered_recents(
- room_id, sync_config, leave_token, since_token,
- )
-
- logger.debug("Recents %r", batch)
-
- state_events_delta = yield self.compute_state_delta(
- room_id, batch, sync_config, since_token, leave_token,
- full_state=full_state
- )
-
- account_data = self.account_data_for_room(
- room_id, tags_by_room, account_data_by_room
- )
-
- account_data = sync_config.filter_collection.filter_room_account_data(
- account_data
- )
-
- room_sync = ArchivedSyncResult(
- room_id=room_id,
- timeline=batch,
- state=state_events_delta,
- account_data=account_data,
- )
-
- logger.debug("Room sync: %r", room_sync)
-
- defer.returnValue(room_sync)
-
- @defer.inlineCallbacks
def get_state_after_event(self, event):
"""
Get the room state after the given event
@@ -970,26 +469,6 @@ class SyncHandler(object):
for e in sync_config.filter_collection.filter_room_state(state.values())
})
- def check_joined_room(self, sync_config, state_delta):
- """
- Check if the user has just joined the given room (so should
- be given the full state)
-
- Args:
- sync_config(synapse.handlers.sync.SyncConfig):
- state_delta(dict[(str,str), synapse.events.FrozenEvent]): the
- difference in state since the last sync
-
- Returns:
- A deferred Tuple (state_delta, limited)
- """
- join_event = state_delta.get((
- EventTypes.Member, sync_config.user.to_string()), None)
- if join_event is not None:
- if join_event.content["membership"] == Membership.JOIN:
- return True
- return False
-
@defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config):
with Measure(self.clock, "unread_notifs_for_room_id"):
@@ -1010,6 +489,551 @@ class SyncHandler(object):
# count is whatever it was last time.
defer.returnValue(None)
+ @defer.inlineCallbacks
+ def generate_sync_result(self, sync_config, since_token=None, full_state=False):
+ """Generates a sync result.
+
+ Args:
+ sync_config (SyncConfig)
+ since_token (StreamToken)
+ full_state (bool)
+
+ Returns:
+ Deferred(SyncResult)
+ """
+
+ # NB: The now_token gets changed by some of the generate_sync_* methods,
+ # this is due to some of the underlying streams not supporting the ability
+ # to query up to a given point.
+ # Always use the `now_token` in `SyncResultBuilder`
+ now_token = yield self.event_sources.get_current_token()
+
+ sync_result_builder = SyncResultBuilder(
+ sync_config, full_state,
+ since_token=since_token,
+ now_token=now_token,
+ )
+
+ account_data_by_room = yield self._generate_sync_entry_for_account_data(
+ sync_result_builder
+ )
+
+ res = yield self._generate_sync_entry_for_rooms(
+ sync_result_builder, account_data_by_room
+ )
+ newly_joined_rooms, newly_joined_users = res
+
+ yield self._generate_sync_entry_for_presence(
+ sync_result_builder, newly_joined_rooms, newly_joined_users
+ )
+
+ defer.returnValue(SyncResult(
+ presence=sync_result_builder.presence,
+ account_data=sync_result_builder.account_data,
+ joined=sync_result_builder.joined,
+ invited=sync_result_builder.invited,
+ archived=sync_result_builder.archived,
+ next_batch=sync_result_builder.now_token,
+ ))
+
+ @defer.inlineCallbacks
+ def _generate_sync_entry_for_account_data(self, sync_result_builder):
+ """Generates the account data portion of the sync response. Populates
+ `sync_result_builder` with the result.
+
+ Args:
+ sync_result_builder(SyncResultBuilder)
+
+ Returns:
+ Deferred(dict): A dictionary containing the per room account data.
+ """
+ sync_config = sync_result_builder.sync_config
+ user_id = sync_result_builder.sync_config.user.to_string()
+ since_token = sync_result_builder.since_token
+
+ if since_token and not sync_result_builder.full_state:
+ account_data, account_data_by_room = (
+ yield self.store.get_updated_account_data_for_user(
+ user_id,
+ since_token.account_data_key,
+ )
+ )
+
+ push_rules_changed = yield self.store.have_push_rules_changed_for_user(
+ user_id, int(since_token.push_rules_key)
+ )
+
+ if push_rules_changed:
+ account_data["m.push_rules"] = yield self.push_rules_for_user(
+ sync_config.user
+ )
+ else:
+ account_data, account_data_by_room = (
+ yield self.store.get_account_data_for_user(
+ sync_config.user.to_string()
+ )
+ )
+
+ account_data['m.push_rules'] = yield self.push_rules_for_user(
+ sync_config.user
+ )
+
+ account_data_for_user = sync_config.filter_collection.filter_account_data([
+ {"type": account_data_type, "content": content}
+ for account_data_type, content in account_data.items()
+ ])
+
+ sync_result_builder.account_data = account_data_for_user
+
+ defer.returnValue(account_data_by_room)
+
+ @defer.inlineCallbacks
+ def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms,
+ newly_joined_users):
+ """Generates the presence portion of the sync response. Populates the
+ `sync_result_builder` with the result.
+
+ Args:
+ sync_result_builder(SyncResultBuilder)
+ newly_joined_rooms(list): List of rooms that the user has joined
+ since the last sync (or empty if an initial sync)
+ newly_joined_users(list): List of users that have joined rooms
+ since the last sync (or empty if an initial sync)
+ """
+ now_token = sync_result_builder.now_token
+ sync_config = sync_result_builder.sync_config
+ user = sync_result_builder.sync_config.user
+
+ presence_source = self.event_sources.sources["presence"]
+
+ since_token = sync_result_builder.since_token
+ if since_token and not sync_result_builder.full_state:
+ presence_key = since_token.presence_key
+ include_offline = True
+ else:
+ presence_key = None
+ include_offline = False
+
+ presence, presence_key = yield presence_source.get_new_events(
+ user=user,
+ from_key=presence_key,
+ is_guest=sync_config.is_guest,
+ include_offline=include_offline,
+ )
+ sync_result_builder.now_token = now_token.copy_and_replace(
+ "presence_key", presence_key
+ )
+
+ extra_users_ids = set(newly_joined_users)
+ for room_id in newly_joined_rooms:
+ users = yield self.store.get_users_in_room(room_id)
+ extra_users_ids.update(users)
+ extra_users_ids.discard(user.to_string())
+
+ states = yield self.presence_handler.get_states(
+ extra_users_ids,
+ as_event=True,
+ )
+ presence.extend(states)
+
+ # Deduplicate the presence entries so that there's at most one per user
+ presence = {p["content"]["user_id"]: p for p in presence}.values()
+
+ presence = sync_config.filter_collection.filter_presence(
+ presence
+ )
+
+ sync_result_builder.presence = presence
+
+ @defer.inlineCallbacks
+ def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room):
+ """Generates the rooms portion of the sync response. Populates the
+ `sync_result_builder` with the result.
+
+ Args:
+ sync_result_builder(SyncResultBuilder)
+ account_data_by_room(dict): Dictionary of per room account data
+
+ Returns:
+ Deferred(tuple): Returns a 2-tuple of
+ `(newly_joined_rooms, newly_joined_users)`
+ """
+ user_id = sync_result_builder.sync_config.user.to_string()
+
+ now_token, ephemeral_by_room = yield self.ephemeral_by_room(
+ sync_result_builder.sync_config,
+ now_token=sync_result_builder.now_token,
+ since_token=sync_result_builder.since_token,
+ )
+ sync_result_builder.now_token = now_token
+
+ ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
+ "m.ignored_user_list", user_id=user_id,
+ )
+
+ if ignored_account_data:
+ ignored_users = ignored_account_data.get("ignored_users", {}).keys()
+ else:
+ ignored_users = frozenset()
+
+ if sync_result_builder.since_token:
+ res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
+ room_entries, invited, newly_joined_rooms = res
+
+ tags_by_room = yield self.store.get_updated_tags(
+ user_id,
+ sync_result_builder.since_token.account_data_key,
+ )
+ else:
+ res = yield self._get_all_rooms(sync_result_builder, ignored_users)
+ room_entries, invited, newly_joined_rooms = res
+
+ tags_by_room = yield self.store.get_tags_for_user(user_id)
+
+ def handle_room_entries(room_entry):
+ return self._generate_room_entry(
+ sync_result_builder,
+ ignored_users,
+ room_entry,
+ ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
+ tags=tags_by_room.get(room_entry.room_id),
+ account_data=account_data_by_room.get(room_entry.room_id, {}),
+ always_include=sync_result_builder.full_state,
+ )
+
+ yield concurrently_execute(handle_room_entries, room_entries, 10)
+
+ sync_result_builder.invited.extend(invited)
+
+ # Now we want to get any newly joined users
+ newly_joined_users = set()
+ if sync_result_builder.since_token:
+ for joined_sync in sync_result_builder.joined:
+ it = itertools.chain(
+ joined_sync.timeline.events, joined_sync.state.values()
+ )
+ for event in it:
+ if event.type == EventTypes.Member:
+ if event.membership == Membership.JOIN:
+ newly_joined_users.add(event.state_key)
+
+ defer.returnValue((newly_joined_rooms, newly_joined_users))
+
+ @defer.inlineCallbacks
+ def _get_rooms_changed(self, sync_result_builder, ignored_users):
+ """Gets the the changes that have happened since the last sync.
+
+ Args:
+ sync_result_builder(SyncResultBuilder)
+ ignored_users(set(str)): Set of users ignored by user.
+
+ Returns:
+ Deferred(tuple): Returns a tuple of the form:
+ `([RoomSyncResultBuilder], [InvitedSyncResult], newly_joined_rooms)`
+ """
+ user_id = sync_result_builder.sync_config.user.to_string()
+ since_token = sync_result_builder.since_token
+ now_token = sync_result_builder.now_token
+ sync_config = sync_result_builder.sync_config
+
+ assert since_token
+
+ app_service = yield self.store.get_app_service_by_user_id(user_id)
+ if app_service:
+ rooms = yield self.store.get_app_service_rooms(app_service)
+ joined_room_ids = set(r.room_id for r in rooms)
+ else:
+ rooms = yield self.store.get_rooms_for_user(user_id)
+ joined_room_ids = set(r.room_id for r in rooms)
+
+ # Get a list of membership change events that have happened.
+ rooms_changed = yield self.store.get_membership_changes_for_user(
+ user_id, since_token.room_key, now_token.room_key
+ )
+
+ mem_change_events_by_room_id = {}
+ for event in rooms_changed:
+ mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
+
+ newly_joined_rooms = []
+ room_entries = []
+ invited = []
+ for room_id, events in mem_change_events_by_room_id.items():
+ non_joins = [e for e in events if e.membership != Membership.JOIN]
+ has_join = len(non_joins) != len(events)
+
+ # We want to figure out if we joined the room at some point since
+ # the last sync (even if we have since left). This is to make sure
+ # we do send down the room, and with full state, where necessary
+ if room_id in joined_room_ids or has_join:
+ old_state = yield self.get_state_at(room_id, since_token)
+ old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
+ if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
+ newly_joined_rooms.append(room_id)
+
+ if room_id in joined_room_ids:
+ continue
+
+ if not non_joins:
+ continue
+
+ # Only bother if we're still currently invited
+ should_invite = non_joins[-1].membership == Membership.INVITE
+ if should_invite:
+ if event.sender not in ignored_users:
+ room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+ if room_sync:
+ invited.append(room_sync)
+
+ # Always include leave/ban events. Just take the last one.
+ # TODO: How do we handle ban -> leave in same batch?
+ leave_events = [
+ e for e in non_joins
+ if e.membership in (Membership.LEAVE, Membership.BAN)
+ ]
+
+ if leave_events:
+ leave_event = leave_events[-1]
+ leave_stream_token = yield self.store.get_stream_token_for_event(
+ leave_event.event_id
+ )
+ leave_token = since_token.copy_and_replace(
+ "room_key", leave_stream_token
+ )
+
+ if since_token and since_token.is_after(leave_token):
+ continue
+
+ room_entries.append(RoomSyncResultBuilder(
+ room_id=room_id,
+ rtype="archived",
+ events=None,
+ newly_joined=room_id in newly_joined_rooms,
+ full_state=False,
+ since_token=since_token,
+ upto_token=leave_token,
+ ))
+
+ timeline_limit = sync_config.filter_collection.timeline_limit()
+
+ # Get all events for rooms we're currently joined to.
+ room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_ids=joined_room_ids,
+ from_key=since_token.room_key,
+ to_key=now_token.room_key,
+ limit=timeline_limit + 1,
+ )
+
+ # We loop through all room ids, even if there are no new events, in case
+ # there are non room events taht we need to notify about.
+ for room_id in joined_room_ids:
+ room_entry = room_to_events.get(room_id, None)
+
+ if room_entry:
+ events, start_key = room_entry
+
+ prev_batch_token = now_token.copy_and_replace("room_key", start_key)
+
+ room_entries.append(RoomSyncResultBuilder(
+ room_id=room_id,
+ rtype="joined",
+ events=events,
+ newly_joined=room_id in newly_joined_rooms,
+ full_state=False,
+ since_token=None if room_id in newly_joined_rooms else since_token,
+ upto_token=prev_batch_token,
+ ))
+ else:
+ room_entries.append(RoomSyncResultBuilder(
+ room_id=room_id,
+ rtype="joined",
+ events=[],
+ newly_joined=room_id in newly_joined_rooms,
+ full_state=False,
+ since_token=since_token,
+ upto_token=since_token,
+ ))
+
+ defer.returnValue((room_entries, invited, newly_joined_rooms))
+
+ @defer.inlineCallbacks
+ def _get_all_rooms(self, sync_result_builder, ignored_users):
+ """Returns entries for all rooms for the user.
+
+ Args:
+ sync_result_builder(SyncResultBuilder)
+ ignored_users(set(str)): Set of users ignored by user.
+
+ Returns:
+ Deferred(tuple): Returns a tuple of the form:
+ `([RoomSyncResultBuilder], [InvitedSyncResult], [])`
+ """
+
+ user_id = sync_result_builder.sync_config.user.to_string()
+ since_token = sync_result_builder.since_token
+ now_token = sync_result_builder.now_token
+ sync_config = sync_result_builder.sync_config
+
+ membership_list = (
+ Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN
+ )
+
+ room_list = yield self.store.get_rooms_for_user_where_membership_is(
+ user_id=user_id,
+ membership_list=membership_list
+ )
+
+ room_entries = []
+ invited = []
+
+ for event in room_list:
+ if event.membership == Membership.JOIN:
+ room_entries.append(RoomSyncResultBuilder(
+ room_id=event.room_id,
+ rtype="joined",
+ events=None,
+ newly_joined=False,
+ full_state=True,
+ since_token=since_token,
+ upto_token=now_token,
+ ))
+ elif event.membership == Membership.INVITE:
+ if event.sender in ignored_users:
+ continue
+ invite = yield self.store.get_event(event.event_id)
+ invited.append(InvitedSyncResult(
+ room_id=event.room_id,
+ invite=invite,
+ ))
+ elif event.membership in (Membership.LEAVE, Membership.BAN):
+ # Always send down rooms we were banned or kicked from.
+ if not sync_config.filter_collection.include_leave:
+ if event.membership == Membership.LEAVE:
+ if user_id == event.sender:
+ continue
+
+ leave_token = now_token.copy_and_replace(
+ "room_key", "s%d" % (event.stream_ordering,)
+ )
+ room_entries.append(RoomSyncResultBuilder(
+ room_id=event.room_id,
+ rtype="archived",
+ events=None,
+ newly_joined=False,
+ full_state=True,
+ since_token=since_token,
+ upto_token=leave_token,
+ ))
+
+ defer.returnValue((room_entries, invited, []))
+
+ @defer.inlineCallbacks
+ def _generate_room_entry(self, sync_result_builder, ignored_users,
+ room_builder, ephemeral, tags, account_data,
+ always_include=False):
+ """Populates the `joined` and `archived` section of `sync_result_builder`
+ based on the `room_builder`.
+
+ Args:
+ sync_result_builder(SyncResultBuilder)
+ ignored_users(set(str)): Set of users ignored by user.
+ room_builder(RoomSyncResultBuilder)
+ ephemeral(list): List of new ephemeral events for room
+ tags(list): List of *all* tags for room, or None if there has been
+ no change.
+ account_data(list): List of new account data for room
+ always_include(bool): Always include this room in the sync response,
+ even if empty.
+ """
+ newly_joined = room_builder.newly_joined
+ full_state = (
+ room_builder.full_state
+ or newly_joined
+ or sync_result_builder.full_state
+ )
+ events = room_builder.events
+
+ # We want to shortcut out as early as possible.
+ if not (always_include or account_data or ephemeral or full_state):
+ if events == [] and tags is None:
+ return
+
+ since_token = sync_result_builder.since_token
+ now_token = sync_result_builder.now_token
+ sync_config = sync_result_builder.sync_config
+
+ room_id = room_builder.room_id
+ since_token = room_builder.since_token
+ upto_token = room_builder.upto_token
+
+ batch = yield self._load_filtered_recents(
+ room_id, sync_config,
+ now_token=upto_token,
+ since_token=since_token,
+ recents=events,
+ newly_joined_room=newly_joined,
+ )
+
+ account_data_events = []
+ if tags is not None:
+ account_data_events.append({
+ "type": "m.tag",
+ "content": {"tags": tags},
+ })
+
+ for account_data_type, content in account_data.items():
+ account_data_events.append({
+ "type": account_data_type,
+ "content": content,
+ })
+
+ account_data = sync_config.filter_collection.filter_room_account_data(
+ account_data_events
+ )
+
+ ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral)
+
+ if not (always_include or batch or account_data or ephemeral or full_state):
+ return
+
+ state = yield self.compute_state_delta(
+ room_id, batch, sync_config, since_token, now_token,
+ full_state=full_state
+ )
+
+ if room_builder.rtype == "joined":
+ unread_notifications = {}
+ room_sync = JoinedSyncResult(
+ room_id=room_id,
+ timeline=batch,
+ state=state,
+ ephemeral=ephemeral,
+ account_data=account_data_events,
+ unread_notifications=unread_notifications,
+ )
+
+ if room_sync or always_include:
+ notifs = yield self.unread_notifs_for_room_id(
+ room_id, sync_config
+ )
+
+ if notifs is not None:
+ unread_notifications["notification_count"] = notifs["notify_count"]
+ unread_notifications["highlight_count"] = notifs["highlight_count"]
+
+ sync_result_builder.joined.append(room_sync)
+ elif room_builder.rtype == "archived":
+ room_sync = ArchivedSyncResult(
+ room_id=room_id,
+ timeline=batch,
+ state=state,
+ account_data=account_data,
+ )
+ if room_sync or always_include:
+ sync_result_builder.archived.append(room_sync)
+ else:
+ raise Exception("Unrecognized rtype: %r", room_builder.rtype)
+
def _action_has_highlight(actions):
for action in actions:
@@ -1057,3 +1081,51 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
(e.type, e.state_key): e
for e in evs
}
+
+
+class SyncResultBuilder(object):
+ "Used to help build up a new SyncResult for a user"
+ def __init__(self, sync_config, full_state, since_token, now_token):
+ """
+ Args:
+ sync_config(SyncConfig)
+ full_state(bool): The full_state flag as specified by user
+ since_token(StreamToken): The token supplied by user, or None.
+ now_token(StreamToken): The token to sync up to.
+ """
+ self.sync_config = sync_config
+ self.full_state = full_state
+ self.since_token = since_token
+ self.now_token = now_token
+
+ self.presence = []
+ self.account_data = []
+ self.joined = []
+ self.invited = []
+ self.archived = []
+
+
+class RoomSyncResultBuilder(object):
+ """Stores information needed to create either a `JoinedSyncResult` or
+ `ArchivedSyncResult`.
+ """
+ def __init__(self, room_id, rtype, events, newly_joined, full_state,
+ since_token, upto_token):
+ """
+ Args:
+ room_id(str)
+ rtype(str): One of `"joined"` or `"archived"`
+ events(list): List of events to include in the room, (more events
+ may be added when generating result).
+ newly_joined(bool): If the user has newly joined the room
+ full_state(bool): Whether the full state should be sent in result
+ since_token(StreamToken): Earliest point to return events from, or None
+ upto_token(StreamToken): Latest point to return events from.
+ """
+ self.room_id = room_id
+ self.rtype = rtype
+ self.events = events
+ self.newly_joined = newly_joined
+ self.full_state = full_state
+ self.since_token = since_token
+ self.upto_token = upto_token
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index d46f05f426..5589296c09 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
# A tiny object useful for storing a user's membership in a room, as a mapping
# key
-RoomMember = namedtuple("RoomMember", ("room_id", "user"))
+RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
class TypingHandler(object):
@@ -38,7 +38,7 @@ class TypingHandler(object):
self.store = hs.get_datastore()
self.server_name = hs.config.server_name
self.auth = hs.get_auth()
- self.is_mine = hs.is_mine
+ self.is_mine_id = hs.is_mine_id
self.notifier = hs.get_notifier()
self.clock = hs.get_clock()
@@ -67,20 +67,23 @@ class TypingHandler(object):
@defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout):
- if not self.is_mine(target_user):
+ target_user_id = target_user.to_string()
+ auth_user_id = auth_user.to_string()
+
+ if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this Home Server")
- if target_user != auth_user:
+ if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
- yield self.auth.check_joined_room(room_id, target_user.to_string())
+ yield self.auth.check_joined_room(room_id, target_user_id)
logger.debug(
- "%s has started typing in %s", target_user.to_string(), room_id
+ "%s has started typing in %s", target_user_id, room_id
)
until = self.clock.time_msec() + timeout
- member = RoomMember(room_id=room_id, user=target_user)
+ member = RoomMember(room_id=room_id, user_id=target_user_id)
was_present = member in self._member_typing_until
@@ -104,25 +107,28 @@ class TypingHandler(object):
yield self._push_update(
room_id=room_id,
- user=target_user,
+ user_id=target_user_id,
typing=True,
)
@defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id):
- if not self.is_mine(target_user):
+ target_user_id = target_user.to_string()
+ auth_user_id = auth_user.to_string()
+
+ if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this Home Server")
- if target_user != auth_user:
+ if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
- yield self.auth.check_joined_room(room_id, target_user.to_string())
+ yield self.auth.check_joined_room(room_id, target_user_id)
logger.debug(
- "%s has stopped typing in %s", target_user.to_string(), room_id
+ "%s has stopped typing in %s", target_user_id, room_id
)
- member = RoomMember(room_id=room_id, user=target_user)
+ member = RoomMember(room_id=room_id, user_id=target_user_id)
if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member])
@@ -132,8 +138,9 @@ class TypingHandler(object):
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
- if self.is_mine(user):
- member = RoomMember(room_id=room_id, user=user)
+ user_id = user.to_string()
+ if self.is_mine_id(user_id):
+ member = RoomMember(room_id=room_id, user_id=user_id)
yield self._stopped_typing(member)
@defer.inlineCallbacks
@@ -144,7 +151,7 @@ class TypingHandler(object):
yield self._push_update(
room_id=member.room_id,
- user=member.user,
+ user_id=member.user_id,
typing=False,
)
@@ -156,7 +163,7 @@ class TypingHandler(object):
del self._member_typing_timer[member]
@defer.inlineCallbacks
- def _push_update(self, room_id, user, typing):
+ def _push_update(self, room_id, user_id, typing):
domains = yield self.store.get_joined_hosts_for_room(room_id)
deferreds = []
@@ -164,7 +171,7 @@ class TypingHandler(object):
if domain == self.server_name:
self._push_update_local(
room_id=room_id,
- user=user,
+ user_id=user_id,
typing=typing
)
else:
@@ -173,7 +180,7 @@ class TypingHandler(object):
edu_type="m.typing",
content={
"room_id": room_id,
- "user_id": user.to_string(),
+ "user_id": user_id,
"typing": typing,
},
))
@@ -183,23 +190,26 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
room_id = content["room_id"]
- user = UserID.from_string(content["user_id"])
+ user_id = content["user_id"]
+
+ # Check that the string is a valid user id
+ UserID.from_string(user_id)
domains = yield self.store.get_joined_hosts_for_room(room_id)
if self.server_name in domains:
self._push_update_local(
room_id=room_id,
- user=user,
+ user_id=user_id,
typing=content["typing"]
)
- def _push_update_local(self, room_id, user, typing):
+ def _push_update_local(self, room_id, user_id, typing):
room_set = self._room_typing.setdefault(room_id, set())
if typing:
- room_set.add(user)
+ room_set.add(user_id)
else:
- room_set.discard(user)
+ room_set.discard(user_id)
self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial
@@ -211,13 +221,14 @@ class TypingHandler(object):
def get_all_typing_updates(self, last_id, current_id):
# TODO: Work out a way to do this without scanning the entire state.
+ if last_id == current_id:
+ return []
+
rows = []
for room_id, serial in self._room_serials.items():
if last_id < serial and serial <= current_id:
typing = self._room_typing[room_id]
- typing_bytes = json.dumps([
- u.to_string() for u in typing
- ], ensure_ascii=False)
+ typing_bytes = json.dumps(list(typing), ensure_ascii=False)
rows.append((serial, room_id, typing_bytes))
rows.sort()
return rows
@@ -239,7 +250,7 @@ class TypingNotificationEventSource(object):
"type": "m.typing",
"room_id": room_id,
"content": {
- "user_ids": [u.to_string() for u in typing],
+ "user_ids": list(typing),
},
}
|