summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/errors.py18
-rw-r--r--synapse/api/filtering.py11
-rw-r--r--synapse/app/__init__.py9
-rw-r--r--synapse/config/__main__.py2
-rw-r--r--synapse/config/server.py38
-rw-r--r--synapse/federation/transaction_queue.py24
-rw-r--r--synapse/handlers/pagination.py21
-rw-r--r--synapse/http/matrixfederationclient.py250
-rw-r--r--synapse/python_dependencies.py55
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py34
-rw-r--r--synapse/rest/media/v1/media_repository.py7
-rw-r--r--synapse/storage/_base.py10
-rw-r--r--synapse/storage/registration.py50
14 files changed, 350 insertions, 181 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py

index 27241cb364..2935238fa2 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -27,4 +27,4 @@ try: except ImportError: pass -__version__ = "0.34.0" +__version__ = "0.34.1" diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 48b903374d..0b464834ce 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py
@@ -348,6 +348,24 @@ class IncompatibleRoomVersionError(SynapseError): ) +class RequestSendFailed(RuntimeError): + """Sending a HTTP request over federation failed due to not being able to + talk to the remote server for some reason. + + This exception is used to differentiate "expected" errors that arise due to + networking (e.g. DNS failures, connection timeouts etc), versus unexpected + errors (like programming errors). + """ + def __init__(self, inner_exception, can_retry): + super(RequestSendFailed, self).__init__( + "Failed to send request: %s: %s" % ( + type(inner_exception).__name__, inner_exception, + ) + ) + self.inner_exception = inner_exception + self.can_retry = can_retry + + def cs_error(msg, code=Codes.UNKNOWN, **kwargs): """ Utility method for constructing an error response for client-server interactions. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 677c0bdd4c..16ad654864 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py
@@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from six import text_type + import jsonschema from canonicaljson import json from jsonschema import FormatChecker @@ -353,7 +355,7 @@ class Filter(object): sender = event.user_id room_id = None ev_type = "m.presence" - is_url = False + contains_url = False else: sender = event.get("sender", None) if not sender: @@ -368,13 +370,16 @@ class Filter(object): room_id = event.get("room_id", None) ev_type = event.get("type", None) - is_url = "url" in event.get("content", {}) + + content = event.get("content", {}) + # check if there is a string url field in the content for filtering purposes + contains_url = isinstance(content.get("url"), text_type) return self.check_fields( room_id, sender, ev_type, - is_url, + contains_url, ) def check_fields(self, room_id, sender, event_type, contains_url): diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index 233bf43fc8..b45adafdd3 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py
@@ -19,15 +19,8 @@ from synapse import python_dependencies # noqa: E402 sys.dont_write_bytecode = True - try: python_dependencies.check_requirements() except python_dependencies.DependencyException as e: - message = "\n".join([ - "Missing Requirements: %s" % (", ".join(e.dependencies),), - "To install run:", - " pip install --upgrade --force %s" % (" ".join(e.dependencies),), - "", - ]) - sys.stderr.writelines(message) + sys.stderr.writelines(e.message) sys.exit(1) diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py
index 79fe9c3dac..fca35b008c 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py
@@ -16,7 +16,7 @@ from synapse.config._base import ConfigError if __name__ == "__main__": import sys - from homeserver import HomeServerConfig + from synapse.config.homeserver import HomeServerConfig action = sys.argv[1] diff --git a/synapse/config/server.py b/synapse/config/server.py
index 120c2b81fc..fb57791098 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017 New Vector Ltd +# Copyright 2017-2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import logging import os.path from synapse.http.endpoint import parse_and_validate_server_name +from synapse.python_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError @@ -204,6 +205,8 @@ class ServerConfig(Config): ] }) + _check_resource_config(self.listeners) + def default_config(self, server_name, data_dir_path, **kwargs): _, bind_port = parse_and_validate_server_name(server_name) if bind_port is not None: @@ -465,3 +468,36 @@ def _warn_if_webclient_configured(listeners): if name == 'webclient': logger.warning(NO_MORE_WEB_CLIENT_WARNING) return + + +KNOWN_RESOURCES = ( + 'client', + 'consent', + 'federation', + 'keys', + 'media', + 'metrics', + 'replication', + 'static', + 'webclient', +) + + +def _check_resource_config(listeners): + resource_names = set( + res_name + for listener in listeners + for res in listener.get("resources", []) + for res_name in res.get("names", []) + ) + + for resource in resource_names: + if resource not in KNOWN_RESOURCES: + raise ConfigError( + "Unknown listener resource '%s'" % (resource, ) + ) + if resource == "consent": + try: + check_requirements('resources.consent') + except DependencyException as e: + raise ConfigError(e.message) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 099ace28c1..fe787abaeb 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py
@@ -22,7 +22,11 @@ from prometheus_client import Counter from twisted.internet import defer import synapse.metrics -from synapse.api.errors import FederationDeniedError, HttpResponseException +from synapse.api.errors import ( + FederationDeniedError, + HttpResponseException, + RequestSendFailed, +) from synapse.handlers.presence import format_user_presence_state, get_interested_remotes from synapse.metrics import ( LaterGauge, @@ -518,11 +522,21 @@ class TransactionQueue(object): ) except FederationDeniedError as e: logger.info(e) - except Exception as e: - logger.warn( - "TX [%s] Failed to send transaction: %s", + except HttpResponseException as e: + logger.warning( + "TX [%s] Received %d response to transaction: %s", + destination, e.code, e, + ) + except RequestSendFailed as e: + logger.warning("TX [%s] Failed to send transaction: %s", destination, e) + + for p, _ in pending_pdus: + logger.info("Failed to send event %s to %s", p.event_id, + destination) + except Exception: + logger.exception( + "TX [%s] Failed to send transaction", destination, - e, ) for p, _ in pending_pdus: logger.info("Failed to send event %s to %s", p.event_id, diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 43f81bd607..9d257ecf31 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py
@@ -235,6 +235,17 @@ class PaginationHandler(object): "room_key", next_key ) + if events: + if event_filter: + events = event_filter.filter(events) + + events = yield filter_events_for_client( + self.store, + user_id, + events, + is_peeking=(member_event_id is None), + ) + if not events: defer.returnValue({ "chunk": [], @@ -242,16 +253,6 @@ class PaginationHandler(object): "end": next_token.to_string(), }) - if event_filter: - events = event_filter.filter(events) - - events = yield filter_events_for_client( - self.store, - user_id, - events, - is_peeking=(member_event_id is None), - ) - state = None if event_filter and event_filter.lazy_load_members(): # TODO: remove redundant members diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 24b6110c20..f2a42f97a6 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -19,7 +19,7 @@ import random import sys from io import BytesIO -from six import PY3, string_types +from six import PY3, raise_from, string_types from six.moves import urllib import attr @@ -41,6 +41,7 @@ from synapse.api.errors import ( Codes, FederationDeniedError, HttpResponseException, + RequestSendFailed, SynapseError, ) from synapse.http.endpoint import matrix_federation_endpoint @@ -228,19 +229,18 @@ class MatrixFederationHttpClient(object): backoff_on_404 (bool): Back off if we get a 404 Returns: - Deferred: resolves with the http response object on success. - - Fails with ``HttpResponseException``: if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist - - (May also fail with plenty of other Exceptions for things like DNS - failures, connection failures, SSL failures.) + Deferred[twisted.web.client.Response]: resolves with the HTTP + response object on success. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ if timeout: _sec_timeout = timeout / 1000 @@ -335,23 +335,74 @@ class MatrixFederationHttpClient(object): reactor=self.hs.get_reactor(), ) - with Measure(self.clock, "outbound_request"): - response = yield make_deferred_yieldable( - request_deferred, + try: + with Measure(self.clock, "outbound_request"): + response = yield make_deferred_yieldable( + request_deferred, + ) + except DNSLookupError as e: + raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e) + except Exception as e: + raise_from(RequestSendFailed(e, can_retry=True), e) + + logger.info( + "{%s} [%s] Got response headers: %d %s", + request.txn_id, + request.destination, + response.code, + response.phrase.decode('ascii', errors='replace'), + ) + + if 200 <= response.code < 300: + pass + else: + # :'( + # Update transactions table? + d = treq.content(response) + d = timeout_deferred( + d, + timeout=_sec_timeout, + reactor=self.hs.get_reactor(), + ) + + try: + body = yield make_deferred_yieldable(d) + except Exception as e: + # Eh, we're already going to raise an exception so lets + # ignore if this fails. + logger.warn( + "{%s} [%s] Failed to get error response: %s %s: %s", + request.txn_id, + request.destination, + request.method, + url_str, + _flatten_response_never_received(e), + ) + body = None + + e = HttpResponseException( + response.code, response.phrase, body ) + # Retry if the error is a 429 (Too Many Requests), + # otherwise just raise a standard HttpResponseException + if response.code == 429: + raise_from(RequestSendFailed(e, can_retry=True), e) + else: + raise e + break - except Exception as e: + except RequestSendFailed as e: logger.warn( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, request.method, url_str, - _flatten_response_never_received(e), + _flatten_response_never_received(e.inner_exception), ) - if not retry_on_dns_fail and isinstance(e, DNSLookupError): + if not e.can_retry: raise if retries_left and not timeout: @@ -376,29 +427,16 @@ class MatrixFederationHttpClient(object): else: raise - logger.info( - "{%s} [%s] Got response headers: %d %s", - request.txn_id, - request.destination, - response.code, - response.phrase.decode('ascii', errors='replace'), - ) - - if 200 <= response.code < 300: - pass - else: - # :'( - # Update transactions table? - d = treq.content(response) - d = timeout_deferred( - d, - timeout=_sec_timeout, - reactor=self.hs.get_reactor(), - ) - body = yield make_deferred_yieldable(d) - raise HttpResponseException( - response.code, response.phrase, body - ) + except Exception as e: + logger.warn( + "{%s} [%s] Request failed: %s %s: %s", + request.txn_id, + request.destination, + request.method, + url_str, + _flatten_response_never_received(e), + ) + raise defer.returnValue(response) @@ -477,17 +515,18 @@ class MatrixFederationHttpClient(object): requests) Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( @@ -531,17 +570,18 @@ class MatrixFederationHttpClient(object): try the request anyway. args (dict): query params Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( @@ -586,17 +626,18 @@ class MatrixFederationHttpClient(object): ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ logger.debug("get_json args: %s", args) @@ -637,17 +678,18 @@ class MatrixFederationHttpClient(object): ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( method="DELETE", @@ -680,18 +722,20 @@ class MatrixFederationHttpClient(object): args (dict): Optional dictionary used to create the query string. ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. - Returns: - Deferred: resolves with an (int,dict) tuple of the file length and - a dict of the response headers. - - Fails with ``HttpResponseException`` if we get an HTTP response code - >= 300 - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Returns: + Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of + the file length and a dict of the response headers. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( method="GET", @@ -784,21 +828,21 @@ def check_content_type_is_json(headers): headers (twisted.web.http_headers.Headers): headers to check Raises: - RuntimeError if the + RequestSendFailed: if the Content-Type header is missing or isn't JSON """ c_type = headers.getRawHeaders(b"Content-Type") if c_type is None: - raise RuntimeError( + raise RequestSendFailed(RuntimeError( "No Content-Type header" - ) + ), can_retry=False) c_type = c_type[0].decode('ascii') # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": - raise RuntimeError( + raise RequestSendFailed(RuntimeError( "Content-Type not application/json: was '%s'" % c_type - ) + ), can_retry=False) def encode_query_args(args): diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 2c65ef5856..69c5f9fe2e 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py
@@ -65,9 +65,13 @@ REQUIREMENTS = [ ] CONDITIONAL_REQUIREMENTS = { - "email.enable_notifs": ["Jinja2>=2.8", "bleach>=1.4.2"], + "email.enable_notifs": ["Jinja2>=2.9", "bleach>=1.4.2"], "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], "postgres": ["psycopg2>=2.6"], + + # ConsentResource uses select_autoescape, which arrived in jinja 2.9 + "resources.consent": ["Jinja2>=2.9"], + "saml2": ["pysaml2>=4.5.0"], "url_preview": ["lxml>=3.5.0"], "test": ["mock>=2.0"], @@ -84,18 +88,30 @@ def list_requirements(): class DependencyException(Exception): @property + def message(self): + return "\n".join([ + "Missing Requirements: %s" % (", ".join(self.dependencies),), + "To install run:", + " pip install --upgrade --force %s" % (" ".join(self.dependencies),), + "", + ]) + + @property def dependencies(self): for i in self.args[0]: yield '"' + i + '"' -def check_requirements(_get_distribution=get_distribution): - +def check_requirements(for_feature=None, _get_distribution=get_distribution): deps_needed = [] errors = [] - # Check the base dependencies exist -- they all must be installed. - for dependency in REQUIREMENTS: + if for_feature: + reqs = CONDITIONAL_REQUIREMENTS[for_feature] + else: + reqs = REQUIREMENTS + + for dependency in reqs: try: _get_distribution(dependency) except VersionConflict as e: @@ -108,23 +124,24 @@ def check_requirements(_get_distribution=get_distribution): deps_needed.append(dependency) errors.append("Needed %s but it was not installed" % (dependency,)) - # Check the optional dependencies are up to date. We allow them to not be - # installed. - OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) - - for dependency in OPTS: - try: - _get_distribution(dependency) - except VersionConflict: - deps_needed.append(dependency) - errors.append("Needed %s but it was not installed" % (dependency,)) - except DistributionNotFound: - # If it's not found, we don't care - pass + if not for_feature: + # Check the optional dependencies are up to date. We allow them to not be + # installed. + OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) + + for dependency in OPTS: + try: + _get_distribution(dependency) + except VersionConflict: + deps_needed.append(dependency) + errors.append("Needed %s but it was not installed" % (dependency,)) + except DistributionNotFound: + # If it's not found, we don't care + pass if deps_needed: for e in errors: - logging.exception(e) + logging.error(e) raise DependencyException(deps_needed) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 371e9aa354..f171b8d626 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -17,7 +17,7 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from ._base import client_v2_patterns @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) class AccountDataServlet(RestServlet): """ PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 + GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 """ PATTERNS = client_v2_patterns( "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" @@ -57,10 +58,26 @@ class AccountDataServlet(RestServlet): defer.returnValue((200, {})) + @defer.inlineCallbacks + def on_GET(self, request, user_id, account_data_type): + requester = yield self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot get account data for other users.") + + event = yield self.store.get_global_account_data_by_type_for_user( + account_data_type, user_id, + ) + + if event is None: + raise NotFoundError("Account data not found") + + defer.returnValue((200, event)) + class RoomAccountDataServlet(RestServlet): """ PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 + GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 """ PATTERNS = client_v2_patterns( "/user/(?P<user_id>[^/]*)" @@ -99,6 +116,21 @@ class RoomAccountDataServlet(RestServlet): defer.returnValue((200, {})) + @defer.inlineCallbacks + def on_GET(self, request, user_id, room_id, account_data_type): + requester = yield self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot get account data for other users.") + + event = yield self.store.get_account_data_for_room_and_type( + user_id, room_id, account_data_type, + ) + + if event is None: + raise NotFoundError("Room account data not found") + + defer.returnValue((200, event)) + def register_servlets(hs, http_server): AccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index e117836e9a..bdffa97805 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py
@@ -30,6 +30,7 @@ from synapse.api.errors import ( FederationDeniedError, HttpResponseException, NotFoundError, + RequestSendFailed, SynapseError, ) from synapse.metrics.background_process_metrics import run_as_background_process @@ -372,10 +373,10 @@ class MediaRepository(object): "allow_remote": "false", } ) - except twisted.internet.error.DNSLookupError as e: - logger.warn("HTTP error fetching remote media %s/%s: %r", + except RequestSendFailed as e: + logger.warn("Request failed fetching remote media %s/%s: %r", server_name, media_id, e) - raise NotFoundError() + raise SynapseError(502, "Failed to fetch remote media") except HttpResponseException as e: logger.warn("HTTP error fetching remote media %s/%s: %s", diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 1d3069b143..865b5e915a 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -547,11 +547,19 @@ class SQLBaseStore(object): if lock: self.database_engine.lock_table(txn, table) + def _getwhere(key): + # If the value we're passing in is None (aka NULL), we need to use + # IS, not =, as NULL = NULL equals NULL (False). + if keyvalues[key] is None: + return "%s IS ?" % (key,) + else: + return "%s = ?" % (key,) + # First try to update. sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), - " AND ".join("%s = ?" % (k,) for k in keyvalues) + " AND ".join(_getwhere(k) for k in keyvalues) ) sqlargs = list(values.values()) + list(keyvalues.values()) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 10c3b9757f..c9e11c3135 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py
@@ -114,6 +114,31 @@ class RegistrationWorkerStore(SQLBaseStore): return None + @cachedInlineCallbacks() + def is_support_user(self, user_id): + """Determines if the user is of type UserTypes.SUPPORT + + Args: + user_id (str): user id to test + + Returns: + Deferred[bool]: True if user is of type UserTypes.SUPPORT + """ + res = yield self.runInteraction( + "is_support_user", self.is_support_user_txn, user_id + ) + defer.returnValue(res) + + def is_support_user_txn(self, txn, user_id): + res = self._simple_select_one_onecol_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + retcol="user_type", + allow_none=True, + ) + return True if res == UserTypes.SUPPORT else False + class RegistrationStore(RegistrationWorkerStore, background_updates.BackgroundUpdateStore): @@ -465,31 +490,6 @@ class RegistrationStore(RegistrationWorkerStore, defer.returnValue(res if res else False) - @cachedInlineCallbacks() - def is_support_user(self, user_id): - """Determines if the user is of type UserTypes.SUPPORT - - Args: - user_id (str): user id to test - - Returns: - Deferred[bool]: True if user is of type UserTypes.SUPPORT - """ - res = yield self.runInteraction( - "is_support_user", self.is_support_user_txn, user_id - ) - defer.returnValue(res) - - def is_support_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - retcol="user_type", - allow_none=True, - ) - return True if res == UserTypes.SUPPORT else False - @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): yield self._simple_upsert("user_threepids", {