diff --git a/docs/code_style.rst b/docs/code_style.rst
index dc40a7ab7b..8d73d17beb 100644
--- a/docs/code_style.rst
+++ b/docs/code_style.rst
@@ -43,7 +43,10 @@ Basically, PEP8
together, or want to deliberately extend or preserve vertical/horizontal
space)
-Comments should follow the google code style. This is so that we can generate
-documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/)
+Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
+This is so that we can generate documentation with
+`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
+`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
+in the sphinx documentation.
Code should pass pep8 --max-line-length=100 without any warnings.
diff --git a/docs/turn-howto.rst b/docs/turn-howto.rst
index e2c73458e2..04c0100715 100644
--- a/docs/turn-howto.rst
+++ b/docs/turn-howto.rst
@@ -9,31 +9,35 @@ the Home Server to generate credentials that are valid for use on the TURN
server through the use of a secret shared between the Home Server and the
TURN server.
-This document described how to install coturn
-(https://code.google.com/p/coturn/) which also supports the TURN REST API,
+This document describes how to install coturn
+(https://github.com/coturn/coturn) which also supports the TURN REST API,
and integrate it with synapse.
coturn Setup
============
+You may be able to setup coturn via your package manager, or set it up manually using the usual ``configure, make, make install`` process.
+
1. Check out coturn::
- svn checkout http://coturn.googlecode.com/svn/trunk/ coturn
+
+ git clone https://github.com/coturn/coturn.git coturn
cd coturn
2. Configure it::
+
./configure
- You may need to install libevent2: if so, you should do so
+ You may need to install ``libevent2``: if so, you should do so
in the way recommended by your operating system.
You can ignore warnings about lack of database support: a
database is unnecessary for this purpose.
3. Build and install it::
+
make
make install
- 4. Make a config file in /etc/turnserver.conf. You can customise
- a config file from turnserver.conf.default. The relevant
+ 4. Create or edit the config file in ``/etc/turnserver.conf``. The relevant
lines, with example values, are::
lt-cred-mech
@@ -41,7 +45,7 @@ coturn Setup
static-auth-secret=[your secret key here]
realm=turn.myserver.org
- See turnserver.conf.default for explanations of the options.
+ See turnserver.conf for explanations of the options.
One way to generate the static-auth-secret is with pwgen::
pwgen -s 64 1
@@ -54,6 +58,7 @@ coturn Setup
import your private key and certificate.
7. Start the turn server::
+
bin/turnserver -o
diff --git a/jenkins-dendron-postgres.sh b/jenkins-dendron-postgres.sh
index 7e6f24aa7d..50268e0982 100755
--- a/jenkins-dendron-postgres.sh
+++ b/jenkins-dendron-postgres.sh
@@ -70,6 +70,7 @@ cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
+: ${PORT_COUNT=20}
./jenkins/prep_sytest_for_postgres.sh
@@ -81,6 +82,6 @@ echo >&2 "Running sytest with PostgreSQL";
--dendron $WORKSPACE/dendron/bin/dendron \
--pusher \
--synchrotron \
- --port-base $PORT_BASE
+ --port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1))
cd ..
diff --git a/jenkins-postgres.sh b/jenkins-postgres.sh
index ae6b111591..2f0768fcb7 100755
--- a/jenkins-postgres.sh
+++ b/jenkins-postgres.sh
@@ -44,6 +44,7 @@ cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
+: ${PORT_COUNT=20}
./jenkins/prep_sytest_for_postgres.sh
@@ -51,7 +52,7 @@ echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
- --port-base $PORT_BASE
+ --port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
cd ..
cp sytest/.coverage.* .
diff --git a/jenkins-sqlite.sh b/jenkins-sqlite.sh
index 9398d9db15..da603c5af8 100755
--- a/jenkins-sqlite.sh
+++ b/jenkins-sqlite.sh
@@ -41,11 +41,12 @@ cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
-: ${PORT_BASE:=8500}
+: ${PORT_COUNT=20}
+: ${PORT_BASE:=8000}
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
- --port-base $PORT_BASE
+ --port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
cd ..
cp sytest/.coverage.* .
diff --git a/res/templates/notif_mail.html b/res/templates/notif_mail.html
index 8aee68b591..535bea764d 100644
--- a/res/templates/notif_mail.html
+++ b/res/templates/notif_mail.html
@@ -36,7 +36,7 @@
<div class="debug">
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
an event was received at {{ reason.received_at|format_ts("%c") }}
- which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago,
+ which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
{% if reason.last_sent_ts %}
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
diff --git a/scripts/hash_password b/scripts/hash_password
index e784600989..215ab25cfe 100755
--- a/scripts/hash_password
+++ b/scripts/hash_password
@@ -1,10 +1,16 @@
#!/usr/bin/env python
import argparse
+
+import sys
+
import bcrypt
import getpass
+import yaml
+
bcrypt_rounds=12
+password_pepper = ""
def prompt_for_pass():
password = getpass.getpass("Password: ")
@@ -28,12 +34,22 @@ if __name__ == "__main__":
default=None,
help="New password for user. Will prompt if omitted.",
)
+ parser.add_argument(
+ "-c", "--config",
+ type=argparse.FileType('r'),
+ help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.",
+ )
args = parser.parse_args()
+ if "config" in args and args.config:
+ config = yaml.safe_load(args.config)
+ bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
+ password_config = config.get("password_config", {})
+ password_pepper = password_config.get("pepper", password_pepper)
password = args.password
if not password:
password = prompt_for_pass()
- print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds))
+ print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))
diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user
index 27a6250b14..12ed20d623 100755
--- a/scripts/register_new_matrix_user
+++ b/scripts/register_new_matrix_user
@@ -25,18 +25,26 @@ import urllib2
import yaml
-def request_registration(user, password, server_location, shared_secret):
+def request_registration(user, password, server_location, shared_secret, admin=False):
mac = hmac.new(
key=shared_secret,
- msg=user,
digestmod=hashlib.sha1,
- ).hexdigest()
+ )
+
+ mac.update(user)
+ mac.update("\x00")
+ mac.update(password)
+ mac.update("\x00")
+ mac.update("admin" if admin else "notadmin")
+
+ mac = mac.hexdigest()
data = {
"user": user,
"password": password,
"mac": mac,
"type": "org.matrix.login.shared_secret",
+ "admin": admin,
}
server_location = server_location.rstrip("/")
@@ -68,7 +76,7 @@ def request_registration(user, password, server_location, shared_secret):
sys.exit(1)
-def register_new_user(user, password, server_location, shared_secret):
+def register_new_user(user, password, server_location, shared_secret, admin):
if not user:
try:
default_user = getpass.getuser()
@@ -99,7 +107,14 @@ def register_new_user(user, password, server_location, shared_secret):
print "Passwords do not match"
sys.exit(1)
- request_registration(user, password, server_location, shared_secret)
+ if not admin:
+ admin = raw_input("Make admin [no]: ")
+ if admin in ("y", "yes", "true"):
+ admin = True
+ else:
+ admin = False
+
+ request_registration(user, password, server_location, shared_secret, bool(admin))
if __name__ == "__main__":
@@ -119,6 +134,11 @@ if __name__ == "__main__":
default=None,
help="New password for user. Will prompt if omitted.",
)
+ parser.add_argument(
+ "-a", "--admin",
+ action="store_true",
+ help="Register new user as an admin. Will prompt if omitted.",
+ )
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
@@ -151,4 +171,4 @@ if __name__ == "__main__":
else:
secret = args.shared_secret
- register_new_user(args.user, args.password, args.server_url, secret)
+ register_new_user(args.user, args.password, args.server_url, secret, args.admin)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index b106fbed6d..b219b46a4b 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -42,8 +42,9 @@ class Codes(object):
TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
- THREEPID_IN_USE = "THREEPID_IN_USE"
+ THREEPID_IN_USE = "M_THREEPID_IN_USE"
INVALID_USERNAME = "M_INVALID_USERNAME"
+ SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
class CodeMessageException(RuntimeError):
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 40ffd9bf0d..9c2dd32953 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -147,7 +147,7 @@ class SynapseHomeServer(HomeServer):
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
- self, self.config.uploads_path, self.auth, self.content_addr
+ self, self.config.uploads_path
),
})
@@ -301,7 +301,6 @@ def setup(config_options):
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
- content_addr=config.content_addr,
version_string=version_string,
database_engine=database_engine,
)
diff --git a/synapse/config/ldap.py b/synapse/config/ldap.py
index 9c14593a99..d83c2230be 100644
--- a/synapse/config/ldap.py
+++ b/synapse/config/ldap.py
@@ -13,40 +13,88 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config
+from ._base import Config, ConfigError
+
+
+MISSING_LDAP3 = (
+ "Missing ldap3 library. This is required for LDAP Authentication."
+)
+
+
+class LDAPMode(object):
+ SIMPLE = "simple",
+ SEARCH = "search",
+
+ LIST = (SIMPLE, SEARCH)
class LDAPConfig(Config):
def read_config(self, config):
- ldap_config = config.get("ldap_config", None)
- if ldap_config:
- self.ldap_enabled = ldap_config.get("enabled", False)
- self.ldap_server = ldap_config["server"]
- self.ldap_port = ldap_config["port"]
- self.ldap_tls = ldap_config.get("tls", False)
- self.ldap_search_base = ldap_config["search_base"]
- self.ldap_search_property = ldap_config["search_property"]
- self.ldap_email_property = ldap_config["email_property"]
- self.ldap_full_name_property = ldap_config["full_name_property"]
- else:
- self.ldap_enabled = False
- self.ldap_server = None
- self.ldap_port = None
- self.ldap_tls = False
- self.ldap_search_base = None
- self.ldap_search_property = None
- self.ldap_email_property = None
- self.ldap_full_name_property = None
+ ldap_config = config.get("ldap_config", {})
+
+ self.ldap_enabled = ldap_config.get("enabled", False)
+
+ if self.ldap_enabled:
+ # verify dependencies are available
+ try:
+ import ldap3
+ ldap3 # to stop unused lint
+ except ImportError:
+ raise ConfigError(MISSING_LDAP3)
+
+ self.ldap_mode = LDAPMode.SIMPLE
+
+ # verify config sanity
+ self.require_keys(ldap_config, [
+ "uri",
+ "base",
+ "attributes",
+ ])
+
+ self.ldap_uri = ldap_config["uri"]
+ self.ldap_start_tls = ldap_config.get("start_tls", False)
+ self.ldap_base = ldap_config["base"]
+ self.ldap_attributes = ldap_config["attributes"]
+
+ if "bind_dn" in ldap_config:
+ self.ldap_mode = LDAPMode.SEARCH
+ self.require_keys(ldap_config, [
+ "bind_dn",
+ "bind_password",
+ ])
+
+ self.ldap_bind_dn = ldap_config["bind_dn"]
+ self.ldap_bind_password = ldap_config["bind_password"]
+ self.ldap_filter = ldap_config.get("filter", None)
+
+ # verify attribute lookup
+ self.require_keys(ldap_config['attributes'], [
+ "uid",
+ "name",
+ "mail",
+ ])
+
+ def require_keys(self, config, required):
+ missing = [key for key in required if key not in config]
+ if missing:
+ raise ConfigError(
+ "LDAP enabled but missing required config values: {}".format(
+ ", ".join(missing)
+ )
+ )
def default_config(self, **kwargs):
return """\
# ldap_config:
# enabled: true
- # server: "ldap://localhost"
- # port: 389
- # tls: false
- # search_base: "ou=Users,dc=example,dc=com"
- # search_property: "cn"
- # email_property: "email"
- # full_name_property: "givenName"
+ # uri: "ldap://ldap.example.com:389"
+ # start_tls: true
+ # base: "ou=users,dc=example,dc=com"
+ # attributes:
+ # uid: "cn"
+ # mail: "email"
+ # name: "givenName"
+ # #bind_dn:
+ # #bind_password:
+ # #filter: "(objectClass=posixAccount)"
"""
diff --git a/synapse/config/password.py b/synapse/config/password.py
index dec801ef41..a4bd171399 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -23,10 +23,14 @@ class PasswordConfig(Config):
def read_config(self, config):
password_config = config.get("password_config", {})
self.password_enabled = password_config.get("enabled", True)
+ self.password_pepper = password_config.get("pepper", "")
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable password for login.
password_config:
enabled: true
+ # Uncomment and change to a secret random string for extra security.
+ # DO NOT CHANGE THIS AFTER INITIAL SETUP!
+ #pepper: ""
"""
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 7840dc3ad6..51eaf423ce 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -107,26 +107,6 @@ class ServerConfig(Config):
]
})
- # Attempt to guess the content_addr for the v0 content repostitory
- content_addr = config.get("content_addr")
- if not content_addr:
- for listener in self.listeners:
- if listener["type"] == "http" and not listener.get("tls", False):
- unsecure_port = listener["port"]
- break
- else:
- raise RuntimeError("Could not determine 'content_addr'")
-
- host = self.server_name
- if ':' not in host:
- host = "%s:%d" % (host, unsecure_port)
- else:
- host = host.split(':')[0]
- host = "%s:%d" % (host, unsecure_port)
- content_addr = "http://%s" % (host,)
-
- self.content_addr = content_addr
-
def default_config(self, server_name, **kwargs):
if ":" in server_name:
bind_port = int(server_name.split(":")[1])
@@ -169,7 +149,6 @@ class ServerConfig(Config):
# room directory.
# secondary_directory_servers:
# - matrix.org
- # - vector.im
# List of ports that Synapse should listen on, their purpose and their
# configuration.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2a589524a4..85f5e752fe 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -49,6 +49,7 @@ class FederationServer(FederationBase):
super(FederationServer, self).__init__(hs)
self._room_pdu_linearizer = Linearizer()
+ self._server_linearizer = Linearizer()
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
@@ -89,11 +90,14 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
- pdus = yield self.handler.on_backfill_request(
- origin, room_id, versions, limit
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ pdus = yield self.handler.on_backfill_request(
+ origin, room_id, versions, limit
+ )
+
+ res = self._transaction_from_pdus(pdus).get_dict()
- defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+ defer.returnValue((200, res))
@defer.inlineCallbacks
@log_function
@@ -184,27 +188,28 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
- if event_id:
- pdus = yield self.handler.get_state_for_pdu(
- origin, room_id, event_id,
- )
- auth_chain = yield self.store.get_auth_chain(
- [pdu.event_id for pdu in pdus]
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ if event_id:
+ pdus = yield self.handler.get_state_for_pdu(
+ origin, room_id, event_id,
+ )
+ auth_chain = yield self.store.get_auth_chain(
+ [pdu.event_id for pdu in pdus]
+ )
- for event in auth_chain:
- # We sign these again because there was a bug where we
- # incorrectly signed things the first time round
- if self.hs.is_mine_id(event.event_id):
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
+ for event in auth_chain:
+ # We sign these again because there was a bug where we
+ # incorrectly signed things the first time round
+ if self.hs.is_mine_id(event.event_id):
+ event.signatures.update(
+ compute_event_signature(
+ event,
+ self.hs.hostname,
+ self.hs.config.signing_key[0]
+ )
)
- )
- else:
- raise NotImplementedError("Specify an event")
+ else:
+ raise NotImplementedError("Specify an event")
defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
@@ -283,14 +288,16 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
- time_now = self._clock.time_msec()
- auth_pdus = yield self.handler.on_event_auth(event_id)
- defer.returnValue((200, {
- "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
- }))
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ time_now = self._clock.time_msec()
+ auth_pdus = yield self.handler.on_event_auth(event_id)
+ res = {
+ "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
+ }
+ defer.returnValue((200, res))
@defer.inlineCallbacks
- def on_query_auth_request(self, origin, content, event_id):
+ def on_query_auth_request(self, origin, content, room_id, event_id):
"""
Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain.
@@ -309,32 +316,33 @@ class FederationServer(FederationBase):
Returns:
Deferred: Results in `dict` with the same format as `content`
"""
- auth_chain = [
- self.event_from_pdu_json(e)
- for e in content["auth_chain"]
- ]
-
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
- origin, auth_chain, outlier=True
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ auth_chain = [
+ self.event_from_pdu_json(e)
+ for e in content["auth_chain"]
+ ]
+
+ signed_auth = yield self._check_sigs_and_hash_and_fetch(
+ origin, auth_chain, outlier=True
+ )
- ret = yield self.handler.on_query_auth(
- origin,
- event_id,
- signed_auth,
- content.get("rejects", []),
- content.get("missing", []),
- )
+ ret = yield self.handler.on_query_auth(
+ origin,
+ event_id,
+ signed_auth,
+ content.get("rejects", []),
+ content.get("missing", []),
+ )
- time_now = self._clock.time_msec()
- send_content = {
- "auth_chain": [
- e.get_pdu_json(time_now)
- for e in ret["auth_chain"]
- ],
- "rejects": ret.get("rejects", []),
- "missing": ret.get("missing", []),
- }
+ time_now = self._clock.time_msec()
+ send_content = {
+ "auth_chain": [
+ e.get_pdu_json(time_now)
+ for e in ret["auth_chain"]
+ ],
+ "rejects": ret.get("rejects", []),
+ "missing": ret.get("missing", []),
+ }
defer.returnValue(
(200, send_content)
@@ -386,21 +394,24 @@ class FederationServer(FederationBase):
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
- logger.info(
- "on_get_missing_events: earliest_events: %r, latest_events: %r,"
- " limit: %d, min_depth: %d",
- earliest_events, latest_events, limit, min_depth
- )
- missing_events = yield self.handler.on_get_missing_events(
- origin, room_id, earliest_events, latest_events, limit, min_depth
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ logger.info(
+ "on_get_missing_events: earliest_events: %r, latest_events: %r,"
+ " limit: %d, min_depth: %d",
+ earliest_events, latest_events, limit, min_depth
+ )
+ missing_events = yield self.handler.on_get_missing_events(
+ origin, room_id, earliest_events, latest_events, limit, min_depth
+ )
- if len(missing_events) < 5:
- logger.info("Returning %d events: %r", len(missing_events), missing_events)
- else:
- logger.info("Returning %d events", len(missing_events))
+ if len(missing_events) < 5:
+ logger.info(
+ "Returning %d events: %r", len(missing_events), missing_events
+ )
+ else:
+ logger.info("Returning %d events", len(missing_events))
- time_now = self._clock.time_msec()
+ time_now = self._clock.time_msec()
defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 8a1965f45a..26fa88ae84 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -388,7 +388,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_POST(self, origin, content, query, context, event_id):
new_content = yield self.handler.on_query_auth_request(
- origin, content, event_id
+ origin, content, context, event_id
)
defer.returnValue((200, new_content))
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index b38f81e999..e259213a36 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -20,6 +20,7 @@ from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor
+from synapse.config.ldap import LDAPMode
from twisted.web.client import PartialDownloadError
@@ -28,6 +29,12 @@ import bcrypt
import pymacaroons
import simplejson
+try:
+ import ldap3
+except ImportError:
+ ldap3 = None
+ pass
+
import synapse.util.stringutils as stringutils
@@ -50,17 +57,20 @@ class AuthHandler(BaseHandler):
self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled
- self.ldap_server = hs.config.ldap_server
- self.ldap_port = hs.config.ldap_port
- self.ldap_tls = hs.config.ldap_tls
- self.ldap_search_base = hs.config.ldap_search_base
- self.ldap_search_property = hs.config.ldap_search_property
- self.ldap_email_property = hs.config.ldap_email_property
- self.ldap_full_name_property = hs.config.ldap_full_name_property
-
- if self.ldap_enabled is True:
- import ldap
- logger.info("Import ldap version: %s", ldap.__version__)
+ if self.ldap_enabled:
+ if not ldap3:
+ raise RuntimeError(
+ 'Missing ldap3 library. This is required for LDAP Authentication.'
+ )
+ self.ldap_mode = hs.config.ldap_mode
+ self.ldap_uri = hs.config.ldap_uri
+ self.ldap_start_tls = hs.config.ldap_start_tls
+ self.ldap_base = hs.config.ldap_base
+ self.ldap_filter = hs.config.ldap_filter
+ self.ldap_attributes = hs.config.ldap_attributes
+ if self.ldap_mode == LDAPMode.SEARCH:
+ self.ldap_bind_dn = hs.config.ldap_bind_dn
+ self.ldap_bind_password = hs.config.ldap_bind_password
self.hs = hs # FIXME better possibility to access registrationHandler later?
@@ -452,40 +462,167 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
- if not self.ldap_enabled:
- logger.debug("LDAP not configured")
+ """ Attempt to authenticate a user against an LDAP Server
+ and register an account if none exists.
+
+ Returns:
+ True if authentication against LDAP was successful
+ """
+
+ if not ldap3 or not self.ldap_enabled:
defer.returnValue(False)
- import ldap
+ if self.ldap_mode not in LDAPMode.LIST:
+ raise RuntimeError(
+ 'Invalid ldap mode specified: {mode}'.format(
+ mode=self.ldap_mode
+ )
+ )
- logger.info("Authenticating %s with LDAP" % user_id)
try:
- ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
- logger.debug("Connecting LDAP server at %s" % ldap_url)
- l = ldap.initialize(ldap_url)
- if self.ldap_tls:
- logger.debug("Initiating TLS")
- self._connection.start_tls_s()
+ server = ldap3.Server(self.ldap_uri)
+ logger.debug(
+ "Attempting ldap connection with %s",
+ self.ldap_uri
+ )
- local_name = UserID.from_string(user_id).localpart
+ localpart = UserID.from_string(user_id).localpart
+ if self.ldap_mode == LDAPMode.SIMPLE:
+ # bind with the the local users ldap credentials
+ bind_dn = "{prop}={value},{base}".format(
+ prop=self.ldap_attributes['uid'],
+ value=localpart,
+ base=self.ldap_base
+ )
+ conn = ldap3.Connection(server, bind_dn, password)
+ logger.debug(
+ "Established ldap connection in simple mode: %s",
+ conn
+ )
- dn = "%s=%s, %s" % (
- self.ldap_search_property,
- local_name,
- self.ldap_search_base)
- logger.debug("DN for LDAP authentication: %s" % dn)
+ if self.ldap_start_tls:
+ conn.start_tls()
+ logger.debug(
+ "Upgraded ldap connection in simple mode through StartTLS: %s",
+ conn
+ )
+
+ conn.bind()
+
+ elif self.ldap_mode == LDAPMode.SEARCH:
+ # connect with preconfigured credentials and search for local user
+ conn = ldap3.Connection(
+ server,
+ self.ldap_bind_dn,
+ self.ldap_bind_password
+ )
+ logger.debug(
+ "Established ldap connection in search mode: %s",
+ conn
+ )
+
+ if self.ldap_start_tls:
+ conn.start_tls()
+ logger.debug(
+ "Upgraded ldap connection in search mode through StartTLS: %s",
+ conn
+ )
- l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
+ conn.bind()
+ # find matching dn
+ query = "({prop}={value})".format(
+ prop=self.ldap_attributes['uid'],
+ value=localpart
+ )
+ if self.ldap_filter:
+ query = "(&{query}{filter})".format(
+ query=query,
+ filter=self.ldap_filter
+ )
+ logger.debug("ldap search filter: %s", query)
+ result = conn.search(self.ldap_base, query)
+
+ if result and len(conn.response) == 1:
+ # found exactly one result
+ user_dn = conn.response[0]['dn']
+ logger.debug('ldap search found dn: %s', user_dn)
+
+ # unbind and reconnect, rebind with found dn
+ conn.unbind()
+ conn = ldap3.Connection(
+ server,
+ user_dn,
+ password,
+ auto_bind=True
+ )
+ else:
+ # found 0 or > 1 results, abort!
+ logger.warn(
+ "ldap search returned unexpected (%d!=1) amount of results",
+ len(conn.response)
+ )
+ defer.returnValue(False)
+
+ logger.info(
+ "User authenticated against ldap server: %s",
+ conn
+ )
+
+ # check for existing account, if none exists, create one
if not (yield self.does_user_exist(user_id)):
- handler = self.hs.get_handlers().registration_handler
- user_id, access_token = (
- yield handler.register(localpart=local_name)
+ # query user metadata for account creation
+ query = "({prop}={value})".format(
+ prop=self.ldap_attributes['uid'],
+ value=localpart
+ )
+
+ if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
+ query = "(&{filter}{user_filter})".format(
+ filter=query,
+ user_filter=self.ldap_filter
+ )
+ logger.debug("ldap registration filter: %s", query)
+
+ result = conn.search(
+ search_base=self.ldap_base,
+ search_filter=query,
+ attributes=[
+ self.ldap_attributes['name'],
+ self.ldap_attributes['mail']
+ ]
)
+ if len(conn.response) == 1:
+ attrs = conn.response[0]['attributes']
+ mail = attrs[self.ldap_attributes['mail']][0]
+ name = attrs[self.ldap_attributes['name']][0]
+
+ # create account
+ registration_handler = self.hs.get_handlers().registration_handler
+ user_id, access_token = (
+ yield registration_handler.register(localpart=localpart)
+ )
+
+ # TODO: bind email, set displayname with data from ldap directory
+
+ logger.info(
+ "ldap registration successful: %d: %s (%s, %)",
+ user_id,
+ localpart,
+ name,
+ mail
+ )
+ else:
+ logger.warn(
+ "ldap registration failed: unexpected (%d!=1) amount of results",
+ len(result)
+ )
+ defer.returnValue(False)
+
defer.returnValue(True)
- except ldap.LDAPError, e:
- logger.warn("LDAP error: %s", e)
+ except ldap3.core.exceptions.LDAPException as e:
+ logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
@defer.inlineCallbacks
@@ -613,7 +750,8 @@ class AuthHandler(BaseHandler):
Returns:
Hashed password (str).
"""
- return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
+ return bcrypt.hashpw(password + self.hs.config.password_pepper,
+ bcrypt.gensalt(self.bcrypt_rounds))
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@@ -626,6 +764,7 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash (bool).
"""
if stored_hash:
- return bcrypt.hashpw(password, stored_hash.encode('utf-8')) == stored_hash
+ return bcrypt.hashpw(password + self.hs.config.password_pepper,
+ stored_hash.encode('utf-8')) == stored_hash
else:
return False
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6c0bc7eafa..351b218247 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1413,7 +1413,7 @@ class FederationHandler(BaseHandler):
local_view = dict(auth_events)
remote_view = dict(auth_events)
remote_view.update({
- (d.type, d.state_key): d for d in different_events
+ (d.type, d.state_key): d for d in different_events if d
})
new_state, prev_state = self.state_handler.resolve_events(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 656ce124f9..559e5d5a71 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -21,7 +21,7 @@ from synapse.api.errors import (
)
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, Codes
import json
import logging
@@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler):
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
)
+ def _should_trust_id_server(self, id_server):
+ if id_server not in self.trusted_id_servers:
+ if self.trust_any_id_server_just_for_testing_do_not_use:
+ logger.warn(
+ "Trusting untrustworthy ID server %r even though it isn't"
+ " in the trusted id list for testing because"
+ " 'use_insecure_ssl_client_just_for_testing_do_not_use'"
+ " is set in the config",
+ id_server,
+ )
+ else:
+ return False
+ return True
+
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
yield run_on_reactor()
@@ -59,19 +73,12 @@ class IdentityHandler(BaseHandler):
else:
raise SynapseError(400, "No client_secret in creds")
- if id_server not in self.trusted_id_servers:
- if self.trust_any_id_server_just_for_testing_do_not_use:
- logger.warn(
- "Trusting untrustworthy ID server %r even though it isn't"
- " in the trusted id list for testing because"
- " 'use_insecure_ssl_client_just_for_testing_do_not_use'"
- " is set in the config",
- id_server,
- )
- else:
- logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
- 'credentials', id_server)
- defer.returnValue(None)
+ if not self._should_trust_id_server(id_server):
+ logger.warn(
+ '%s is not a trusted ID server: rejecting 3pid ' +
+ 'credentials', id_server
+ )
+ defer.returnValue(None)
data = {}
try:
@@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler):
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
yield run_on_reactor()
+ if not self._should_trust_id_server(id_server):
+ raise SynapseError(
+ 400, "Untrusted ID server '%s'" % id_server,
+ Codes.SERVER_NOT_TRUSTED
+ )
+
params = {
'email': email,
'client_secret': client_secret,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 15caf1950a..ad2753c1b5 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -26,7 +26,7 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
)
from synapse.util import unwrapFirstError
-from synapse.util.async import concurrently_execute, run_on_reactor
+from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn
from synapse.visibility import filter_events_for_client
@@ -50,6 +50,20 @@ class MessageHandler(BaseHandler):
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
+ self.pagination_lock = ReadWriteLock()
+
+ @defer.inlineCallbacks
+ def purge_history(self, room_id, event_id):
+ event = yield self.store.get_event(event_id)
+
+ if event.room_id != room_id:
+ raise SynapseError(400, "Event is for wrong room.")
+
+ depth = event.depth
+
+ with (yield self.pagination_lock.write(room_id)):
+ yield self.store.delete_old_state(room_id, depth)
+
@defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True):
@@ -85,42 +99,43 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room")
- membership, member_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
+ with (yield self.pagination_lock.read(room_id)):
+ membership, member_event_id = yield self._check_in_room_or_world_readable(
+ room_id, user_id
+ )
- if source_config.direction == 'b':
- # if we're going backwards, we might need to backfill. This
- # requires that we have a topo token.
- if room_token.topological:
- max_topo = room_token.topological
- else:
- max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
- room_id, room_token.stream
- )
+ if source_config.direction == 'b':
+ # if we're going backwards, we might need to backfill. This
+ # requires that we have a topo token.
+ if room_token.topological:
+ max_topo = room_token.topological
+ else:
+ max_topo = yield self.store.get_max_topological_token(
+ room_id, room_token.stream
+ )
+
+ if membership == Membership.LEAVE:
+ # If they have left the room then clamp the token to be before
+ # they left the room, to save the effort of loading from the
+ # database.
+ leave_token = yield self.store.get_topological_token_for_event(
+ member_event_id
+ )
+ leave_token = RoomStreamToken.parse(leave_token)
+ if leave_token.topological < max_topo:
+ source_config.from_key = str(leave_token)
- if membership == Membership.LEAVE:
- # If they have left the room then clamp the token to be before
- # they left the room, to save the effort of loading from the
- # database.
- leave_token = yield self.store.get_topological_token_for_event(
- member_event_id
+ yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ room_id, max_topo
)
- leave_token = RoomStreamToken.parse(leave_token)
- if leave_token.topological < max_topo:
- source_config.from_key = str(leave_token)
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
- room_id, max_topo
+ events, next_key = yield data_source.get_pagination_rows(
+ requester.user, source_config, room_id
)
- events, next_key = yield data_source.get_pagination_rows(
- requester.user, source_config, room_id
- )
-
- next_token = pagin_config.from_token.copy_and_replace(
- "room_key", next_key
- )
+ next_token = pagin_config.from_token.copy_and_replace(
+ "room_key", next_key
+ )
if not events:
defer.returnValue({
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0b7517221d..8c3381df8a 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -90,7 +90,8 @@ class RegistrationHandler(BaseHandler):
password=None,
generate_token=True,
guest_access_token=None,
- make_guest=False
+ make_guest=False,
+ admin=False,
):
"""Registers a new client on the server.
@@ -141,6 +142,7 @@ class RegistrationHandler(BaseHandler):
# If the user was a guest then they already have a profile
None if was_guest else user.localpart
),
+ admin=admin,
)
else:
# autogen a sequential user ID
@@ -358,7 +360,8 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data)
@defer.inlineCallbacks
- def get_or_create_user(self, localpart, displayname, duration_seconds):
+ def get_or_create_user(self, localpart, displayname, duration_seconds,
+ password_hash=None):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@@ -394,7 +397,7 @@ class RegistrationHandler(BaseHandler):
yield self.store.register(
user_id=user_id,
token=token,
- password_hash=None,
+ password_hash=password_hash,
create_profile_with_localpart=user.localpart,
)
else:
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 3992804845..2acc6cc214 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -38,6 +38,7 @@ class HttpPusher(object):
self.hs = hs
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
+ self.state_handler = self.hs.get_state_handler()
self.user_id = pusherdict['user_name']
self.app_id = pusherdict['app_id']
self.app_display_name = pusherdict['app_display_name']
@@ -237,7 +238,9 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
- ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event)
+ ctx = yield push_tools.get_context_for_event(
+ self.state_handler, event, self.user_id
+ )
d = {
'notification': {
@@ -269,8 +272,8 @@ class HttpPusher(object):
if 'content' in event:
d['notification']['content'] = event.content
- if len(ctx['aliases']):
- d['notification']['room_alias'] = ctx['aliases'][0]
+ # We no longer send aliases separately, instead, we send the human
+ # readable name of the room, which may be an alias.
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
d['notification']['sender_display_name'] = ctx['sender_display_name']
if 'name' in ctx and len(ctx['name']) > 0:
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 89a3b5e90a..6f2d1ad57d 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -14,6 +14,9 @@
# limitations under the License.
from twisted.internet import defer
+from synapse.util.presentable_names import (
+ calculate_room_name, name_from_member_event
+)
@defer.inlineCallbacks
@@ -45,24 +48,21 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks
-def get_context_for_event(store, ev):
- name_aliases = yield store.get_room_name_and_aliases(
- ev.room_id
- )
+def get_context_for_event(state_handler, ev, user_id):
+ ctx = {}
- ctx = {'aliases': name_aliases[1]}
- if name_aliases[0] is not None:
- ctx['name'] = name_aliases[0]
+ room_state = yield state_handler.get_current_state(ev.room_id)
- their_member_events_for_room = yield store.get_current_state(
- room_id=ev.room_id,
- event_type='m.room.member',
- state_key=ev.user_id
+ # we no longer bother setting room_alias, and make room_name the
+ # human-readable name instead, be that m.room.namer, an alias or
+ # a list of people in the room
+ name = calculate_room_name(
+ room_state, user_id, fallback_to_single_member=False
)
- for mev in their_member_events_for_room:
- if mev.content['membership'] == 'join' and 'displayname' in mev.content:
- dn = mev.content['displayname']
- if dn is not None:
- ctx['sender_display_name'] = dn
+ if name:
+ ctx['name'] = name
+
+ sender_state_event = room_state[("m.room.member", ev.sender)]
+ ctx['sender_display_name'] = name_from_member_event(sender_state_event)
defer.returnValue(ctx)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index e0a7a19777..e024cec0a2 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -48,6 +48,9 @@ CONDITIONAL_REQUIREMENTS = {
"Jinja2>=2.8": ["Jinja2>=2.8"],
"bleach>=1.4.2": ["bleach>=1.4.2"],
},
+ "ldap": {
+ "ldap3>=1.0": ["ldap3>=1.0"],
+ },
}
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 877c68508c..369d839464 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -18,7 +18,6 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.storage import DataStore
-from synapse.storage.room import RoomStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
@@ -64,7 +63,6 @@ class SlavedEventStore(BaseSlavedStore):
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
- get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_latest_event_ids_in_room = EventFederationStore.__dict__[
@@ -202,7 +200,6 @@ class SlavedEventStore(BaseSlavedStore):
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
- self.get_room_name_and_aliases.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id)
@@ -246,9 +243,3 @@ class SlavedEventStore(BaseSlavedStore):
self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key
))
-
- if event.type in [EventTypes.Name, EventTypes.Aliases]:
- self.get_room_name_and_aliases.invalidate(
- (event.room_id,)
- )
- pass
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index aa05b3f023..b0cb31a448 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -46,5 +46,82 @@ class WhoisRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
+class PurgeMediaCacheRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns("/admin/purge_media_cache")
+
+ def __init__(self, hs):
+ self.media_repository = hs.get_media_repository()
+ super(PurgeMediaCacheRestServlet, self).__init__(hs)
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ before_ts = request.args.get("before_ts", None)
+ if not before_ts:
+ raise SynapseError(400, "Missing 'before_ts' arg")
+
+ logger.info("before_ts: %r", before_ts[0])
+
+ try:
+ before_ts = int(before_ts[0])
+ except Exception:
+ raise SynapseError(400, "Invalid 'before_ts' arg")
+
+ ret = yield self.media_repository.delete_old_remote_media(before_ts)
+
+ defer.returnValue((200, ret))
+
+
+class PurgeHistoryRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, room_id, event_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ yield self.handlers.message_handler.purge_history(room_id, event_id)
+
+ defer.returnValue((200, {}))
+
+
+class DeactivateAccountRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ super(DeactivateAccountRestServlet, self).__init__(hs)
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, target_user_id):
+ UserID.from_string(target_user_id)
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ # FIXME: Theoretically there is a race here wherein user resets password
+ # using threepid.
+ yield self.store.user_delete_access_tokens(target_user_id)
+ yield self.store.user_delete_threepids(target_user_id)
+ yield self.store.user_set_password_hash(target_user_id, None)
+
+ defer.returnValue((200, {}))
+
+
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)
+ PurgeMediaCacheRestServlet(hs).register(http_server)
+ DeactivateAccountRestServlet(hs).register(http_server)
+ PurgeHistoryRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index e3f4fbb0bb..ce7099b18f 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -324,6 +324,14 @@ class RegisterRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Shared secret registration is not enabled")
user = register_json["user"].encode("utf-8")
+ password = register_json["password"].encode("utf-8")
+ admin = register_json.get("admin", None)
+
+ # Its important to check as we use null bytes as HMAC field separators
+ if "\x00" in user:
+ raise SynapseError(400, "Invalid user")
+ if "\x00" in password:
+ raise SynapseError(400, "Invalid password")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
@@ -331,17 +339,21 @@ class RegisterRestServlet(ClientV1RestServlet):
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
- msg=user,
digestmod=sha1,
- ).hexdigest()
-
- password = register_json["password"].encode("utf-8")
+ )
+ want_mac.update(user)
+ want_mac.update("\x00")
+ want_mac.update(password)
+ want_mac.update("\x00")
+ want_mac.update("admin" if admin else "notadmin")
+ want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
localpart=user,
password=password,
+ admin=bool(admin),
)
self._remove_session(session)
defer.returnValue({
@@ -410,12 +422,15 @@ class CreateUserRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Failed to parse 'duration_seconds'")
if duration_seconds > self.direct_user_creation_max_duration:
duration_seconds = self.direct_user_creation_max_duration
+ password_hash = user_json["password_hash"].encode("utf-8") \
+ if user_json.get("password_hash") else None
handler = self.handlers.registration_handler
user_id, token = yield handler.get_or_create_user(
localpart=localpart,
displayname=displayname,
- duration_seconds=duration_seconds
+ duration_seconds=duration_seconds,
+ password_hash=password_hash
)
defer.returnValue({
diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py
index d9fc045fc6..956bd5da75 100644
--- a/synapse/rest/media/v0/content_repository.py
+++ b/synapse/rest/media/v0/content_repository.py
@@ -15,14 +15,12 @@
from synapse.http.server import respond_with_json_bytes, finish_request
-from synapse.util.stringutils import random_string
from synapse.api.errors import (
- cs_exception, SynapseError, CodeMessageException, Codes, cs_error
+ Codes, cs_error
)
from twisted.protocols.basic import FileSender
from twisted.web import server, resource
-from twisted.internet import defer
import base64
import simplejson as json
@@ -50,64 +48,10 @@ class ContentRepoResource(resource.Resource):
"""
isLeaf = True
- def __init__(self, hs, directory, auth, external_addr):
+ def __init__(self, hs, directory):
resource.Resource.__init__(self)
self.hs = hs
self.directory = directory
- self.auth = auth
- self.external_addr = external_addr.rstrip('/')
- self.max_upload_size = hs.config.max_upload_size
-
- if not os.path.isdir(self.directory):
- os.mkdir(self.directory)
- logger.info("ContentRepoResource : Created %s directory.",
- self.directory)
-
- @defer.inlineCallbacks
- def map_request_to_name(self, request):
- # auth the user
- requester = yield self.auth.get_user_by_req(request)
-
- # namespace all file uploads on the user
- prefix = base64.urlsafe_b64encode(
- requester.user.to_string()
- ).replace('=', '')
-
- # use a random string for the main portion
- main_part = random_string(24)
-
- # suffix with a file extension if we can make one. This is nice to
- # provide a hint to clients on the file information. We will also reuse
- # this info to spit back the content type to the client.
- suffix = ""
- if request.requestHeaders.hasHeader("Content-Type"):
- content_type = request.requestHeaders.getRawHeaders(
- "Content-Type")[0]
- suffix = "." + base64.urlsafe_b64encode(content_type)
- if (content_type.split("/")[0].lower() in
- ["image", "video", "audio"]):
- file_ext = content_type.split("/")[-1]
- # be a little paranoid and only allow a-z
- file_ext = re.sub("[^a-z]", "", file_ext)
- suffix += "." + file_ext
-
- file_name = prefix + main_part + suffix
- file_path = os.path.join(self.directory, file_name)
- logger.info("User %s is uploading a file to path %s",
- request.user.user_id.to_string(),
- file_path)
-
- # keep trying to make a non-clashing file, with a sensible max attempts
- attempts = 0
- while os.path.exists(file_path):
- main_part = random_string(24)
- file_name = prefix + main_part + suffix
- file_path = os.path.join(self.directory, file_name)
- attempts += 1
- if attempts > 25: # really? Really?
- raise SynapseError(500, "Unable to create file.")
-
- defer.returnValue(file_path)
def render_GET(self, request):
# no auth here on purpose, to allow anyone to view, even across home
@@ -155,58 +99,6 @@ class ContentRepoResource(resource.Resource):
return server.NOT_DONE_YET
- def render_POST(self, request):
- self._async_render(request)
- return server.NOT_DONE_YET
-
def render_OPTIONS(self, request):
respond_with_json_bytes(request, 200, {}, send_cors=True)
return server.NOT_DONE_YET
-
- @defer.inlineCallbacks
- def _async_render(self, request):
- try:
- # TODO: The checks here are a bit late. The content will have
- # already been uploaded to a tmp file at this point
- content_length = request.getHeader("Content-Length")
- if content_length is None:
- raise SynapseError(
- msg="Request must specify a Content-Length", code=400
- )
- if int(content_length) > self.max_upload_size:
- raise SynapseError(
- msg="Upload request body is too large",
- code=413,
- )
-
- fname = yield self.map_request_to_name(request)
-
- # TODO I have a suspicious feeling this is just going to block
- with open(fname, "wb") as f:
- f.write(request.content.read())
-
- # FIXME (erikj): These should use constants.
- file_name = os.path.basename(fname)
- # FIXME: we can't assume what the repo's public mounted path is
- # ...plus self-signed SSL won't work to remote clients anyway
- # ...and we can't assume that it's SSL anyway, as we might want to
- # serve it via the non-SSL listener...
- url = "%s/_matrix/content/%s" % (
- self.external_addr, file_name
- )
-
- respond_with_json_bytes(request, 200,
- json.dumps({"content_token": url}),
- send_cors=True)
-
- except CodeMessageException as e:
- logger.exception(e)
- respond_with_json_bytes(request, e.code,
- json.dumps(cs_exception(e)))
- except Exception as e:
- logger.error("Failed to store file: %s" % e)
- respond_with_json_bytes(
- request,
- 500,
- json.dumps({"error": "Internal server error"}),
- send_cors=True)
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 422ab86fb3..0137458f71 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -65,3 +65,9 @@ class MediaFilePaths(object):
file_id[0:2], file_id[2:4], file_id[4:],
file_name
)
+
+ def remote_media_thumbnail_dir(self, server_name, file_id):
+ return os.path.join(
+ self.base_path, "remote_thumbnail", server_name,
+ file_id[0:2], file_id[2:4], file_id[4:],
+ )
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 2468c3ac42..692e078419 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -30,11 +30,13 @@ from synapse.api.errors import SynapseError
from twisted.internet import defer, threads
-from synapse.util.async import ObservableDeferred
+from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn
import os
+import errno
+import shutil
import cgi
import logging
@@ -43,8 +45,11 @@ import urlparse
logger = logging.getLogger(__name__)
+UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+
+
class MediaRepository(object):
- def __init__(self, hs, filepaths):
+ def __init__(self, hs):
self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
@@ -52,11 +57,28 @@ class MediaRepository(object):
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
- self.filepaths = filepaths
- self.downloads = {}
+ self.filepaths = MediaFilePaths(hs.config.media_store_path)
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
+ self.remote_media_linearizer = Linearizer()
+
+ self.recently_accessed_remotes = set()
+
+ self.clock.looping_call(
+ self._update_recently_accessed_remotes,
+ UPDATE_RECENTLY_ACCESSED_REMOTES_TS
+ )
+
+ @defer.inlineCallbacks
+ def _update_recently_accessed_remotes(self):
+ media = self.recently_accessed_remotes
+ self.recently_accessed_remotes = set()
+
+ yield self.store.update_cached_last_access_time(
+ media, self.clock.time_msec()
+ )
+
@staticmethod
def _makedirs(filepath):
dirname = os.path.dirname(filepath)
@@ -93,22 +115,12 @@ class MediaRepository(object):
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
+ @defer.inlineCallbacks
def get_remote_media(self, server_name, media_id):
key = (server_name, media_id)
- download = self.downloads.get(key)
- if download is None:
- download = self._get_remote_media_impl(server_name, media_id)
- download = ObservableDeferred(
- download,
- consumeErrors=True
- )
- self.downloads[key] = download
-
- @download.addBoth
- def callback(media_info):
- del self.downloads[key]
- return media_info
- return download.observe()
+ with (yield self.remote_media_linearizer.queue(key)):
+ media_info = yield self._get_remote_media_impl(server_name, media_id)
+ defer.returnValue(media_info)
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
@@ -119,6 +131,11 @@ class MediaRepository(object):
media_info = yield self._download_remote_file(
server_name, media_id
)
+ else:
+ self.recently_accessed_remotes.add((server_name, media_id))
+ yield self.store.update_cached_last_access_time(
+ [(server_name, media_id)], self.clock.time_msec()
+ )
defer.returnValue(media_info)
@defer.inlineCallbacks
@@ -416,6 +433,41 @@ class MediaRepository(object):
"height": m_height,
})
+ @defer.inlineCallbacks
+ def delete_old_remote_media(self, before_ts):
+ old_media = yield self.store.get_remote_media_before(before_ts)
+
+ deleted = 0
+
+ for media in old_media:
+ origin = media["media_origin"]
+ media_id = media["media_id"]
+ file_id = media["filesystem_id"]
+ key = (origin, media_id)
+
+ logger.info("Deleting: %r", key)
+
+ with (yield self.remote_media_linearizer.queue(key)):
+ full_path = self.filepaths.remote_media_filepath(origin, file_id)
+ try:
+ os.remove(full_path)
+ except OSError as e:
+ logger.warn("Failed to remove file: %r", full_path)
+ if e.errno == errno.ENOENT:
+ pass
+ else:
+ continue
+
+ thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
+ origin, file_id
+ )
+ shutil.rmtree(thumbnail_dir, ignore_errors=True)
+
+ yield self.store.delete_remote_media(origin, media_id)
+ deleted += 1
+
+ defer.returnValue({"deleted": deleted})
+
class MediaRepositoryResource(Resource):
"""File uploading and downloading.
@@ -464,9 +516,8 @@ class MediaRepositoryResource(Resource):
def __init__(self, hs):
Resource.__init__(self)
- filepaths = MediaFilePaths(hs.config.media_store_path)
- media_repo = MediaRepository(hs, filepaths)
+ media_repo = hs.get_media_repository()
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
diff --git a/synapse/server.py b/synapse/server.py
index dd4b81c658..d49a1a8a96 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -45,6 +45,7 @@ from synapse.crypto.keyring import Keyring
from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
+from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
@@ -113,6 +114,7 @@ class HomeServer(object):
'filtering',
'http_client_context_factory',
'simple_http_client',
+ 'media_repository',
]
def __init__(self, hostname, **kwargs):
@@ -233,6 +235,9 @@ class HomeServer(object):
**self.db_config.get("args", {})
)
+ def build_media_repository(self):
+ return MediaRepository(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 32c6677d47..d766a30299 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -807,6 +807,11 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
+ def _simple_delete(self, table, keyvalues, desc):
+ return self.runInteraction(
+ desc, self._simple_delete_txn, table, keyvalues
+ )
+
@staticmethod
def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 940e11d7a2..3d93285f84 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -16,6 +16,8 @@
from ._base import SQLBaseStore
from twisted.internet import defer
from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.types import RoomStreamToken
+from .stream import lower_bound
import logging
import ujson as json
@@ -73,6 +75,9 @@ class EventPushActionsStore(SQLBaseStore):
stream_ordering = results[0][0]
topological_ordering = results[0][1]
+ token = RoomStreamToken(
+ topological_ordering, stream_ordering
+ )
sql = (
"SELECT sum(notif), sum(highlight)"
@@ -80,15 +85,10 @@ class EventPushActionsStore(SQLBaseStore):
" WHERE"
" user_id = ?"
" AND room_id = ?"
- " AND ("
- " topological_ordering > ?"
- " OR (topological_ordering = ? AND stream_ordering > ?)"
- ")"
- )
- txn.execute(sql, (
- user_id, room_id,
- topological_ordering, topological_ordering, stream_ordering
- ))
+ " AND %s"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ txn.execute(sql, (user_id, room_id))
row = txn.fetchone()
if row:
return {
@@ -152,7 +152,7 @@ class EventPushActionsStore(SQLBaseStore):
if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering)
- sql += " ORDER BY ep.stream_ordering ASC LIMIT ?"
+ sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
args.append(limit)
txn.execute(sql, args)
return txn.fetchall()
@@ -176,14 +176,16 @@ class EventPushActionsStore(SQLBaseStore):
if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering)
- sql += " ORDER BY ep.stream_ordering ASC"
+ sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
+ args.append(limit)
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_no_receipt
)
- defer.returnValue([
+ # Make a list of dicts from the two sets of results.
+ notifs = [
{
"event_id": row[0],
"room_id": row[1],
@@ -191,7 +193,16 @@ class EventPushActionsStore(SQLBaseStore):
"actions": json.loads(row[3]),
"received_ts": row[4],
} for row in after_read_receipt + no_read_receipt
- ])
+ ]
+
+ # Now sort it so it's ordered correctly, since currently it will
+ # contain results from the first query, correctly ordered, followed
+ # by results from the second query, but we want them all ordered
+ # by received_ts
+ notifs.sort(key=lambda r: -(r['received_ts'] or 0))
+
+ # Now return the first `limit`
+ defer.returnValue(notifs[:limit])
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 6d978ffcd5..b582942164 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -23,6 +23,7 @@ from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
from canonicaljson import encode_canonical_json
from collections import deque, namedtuple
@@ -355,7 +356,6 @@ class EventsStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
- txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
@@ -666,12 +666,6 @@ class EventsStore(SQLBaseStore):
(event.room_id, event.type, event.state_key,)
)
- if event.type in [EventTypes.Name, EventTypes.Aliases]:
- txn.call_after(
- self.get_room_name_and_aliases.invalidate,
- (event.room_id,)
- )
-
self._simple_upsert_txn(
txn,
"current_state_events",
@@ -1288,6 +1282,156 @@ class EventsStore(SQLBaseStore):
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
+ def delete_old_state(self, room_id, topological_ordering):
+ return self.runInteraction(
+ "delete_old_state",
+ self._delete_old_state_txn, room_id, topological_ordering
+ )
+
+ def _delete_old_state_txn(self, txn, room_id, topological_ordering):
+ """Deletes old room state
+ """
+
+ # Tables that should be pruned:
+ # event_auth
+ # event_backward_extremities
+ # event_content_hashes
+ # event_destinations
+ # event_edge_hashes
+ # event_edges
+ # event_forward_extremities
+ # event_json
+ # event_push_actions
+ # event_reference_hashes
+ # event_search
+ # event_signatures
+ # event_to_state_groups
+ # events
+ # rejections
+ # room_depth
+ # state_groups
+ # state_groups_state
+
+ # First ensure that we're not about to delete all the forward extremeties
+ txn.execute(
+ "SELECT e.event_id, e.depth FROM events as e "
+ "INNER JOIN event_forward_extremities as f "
+ "ON e.event_id = f.event_id "
+ "AND e.room_id = f.room_id "
+ "WHERE f.room_id = ?",
+ (room_id,)
+ )
+ rows = txn.fetchall()
+ max_depth = max(row[0] for row in rows)
+
+ if max_depth <= topological_ordering:
+ # We need to ensure we don't delete all the events from the datanase
+ # otherwise we wouldn't be able to send any events (due to not
+ # having any backwards extremeties)
+ raise SynapseError(
+ 400, "topological_ordering is greater than forward extremeties"
+ )
+
+ txn.execute(
+ "SELECT event_id, state_key FROM events"
+ " LEFT JOIN state_events USING (room_id, event_id)"
+ " WHERE room_id = ? AND topological_ordering < ?",
+ (room_id, topological_ordering,)
+ )
+ event_rows = txn.fetchall()
+
+ # We calculate the new entries for the backward extremeties by finding
+ # all events that point to events that are to be purged
+ txn.execute(
+ "SELECT e.event_id FROM events as e"
+ " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
+ " INNER JOIN events as e2 ON e2.event_id = ed.event_id"
+ " WHERE e.room_id = ? AND e.topological_ordering < ?"
+ " AND e2.topological_ordering >= ?",
+ (room_id, topological_ordering, topological_ordering)
+ )
+ new_backwards_extrems = txn.fetchall()
+
+ # Get all state groups that are only referenced by events that are
+ # to be deleted.
+ txn.execute(
+ "SELECT state_group FROM event_to_state_groups"
+ " INNER JOIN events USING (event_id)"
+ " WHERE state_group IN ("
+ " SELECT DISTINCT state_group FROM events"
+ " INNER JOIN event_to_state_groups USING (event_id)"
+ " WHERE room_id = ? AND topological_ordering < ?"
+ " )"
+ " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
+ (room_id, topological_ordering, topological_ordering)
+ )
+ state_rows = txn.fetchall()
+ txn.executemany(
+ "DELETE FROM state_groups_state WHERE state_group = ?",
+ state_rows
+ )
+ txn.executemany(
+ "DELETE FROM state_groups WHERE id = ?",
+ state_rows
+ )
+ # Delete all non-state
+ txn.executemany(
+ "DELETE FROM event_to_state_groups WHERE event_id = ?",
+ [(event_id,) for event_id, _ in event_rows]
+ )
+
+ txn.execute(
+ "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+ (topological_ordering, room_id,)
+ )
+
+ # Delete all remote non-state events
+ to_delete = [
+ (event_id,) for event_id, state_key in event_rows
+ if state_key is None and not self.hs.is_mine_id(event_id)
+ ]
+ for table in (
+ "events",
+ "event_json",
+ "event_auth",
+ "event_content_hashes",
+ "event_destinations",
+ "event_edge_hashes",
+ "event_edges",
+ "event_forward_extremities",
+ "event_push_actions",
+ "event_reference_hashes",
+ "event_search",
+ "event_signatures",
+ "rejections",
+ "event_backward_extremities",
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE event_id = ?" % (table,),
+ to_delete
+ )
+
+ # Update backward extremeties
+ txn.executemany(
+ "INSERT INTO event_backward_extremities (room_id, event_id)"
+ " VALUES (?, ?)",
+ [(room_id, event_id) for event_id, in new_backwards_extrems]
+ )
+
+ txn.executemany(
+ "DELETE FROM events WHERE event_id = ?",
+ to_delete
+ )
+ # Mark all state and own events as outliers
+ txn.executemany(
+ "UPDATE events SET outlier = ?"
+ " WHERE event_id = ?",
+ [
+ (True, event_id,) for event_id, state_key in event_rows
+ if state_key is not None or self.hs.is_mine_id(event_id)
+ ]
+ )
+
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index a820fcf07f..4c0f82353d 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -157,10 +157,25 @@ class MediaRepositoryStore(SQLBaseStore):
"created_ts": time_now_ms,
"upload_name": upload_name,
"filesystem_id": filesystem_id,
+ "last_access_ts": time_now_ms,
},
desc="store_cached_remote_media",
)
+ def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+ def update_cache_txn(txn):
+ sql = (
+ "UPDATE remote_media_cache SET last_access_ts = ?"
+ " WHERE media_origin = ? AND media_id = ?"
+ )
+
+ txn.executemany(sql, (
+ (time_ts, media_origin, media_id)
+ for media_origin, media_id in origin_id_tuples
+ ))
+
+ return self.runInteraction("update_cached_last_access_time", update_cache_txn)
+
def get_remote_media_thumbnails(self, origin, media_id):
return self._simple_select_list(
"remote_media_cache_thumbnails",
@@ -190,3 +205,32 @@ class MediaRepositoryStore(SQLBaseStore):
},
desc="store_remote_media_thumbnail",
)
+
+ def get_remote_media_before(self, before_ts):
+ sql = (
+ "SELECT media_origin, media_id, filesystem_id"
+ " FROM remote_media_cache"
+ " WHERE last_access_ts < ?"
+ )
+
+ return self._execute(
+ "get_remote_media_before", self.cursor_to_dict, sql, before_ts
+ )
+
+ def delete_remote_media(self, media_origin, media_id):
+ def delete_remote_media_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ "remote_media_cache",
+ keyvalues={
+ "media_origin": media_origin, "media_id": media_id
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ "remote_media_cache_thumbnails",
+ keyvalues={
+ "media_origin": media_origin, "media_id": media_id
+ },
+ )
+ return self.runInteraction("delete_remote_media", delete_remote_media_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index c8487c8838..8801669a6b 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 32
+SCHEMA_VERSION = 33
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 3de9e0f709..0a68341494 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -77,7 +77,7 @@ class RegistrationStore(SQLBaseStore):
@defer.inlineCallbacks
def register(self, user_id, token, password_hash,
was_guest=False, make_guest=False, appservice_id=None,
- create_profile_with_localpart=None):
+ create_profile_with_localpart=None, admin=False):
"""Attempts to register an account.
Args:
@@ -104,6 +104,7 @@ class RegistrationStore(SQLBaseStore):
make_guest,
appservice_id,
create_profile_with_localpart,
+ admin
)
self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,))
@@ -118,6 +119,7 @@ class RegistrationStore(SQLBaseStore):
make_guest,
appservice_id,
create_profile_with_localpart,
+ admin,
):
now = int(self.clock.time())
@@ -125,29 +127,33 @@ class RegistrationStore(SQLBaseStore):
try:
if was_guest:
- txn.execute("UPDATE users SET"
- " password_hash = ?,"
- " upgrade_ts = ?,"
- " is_guest = ?"
- " WHERE name = ?",
- [password_hash, now, 1 if make_guest else 0, user_id])
+ self._simple_update_one_txn(
+ txn,
+ "users",
+ keyvalues={
+ "name": user_id,
+ },
+ updatevalues={
+ "password_hash": password_hash,
+ "upgrade_ts": now,
+ "is_guest": 1 if make_guest else 0,
+ "appservice_id": appservice_id,
+ "admin": 1 if admin else 0,
+ }
+ )
else:
- txn.execute("INSERT INTO users "
- "("
- " name,"
- " password_hash,"
- " creation_ts,"
- " is_guest,"
- " appservice_id"
- ") "
- "VALUES (?,?,?,?,?)",
- [
- user_id,
- password_hash,
- now,
- 1 if make_guest else 0,
- appservice_id,
- ])
+ self._simple_insert_txn(
+ txn,
+ "users",
+ values={
+ "name": user_id,
+ "password_hash": password_hash,
+ "creation_ts": now,
+ "is_guest": 1 if make_guest else 0,
+ "appservice_id": appservice_id,
+ "admin": 1 if admin else 0,
+ }
+ )
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -384,6 +390,15 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(ret['user_id'])
defer.returnValue(None)
+ def user_delete_threepids(self, user_id):
+ return self._simple_delete(
+ "user_threepids",
+ keyvalues={
+ "user_id": user_id,
+ },
+ desc="user_delete_threepids",
+ )
+
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 97f9f1929c..8251f58670 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks
from .engines import PostgresEngine, Sqlite3Engine
import collections
@@ -192,49 +191,6 @@ class RoomStore(SQLBaseStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- @cachedInlineCallbacks()
- def get_room_name_and_aliases(self, room_id):
- def get_room_name(txn):
- sql = (
- "SELECT name FROM room_names"
- " INNER JOIN current_state_events USING (room_id, event_id)"
- " WHERE room_id = ?"
- " LIMIT 1"
- )
-
- txn.execute(sql, (room_id,))
- rows = txn.fetchall()
- if rows:
- return rows[0][0]
- else:
- return None
-
- return [row[0] for row in txn.fetchall()]
-
- def get_room_aliases(txn):
- sql = (
- "SELECT content FROM current_state_events"
- " INNER JOIN events USING (room_id, event_id)"
- " WHERE room_id = ?"
- )
- txn.execute(sql, (room_id,))
- return [row[0] for row in txn.fetchall()]
-
- name = yield self.runInteraction("get_room_name", get_room_name)
- alias_contents = yield self.runInteraction("get_room_aliases", get_room_aliases)
-
- aliases = []
-
- for c in alias_contents:
- try:
- content = json.loads(c)
- except:
- continue
-
- aliases.extend(content.get('aliases', []))
-
- defer.returnValue((name, aliases))
-
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
next_id = self._event_reports_id_gen.get_next()
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py
new file mode 100644
index 0000000000..55ae43f395
--- /dev/null
+++ b/synapse/storage/schema/delta/33/remote_media_ts.py
@@ -0,0 +1,31 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+
+
+ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT"
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ cur.execute(ALTER_TABLE)
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ cur.execute(
+ database_engine.convert_param_style(
+ "UPDATE remote_media_cache SET last_access_ts = ?"
+ ),
+ (int(time.time() * 1000),)
+ )
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index b9ad965fd6..c33ac5a8d7 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -40,6 +40,7 @@ from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
@@ -54,25 +55,43 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
-def lower_bound(token):
+def lower_bound(token, engine, inclusive=False):
+ inclusive = "=" if inclusive else ""
if token.topological is None:
- return "(%d < %s)" % (token.stream, "stream_ordering")
+ return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
- return "(%d < %s OR (%d = %s AND %d < %s))" % (
+ if isinstance(engine, PostgresEngine):
+ # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
+ # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
+ # use the later form when running against postgres.
+ return "((%d,%d) <%s (%s,%s))" % (
+ token.topological, token.stream, inclusive,
+ "topological_ordering", "stream_ordering",
+ )
+ return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
- token.stream, "stream_ordering",
+ token.stream, inclusive, "stream_ordering",
)
-def upper_bound(token):
+def upper_bound(token, engine, inclusive=True):
+ inclusive = "=" if inclusive else ""
if token.topological is None:
- return "(%d >= %s)" % (token.stream, "stream_ordering")
+ return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
- return "(%d > %s OR (%d = %s AND %d >= %s))" % (
+ if isinstance(engine, PostgresEngine):
+ # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
+ # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
+ # use the later form when running against postgres.
+ return "((%d,%d) >%s (%s,%s))" % (
+ token.topological, token.stream, inclusive,
+ "topological_ordering", "stream_ordering",
+ )
+ return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
- token.stream, "stream_ordering",
+ token.stream, inclusive, "stream_ordering",
)
@@ -308,18 +327,22 @@ class StreamStore(SQLBaseStore):
args = [False, room_id]
if direction == 'b':
order = "DESC"
- bounds = upper_bound(RoomStreamToken.parse(from_key))
+ bounds = upper_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
if to_key:
- bounds = "%s AND %s" % (
- bounds, lower_bound(RoomStreamToken.parse(to_key))
- )
+ bounds = "%s AND %s" % (bounds, lower_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
else:
order = "ASC"
- bounds = lower_bound(RoomStreamToken.parse(from_key))
+ bounds = lower_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
if to_key:
- bounds = "%s AND %s" % (
- bounds, upper_bound(RoomStreamToken.parse(to_key))
- )
+ bounds = "%s AND %s" % (bounds, upper_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
if int(limit) > 0:
args.append(int(limit))
@@ -487,13 +510,13 @@ class StreamStore(SQLBaseStore):
row["topological_ordering"], row["stream_ordering"],)
)
- def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
+ def get_max_topological_token(self, room_id, stream_key):
sql = (
"SELECT max(topological_ordering) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
return self._execute(
- "get_max_topological_token_for_stream_and_room", None,
+ "get_max_topological_token", None,
sql, room_id, stream_key,
).addCallback(
lambda r: r[0][0] if r else 0
@@ -586,32 +609,60 @@ class StreamStore(SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
- stream_ordering = results["stream_ordering"]
- topological_ordering = results["topological_ordering"]
-
- query_before = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND (topological_ordering < ?"
- " OR (topological_ordering = ? AND stream_ordering < ?))"
- " ORDER BY topological_ordering DESC, stream_ordering DESC"
- " LIMIT ?"
+ token = RoomStreamToken(
+ results["topological_ordering"],
+ results["stream_ordering"],
)
- query_after = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND (topological_ordering > ?"
- " OR (topological_ordering = ? AND stream_ordering > ?))"
- " ORDER BY topological_ordering ASC, stream_ordering ASC"
- " LIMIT ?"
- )
+ if isinstance(self.database_engine, Sqlite3Engine):
+ # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)``
+ # So we give pass it to SQLite3 as the UNION ALL of the two queries.
+
+ query_before = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering < ?"
+ " UNION ALL"
+ " SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
+ " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+ )
+ before_args = (
+ room_id, token.topological,
+ room_id, token.topological, token.stream,
+ before_limit,
+ )
- txn.execute(
- query_before,
- (
- room_id, topological_ordering, topological_ordering,
- stream_ordering, before_limit,
+ query_after = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering > ?"
+ " UNION ALL"
+ " SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
+ " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
)
- )
+ after_args = (
+ room_id, token.topological,
+ room_id, token.topological, token.stream,
+ after_limit,
+ )
+ else:
+ query_before = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND %s"
+ " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+ ) % (upper_bound(token, self.database_engine, inclusive=False),)
+
+ before_args = (room_id, before_limit)
+
+ query_after = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND %s"
+ " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ after_args = (room_id, after_limit)
+
+ txn.execute(query_before, before_args)
rows = self.cursor_to_dict(txn)
events_before = [r["event_id"] for r in rows]
@@ -623,17 +674,11 @@ class StreamStore(SQLBaseStore):
))
else:
start_token = str(RoomStreamToken(
- topological_ordering,
- stream_ordering - 1,
+ token.topological,
+ token.stream - 1,
))
- txn.execute(
- query_after,
- (
- room_id, topological_ordering, topological_ordering,
- stream_ordering, after_limit,
- )
- )
+ txn.execute(query_after, after_args)
rows = self.cursor_to_dict(txn)
events_after = [r["event_id"] for r in rows]
@@ -644,10 +689,7 @@ class StreamStore(SQLBaseStore):
rows[-1]["stream_ordering"],
))
else:
- end_token = str(RoomStreamToken(
- topological_ordering,
- stream_ordering,
- ))
+ end_token = str(token)
return {
"before": {
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 40be7fe7e3..c84b23ff46 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -194,3 +194,85 @@ class Linearizer(object):
self.key_to_defer.pop(key, None)
defer.returnValue(_ctx_manager())
+
+
+class ReadWriteLock(object):
+ """A deferred style read write lock.
+
+ Example:
+
+ with (yield read_write_lock.read("test_key")):
+ # do some work
+ """
+
+ # IMPLEMENTATION NOTES
+ #
+ # We track the most recent queued reader and writer deferreds (which get
+ # resolved when they release the lock).
+ #
+ # Read: We know its safe to acquire a read lock when the latest writer has
+ # been resolved. The new reader is appeneded to the list of latest readers.
+ #
+ # Write: We know its safe to acquire the write lock when both the latest
+ # writers and readers have been resolved. The new writer replaces the latest
+ # writer.
+
+ def __init__(self):
+ # Latest readers queued
+ self.key_to_current_readers = {}
+
+ # Latest writer queued
+ self.key_to_current_writer = {}
+
+ @defer.inlineCallbacks
+ def read(self, key):
+ new_defer = defer.Deferred()
+
+ curr_readers = self.key_to_current_readers.setdefault(key, set())
+ curr_writer = self.key_to_current_writer.get(key, None)
+
+ curr_readers.add(new_defer)
+
+ # We wait for the latest writer to finish writing. We can safely ignore
+ # any existing readers... as they're readers.
+ yield curr_writer
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ new_defer.callback(None)
+ self.key_to_current_readers.get(key, set()).discard(new_defer)
+
+ defer.returnValue(_ctx_manager())
+
+ @defer.inlineCallbacks
+ def write(self, key):
+ new_defer = defer.Deferred()
+
+ curr_readers = self.key_to_current_readers.get(key, set())
+ curr_writer = self.key_to_current_writer.get(key, None)
+
+ # We wait on all latest readers and writer.
+ to_wait_on = list(curr_readers)
+ if curr_writer:
+ to_wait_on.append(curr_writer)
+
+ # We can clear the list of current readers since the new writer waits
+ # for them to finish.
+ curr_readers.clear()
+ self.key_to_current_writer[key] = new_defer
+
+ yield defer.gatherResults(to_wait_on)
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ new_defer.callback(None)
+ if self.key_to_current_writer[key] == new_defer:
+ self.key_to_current_writer.pop(key)
+
+ defer.returnValue(_ctx_manager())
diff --git a/synapse/util/presentable_names.py b/synapse/util/presentable_names.py
index a6866f6117..4c54812e6f 100644
--- a/synapse/util/presentable_names.py
+++ b/synapse/util/presentable_names.py
@@ -25,7 +25,8 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
ALL_ALONE = "Empty Room"
-def calculate_room_name(room_state, user_id, fallback_to_members=True):
+def calculate_room_name(room_state, user_id, fallback_to_members=True,
+ fallback_to_single_member=True):
"""
Works out a user-facing name for the given room as per Matrix
spec recommendations.
@@ -129,6 +130,8 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True):
return name_from_member_event(all_members[0])
else:
return ALL_ALONE
+ elif len(other_members) == 1 and not fallback_to_single_member:
+ return None
else:
return descriptor_from_member_events(other_members)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 17587fda00..f33e6f60fb 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -59,47 +59,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
[unpatch() for unpatch in self.unpatches]
@defer.inlineCallbacks
- def test_room_name_and_aliases(self):
- create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
- yield self.persist(type="m.room.member", key=USER_ID, membership="join")
- yield self.persist(type="m.room.name", key="", name="name1")
- yield self.persist(
- type="m.room.aliases", key="blue", aliases=["#1:blue"]
- )
- yield self.replicate()
- yield self.check(
- "get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"])
- )
-
- # Set the room name.
- yield self.persist(type="m.room.name", key="", name="name2")
- yield self.replicate()
- yield self.check(
- "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"])
- )
-
- # Set the room aliases.
- yield self.persist(
- type="m.room.aliases", key="blue", aliases=["#2:blue"]
- )
- yield self.replicate()
- yield self.check(
- "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"])
- )
-
- # Leave and join the room clobbering the state.
- yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
- yield self.persist(
- type="m.room.member", key=USER_ID, membership="join",
- reset_state=[create]
- )
- yield self.replicate()
-
- yield self.check(
- "get_room_name_and_aliases", (ROOM_ID,), (None, [])
- )
-
- @defer.inlineCallbacks
def test_room_members(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
new file mode 100644
index 0000000000..1d745ae1a7
--- /dev/null
+++ b/tests/util/test_rwlock.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+
+from synapse.util.async import ReadWriteLock
+
+
+class ReadWriteLockTestCase(unittest.TestCase):
+
+ def _assert_called_before_not_after(self, lst, first_false):
+ for i, d in enumerate(lst[:first_false]):
+ self.assertTrue(d.called, msg="%d was unexpectedly false" % i)
+
+ for i, d in enumerate(lst[first_false:]):
+ self.assertFalse(
+ d.called, msg="%d was unexpectedly true" % (i + first_false)
+ )
+
+ def test_rwlock(self):
+ rwlock = ReadWriteLock()
+
+ key = object()
+
+ ds = [
+ rwlock.read(key), # 0
+ rwlock.read(key), # 1
+ rwlock.write(key), # 2
+ rwlock.write(key), # 3
+ rwlock.read(key), # 4
+ rwlock.read(key), # 5
+ rwlock.write(key), # 6
+ ]
+
+ self._assert_called_before_not_after(ds, 2)
+
+ with ds[0].result:
+ self._assert_called_before_not_after(ds, 2)
+ self._assert_called_before_not_after(ds, 2)
+
+ with ds[1].result:
+ self._assert_called_before_not_after(ds, 2)
+ self._assert_called_before_not_after(ds, 3)
+
+ with ds[2].result:
+ self._assert_called_before_not_after(ds, 3)
+ self._assert_called_before_not_after(ds, 4)
+
+ with ds[3].result:
+ self._assert_called_before_not_after(ds, 4)
+ self._assert_called_before_not_after(ds, 6)
+
+ with ds[5].result:
+ self._assert_called_before_not_after(ds, 6)
+ self._assert_called_before_not_after(ds, 6)
+
+ with ds[4].result:
+ self._assert_called_before_not_after(ds, 6)
+ self._assert_called_before_not_after(ds, 7)
+
+ with ds[6].result:
+ pass
+
+ d = rwlock.write(key)
+ self.assertTrue(d.called)
+ with d.result:
+ pass
+
+ d = rwlock.read(key)
+ self.assertTrue(d.called)
+ with d.result:
+ pass
diff --git a/tests/utils.py b/tests/utils.py
index 6e41ae1ff6..ed547bc39b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -56,6 +56,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"}
+ config.ldap_enabled = False
if "clock" not in kargs:
kargs["clock"] = MockClock()
|