From 5eac0b7e76c8316604480faf2d6158a0e1d68466 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 09:00:59 -0400 Subject: Add more types to synapse.storage.database. (#8127) --- synapse/storage/databases/main/ui_auth.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 37276f73f8..d80d7da895 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -19,6 +19,7 @@ from canonicaljson import json from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict from synapse.util import stringutils as stringutils @@ -214,14 +215,16 @@ class UIAuthWorkerStore(SQLBaseStore): value, ) - def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): + def _set_ui_auth_session_data_txn( + self, txn: LoggingTransaction, session_id: str, key: str, value: Any + ): # Get the current value. result = self.db_pool.simple_select_one_txn( txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("serverdict",), - ) + ) # type: Dict[str, Any] # type: ignore # Update it and add it back to the database. serverdict = db_to_json(result["serverdict"]) @@ -275,7 +278,9 @@ class UIAuthStore(UIAuthWorkerStore): expiration_time, ) - def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): + def _delete_old_ui_auth_sessions_txn( + self, txn: LoggingTransaction, expiration_time: int + ): # Get the expired sessions. sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" txn.execute(sql, [expiration_time]) -- cgit 1.5.1 From dbc630a628e4fc6eb5eff09ce5edba062c0e9955 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 10:32:33 -0400 Subject: Use the JSON encoder without whitespace in more places. (#8124) --- changelog.d/8124.misc | 1 + synapse/handlers/devicemessage.py | 5 ++--- synapse/logging/opentracing.py | 5 ++--- synapse/rest/well_known.py | 4 ++-- synapse/storage/background_updates.py | 5 ++--- synapse/storage/databases/main/appservice.py | 5 ++--- synapse/storage/databases/main/room.py | 5 ++--- synapse/storage/databases/main/tags.py | 7 +++---- synapse/storage/databases/main/ui_auth.py | 11 +++++------ 9 files changed, 21 insertions(+), 27 deletions(-) create mode 100644 changelog.d/8124.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/8124.misc b/changelog.d/8124.misc new file mode 100644 index 0000000000..9fac710205 --- /dev/null +++ b/changelog.d/8124.misc @@ -0,0 +1 @@ +Reduce the amount of whitespace in JSON stored and sent in responses. diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 610b08d00b..dcb4c82244 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -16,8 +16,6 @@ import logging from typing import Any, Dict -from canonicaljson import json - from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -27,6 +25,7 @@ from synapse.logging.opentracing import ( start_active_span, ) from synapse.types import UserID, get_domain_from_id +from synapse.util import json_encoder from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) @@ -174,7 +173,7 @@ class DeviceMessageHandler(object): "sender": sender_user_id, "type": message_type, "message_id": message_id, - "org.matrix.opentracing_context": json.dumps(context), + "org.matrix.opentracing_context": json_encoder.encode(context), } log_kv({"local_messages": local_messages}) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index abe532d350..d39ac62168 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -172,12 +172,11 @@ from functools import wraps from typing import TYPE_CHECKING, Dict, Optional, Type import attr -from canonicaljson import json from twisted.internet import defer from synapse.config import ConfigError -from synapse.util import json_decoder +from synapse.util import json_decoder, json_encoder if TYPE_CHECKING: from synapse.http.site import SynapseRequest @@ -693,7 +692,7 @@ def active_span_context_as_string(): opentracing.tracer.inject( opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier ) - return json.dumps(carrier) + return json_encoder.encode(carrier) @only_if_tracing diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 20177b44e7..e15e13b756 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging from twisted.web.resource import Resource from synapse.http.server import set_cors_headers +from synapse.util import json_encoder logger = logging.getLogger(__name__) @@ -67,4 +67,4 @@ class WellKnownResource(Resource): logger.debug("returning: %s", r) request.setHeader(b"Content-Type", b"application/json") - return json.dumps(r).encode("utf-8") + return json_encoder.encode(r).encode("utf-8") diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 90a1f9e8b1..56818f4df8 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -16,9 +16,8 @@ import logging from typing import Optional -from canonicaljson import json - from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import json_encoder from . import engines @@ -457,7 +456,7 @@ class BackgroundUpdater(object): progress(dict): The progress of the update. """ - progress_json = json.dumps(progress) + progress_json = json_encoder.encode(progress) self.db_pool.simple_update_one_txn( txn, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 02568a2391..77723f7d4d 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -16,13 +16,12 @@ import logging import re -from canonicaljson import json - from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util import json_encoder logger = logging.getLogger(__name__) @@ -204,7 +203,7 @@ class ApplicationServiceTransactionWorkerStore( new_txn_id = max(highest_txn_id, last_txn_id) + 1 # Insert new txn into txn table - event_ids = json.dumps([e.event_id for e in events]) + event_ids = json_encoder.encode([e.event_id for e in events]) txn.execute( "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "VALUES(?,?,?)", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index aef08c7e12..7d3ac47261 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -21,8 +21,6 @@ from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from canonicaljson import json - from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions @@ -30,6 +28,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore from synapse.types import ThirdPartyInstanceID +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -1310,7 +1309,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "event_id": event_id, "user_id": user_id, "reason": reason, - "content": json.dumps(content), + "content": json_encoder.encode(content), }, desc="add_event_report", ) diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index e4e0a0c433..ade7abc927 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -17,11 +17,10 @@ import logging from typing import Dict, List, Tuple -from canonicaljson import json - from synapse.storage._base import db_to_json from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.types import JsonDict +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore): txn.execute(sql, (user_id, room_id)) tags = [] for tag, content in txn: - tags.append(json.dumps(tag) + ":" + content) + tags.append(json_encoder.encode(tag) + ":" + content) tag_json = "{" + ",".join(tags) + "}" results.append((stream_id, (user_id, room_id, tag_json))) @@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) def add_tag_txn(txn, next_id): self.db_pool.simple_upsert_txn( diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index d80d7da895..6281a41a3d 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -15,13 +15,12 @@ from typing import Any, Dict, Optional, Union import attr -from canonicaljson import json from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict -from synapse.util import stringutils as stringutils +from synapse.util import json_encoder, stringutils @attr.s @@ -73,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore): StoreError if a unique session ID cannot be generated. """ # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) + clientdict_json = json_encoder.encode(clientdict) # autogen a session ID and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. @@ -144,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore): await self.db_pool.simple_upsert( table="ui_auth_sessions_credentials", keyvalues={"session_id": session_id, "stage_type": stage_type}, - values={"result": json.dumps(result)}, + values={"result": json_encoder.encode(result)}, desc="mark_ui_auth_stage_complete", ) except self.db_pool.engine.module.IntegrityError: @@ -185,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore): The dictionary from the client root level, not the 'auth' key. """ # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) + clientdict_json = json_encoder.encode(clientdict) await self.db_pool.simple_update_one( table="ui_auth_sessions", @@ -234,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore): txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, - updatevalues={"serverdict": json.dumps(serverdict)}, + updatevalues={"serverdict": json_encoder.encode(serverdict)}, ) async def get_ui_auth_session_data( -- cgit 1.5.1 From 318f4e738e66a1376a8db25c217a090384083f2d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 20 Aug 2020 16:42:12 +0100 Subject: Be more tolerant of membership events in unknown rooms (#8110) It turns out that not all out-of-band membership events are labelled as such, so we need to be more accepting here. --- changelog.d/8110.bugfix | 1 + synapse/events/__init__.py | 2 ++ synapse/storage/databases/main/events_worker.py | 31 ++++++++++++++++++++----- 3 files changed, 28 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8110.bugfix (limited to 'synapse/storage/databases') diff --git a/changelog.d/8110.bugfix b/changelog.d/8110.bugfix new file mode 100644 index 0000000000..5269a232e1 --- /dev/null +++ b/changelog.d/8110.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.12.0 which could cause `/sync` requests to fail with a 404 if you had a very old outstanding room invite. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index cc5deca75b..67db763dbf 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -133,6 +133,8 @@ class _EventInternalMetadata(object): rejection. This is needed as those events are marked as outliers, but they still need to be processed as if they're new events (e.g. updating invite state in the database, relaying to clients, etc). + + (Added in synapse 0.99.0, so may be unreliable for events received before that) """ return self._dict.get("out_of_band_membership", False) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4a3333c0db..e1241a724b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -620,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore): room_version_id = row["room_version_id"] if not room_version_id: - # this should only happen for out-of-band membership events - if not internal_metadata.get("out_of_band_membership"): - logger.warning( - "Room %s for event %s is unknown", d["room_id"], event_id + # this should only happen for out-of-band membership events which + # arrived before #6983 landed. For all other events, we should have + # an entry in the 'rooms' table. + # + # However, the 'out_of_band_membership' flag is unreliable for older + # invites, so just accept it for all membership events. + # + if d["type"] != EventTypes.Member: + raise Exception( + "Room %s for event %s is unknown" % (d["room_id"], event_id) ) - continue - # take a wild stab at the room version based on the event format + # so, assuming this is an out-of-band-invite that arrived before #6983 + # landed, we know that the room version must be v5 or earlier (because + # v6 hadn't been invented at that point, so invites from such rooms + # would have been rejected.) + # + # The main reason we need to know the room version here (other than + # choosing the right python Event class) is in case the event later has + # to be redacted - and all the room versions up to v5 used the same + # redaction algorithm. + # + # So, the following approximations should be adequate. + if format_version == EventFormatVersions.V1: + # if it's event format v1 then it must be room v1 or v2 room_version = RoomVersions.V1 elif format_version == EventFormatVersions.V2: + # if it's event format v2 then it must be room v3 room_version = RoomVersions.V3 else: + # if it's event format v3 then it must be room v4 or v5 room_version = RoomVersions.V5 else: room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) -- cgit 1.5.1 From 3f91638da6ea0aeaf789ddc8ca1e624a11b7ebb2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 15:42:58 -0400 Subject: Allow denying or shadow banning registrations via the spam checker (#8034) --- changelog.d/8034.feature | 1 + synapse/events/spamcheck.py | 35 ++++++++++++++- synapse/handlers/auth.py | 8 ++++ synapse/handlers/cas_handler.py | 11 ++++- synapse/handlers/oidc_handler.py | 21 +++++++-- synapse/handlers/register.py | 26 ++++++++++- synapse/handlers/saml_handler.py | 18 +++++++- synapse/rest/client/v2_alpha/register.py | 5 +++ synapse/spam_checker_api/__init__.py | 11 +++++ .../main/schema/delta/58/07persist_ui_auth_ips.sql | 25 +++++++++++ synapse/storage/databases/main/ui_auth.py | 39 +++++++++++++++- tests/handlers/test_oidc.py | 18 ++++++-- tests/handlers/test_register.py | 52 +++++++++++++++++++++- tests/handlers/test_user_directory.py | 6 +-- 14 files changed, 258 insertions(+), 18 deletions(-) create mode 100644 changelog.d/8034.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql (limited to 'synapse/storage/databases') diff --git a/changelog.d/8034.feature b/changelog.d/8034.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8034.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 1ffc9525d1..a7cddac974 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,9 +15,10 @@ # limitations under the License. import inspect -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple -from synapse.spam_checker_api import SpamCheckerApi +from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi +from synapse.types import Collection MYPY = False if MYPY: @@ -160,3 +161,33 @@ class SpamChecker(object): return True return False + + def check_registration_for_spam( + self, + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: + """Checks if we should allow the given registration request. + + Args: + email_threepid: The email threepid used for registering, if any + username: The request user name, if any + request_info: List of tuples of user agent and IP that + were used during the registration process. + + Returns: + Enum for how the request should be handled + """ + + for spam_checker in self.spam_checkers: + # For backwards compatibility, only run if the method exists on the + # spam checker + checker = getattr(spam_checker, "check_registration_for_spam", None) + if checker: + behaviour = checker(email_threepid, username, request_info) + assert isinstance(behaviour, RegistrationBehaviour) + if behaviour != RegistrationBehaviour.ALLOW: + return behaviour + + return RegistrationBehaviour.ALLOW diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 68d6870e40..654f58ddae 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -364,6 +364,14 @@ class AuthHandler(BaseHandler): # authentication flow. await self.store.set_ui_auth_clientdict(sid, clientdict) + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + + await self.store.add_user_agent_ip_to_ui_auth_session( + session.session_id, user_agent, clientip + ) + if not authdict: raise InteractiveAuthIncompleteError( session.session_id, self._auth_dict_for_flows(flows, session.session_id) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 786e608fa2..a4cc4b9a5a 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -35,6 +35,7 @@ class CasHandler: """ def __init__(self, hs): + self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -210,8 +211,16 @@ class CasHandler: else: if not registered_user_id: + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders( + b"User-Agent", default=[b""] + )[0].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name + localpart=localpart, + default_display_name=user_display_name, + user_agent_ips=(user_agent, ip_address), ) await self._auth_handler.complete_sso_login( diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index dd3703cbd2..c5bd2fea68 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -93,6 +93,7 @@ class OidcHandler: """ def __init__(self, hs: "HomeServer"): + self.hs = hs self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] self._client_auth = ClientAuth( @@ -689,9 +690,17 @@ class OidcHandler: self._render_error(request, "invalid_token", str(e)) return + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + # Call the mapper to register/login the user try: - user_id = await self._map_userinfo_to_user(userinfo, token) + user_id = await self._map_userinfo_to_user( + userinfo, token, user_agent, ip_address + ) except MappingException as e: logger.exception("Could not map user") self._render_error(request, "mapping_error", str(e)) @@ -828,7 +837,9 @@ class OidcHandler: now = self._clock.time_msec() return now < expiry - async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: + async def _map_userinfo_to_user( + self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str + ) -> str: """Maps a UserInfo object to a mxid. UserInfo should have a claim that uniquely identifies users. This claim @@ -843,6 +854,8 @@ class OidcHandler: Args: userinfo: an object representing the user token: a dict with the tokens obtained from the provider + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Raises: MappingException: if there was an error while mapping some properties @@ -899,7 +912,9 @@ class OidcHandler: # It's the first time this user is logging in and the mapped mxid was # not taken, register the user registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=attributes["display_name"], + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ccd96e4626..cde2dbca92 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -26,6 +26,7 @@ from synapse.replication.http.register import ( ReplicationPostRegisterActionsServlet, ReplicationRegisterServlet, ) +from synapse.spam_checker_api import RegistrationBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester @@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._server_notices_mxid = hs.config.server_notices_mxid + self.spam_checker = hs.get_spam_checker() + if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( @@ -144,7 +147,7 @@ class RegistrationHandler(BaseHandler): address=None, bind_emails=[], by_admin=False, - shadow_banned=False, + user_agent_ips=None, ): """Registers a new client on the server. @@ -162,7 +165,8 @@ class RegistrationHandler(BaseHandler): bind_emails (List[str]): list of emails to bind to this account. by_admin (bool): True if this registration is being made via the admin api, otherwise False. - shadow_banned (bool): Shadow-ban the created user. + user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used + during the registration process. Returns: str: user_id Raises: @@ -170,6 +174,24 @@ class RegistrationHandler(BaseHandler): """ self.check_registration_ratelimit(address) + result = self.spam_checker.check_registration_for_spam( + threepid, localpart, user_agent_ips or [], + ) + + if result == RegistrationBehaviour.DENY: + logger.info( + "Blocked registration of %r", localpart, + ) + # We return a 429 to make it not obvious that they've been + # denied. + raise SynapseError(429, "Rate limited") + + shadow_banned = result == RegistrationBehaviour.SHADOW_BAN + if shadow_banned: + logger.info( + "Shadow banning registration of %r", localpart, + ) + # do not check_auth_blocking if the call is coming through the Admin API if not by_admin: await self.auth.check_auth_blocking(threepid=threepid) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index c1fcb98454..b426199aa6 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -54,6 +54,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs: "synapse.server.HomeServer"): + self.hs = hs self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -133,8 +134,14 @@ class SamlHandler: # the dict. self.expire_sessions() + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + user_id, current_session = await self._map_saml_response_to_user( - resp_bytes, relay_state + resp_bytes, relay_state, user_agent, ip_address ) # Complete the interactive auth session or the login. @@ -147,7 +154,11 @@ class SamlHandler: await self._auth_handler.complete_sso_login(user_id, request, relay_state) async def _map_saml_response_to_user( - self, resp_bytes: str, client_redirect_url: str + self, + resp_bytes: str, + client_redirect_url: str, + user_agent: str, + ip_address: str, ) -> Tuple[str, Optional[Saml2SessionData]]: """ Given a sample response, retrieve the cached session and user for it. @@ -155,6 +166,8 @@ class SamlHandler: Args: resp_bytes: The SAML response. client_redirect_url: The redirect URL passed in by the client. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Returns: Tuple of the user ID and SAML session associated with this response. @@ -291,6 +304,7 @@ class SamlHandler: localpart=localpart, default_display_name=displayname, bind_emails=emails, + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 7290fd0756..be0e680ac5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -591,12 +591,17 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) + entries = await self.store.get_user_agents_ips_to_ui_auth_session( + session_id + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, + user_agent_ips=entries, ) # Necessary due to auth checks prior to the threepid being # written to the db diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 7f63f1bfa0..9be92e2565 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from enum import Enum from twisted.internet import defer @@ -25,6 +26,16 @@ if MYPY: logger = logging.getLogger(__name__) +class RegistrationBehaviour(Enum): + """ + Enum to define whether a registration request should allowed, denied, or shadow-banned. + """ + + ALLOW = "allow" + SHADOW_BAN = "shadow_ban" + DENY = "deny" + + class SpamCheckerApi(object): """A proxy object that gets passed to spam checkers so they can get access to rooms and other relevant information. diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql new file mode 100644 index 0000000000..4cc96a5341 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A table of the IP address and user-agent used to complete each step of a +-- user-interactive authentication session. +CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips( + session_id TEXT NOT NULL, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + UNIQUE (session_id, ip, user_agent), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 6281a41a3d..9eef8e57c5 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import attr @@ -260,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore): return serverdict.get(key, default) + async def add_user_agent_ip_to_ui_auth_session( + self, session_id: str, user_agent: str, ip: str, + ): + """Add the given user agent / IP to the tracking table + """ + await self.db_pool.simple_upsert( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, + values={}, + desc="add_user_agent_ip_to_ui_auth_session", + ) + + async def get_user_agents_ips_to_ui_auth_session( + self, session_id: str, + ) -> List[Tuple[str, str]]: + """Get the given user agents / IPs used during the ui auth process + + Returns: + List of user_agent/ip pairs + """ + rows = await self.db_pool.simple_select_list( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id}, + retcols=("user_agent", "ip"), + desc="get_user_agents_ips_to_ui_auth_session", + ) + return [(row["user_agent"], row["ip"]) for row in rows] + class UIAuthStore(UIAuthWorkerStore): def delete_old_ui_auth_sessions(self, expiration_time: int): @@ -285,6 +313,15 @@ class UIAuthStore(UIAuthWorkerStore): txn.execute(sql, [expiration_time]) session_ids = [r[0] for r in txn.fetchall()] + # Delete the corresponding IP/user agents. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions_ips", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 1bb25ab684..f92f3b8c15 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock(spec=["args", "getCookie", "addCookie"]) + request = Mock( + spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + ) code = "code" state = "state" nonce = "nonce" client_redirect_url = "http://client/redirect" + user_agent = "Browser" + ip_address = "10.0.0.1" session = self.handler._generate_oidc_session_token( state=state, nonce=nonce, @@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] + request.requestHeaders = Mock(spec=["getRawHeaders"]) + request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] + request.getClientIP.return_value = ip_address + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( @@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_not_called() self.handler._render_error.assert_not_called() @@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._render_error.assert_not_called() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e364b1bd62..5c92d0e8c9 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,18 +17,21 @@ from mock import Mock from twisted.internet import defer +from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler +from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester from tests.test_utils import make_awaitable from tests.unittest import override_config +from tests.utils import mock_getRawHeaders from .. import unittest -class RegistrationHandlers(object): +class RegistrationHandlers: def __init__(self, hs): self.registration_handler = RegistrationHandler(hs) @@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) + def test_spam_checker_deny(self): + """A spam checker can deny registration, which results in an error.""" + + class DenyAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.DENY + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [DenyAll()] + + self.get_failure(self.handler.register_user(localpart="user"), SynapseError) + + def test_spam_checker_shadow_ban(self): + """A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" + + class BanAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.SHADOW_BAN + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanAll()] + + user_id = self.get_success(self.handler.register_user(localpart="user")) + + # Get an access token. + token = self.macaroon_generator.generate_access_token(user_id) + self.get_success( + self.store.add_access_token_to_user( + user_id=user_id, token=token, device_id=None, valid_until_ms=None + ) + ) + + # Ensure the user was marked as shadow-banned. + request = Mock(args={}) + request.args[b"access_token"] = [token.encode("ascii")] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + auth = Auth(self.hs) + requester = self.get_success(auth.get_user_by_req(request)) + + self.assertTrue(requester.shadow_banned) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 31ed89a5cd..87be94111f 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_spam_checker(self): """ - A user which fails to the spam checks will not appear in search results. + A user which fails the spam checks will not appear in search results. """ u1 = self.register_user("user1", "pass") u1_token = self.login(u1, "pass") @@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that does not filter any users. spam_checker = self.hs.get_spam_checker() - class AllowAll(object): + class AllowAll: def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - class BlockAll(object): + class BlockAll: def check_username_for_spam(self, user_profile): # All users are spammy. return True -- cgit 1.5.1 From 09fd0eda81c8cdab67d0fee83c4035886c881983 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 21 Aug 2020 10:06:45 +0100 Subject: Micro-optimisations to get_auth_chain_ids (#8132) --- changelog.d/8132.misc | 1 + synapse/storage/databases/main/event_federation.py | 40 +++++++++------------- 2 files changed, 18 insertions(+), 23 deletions(-) create mode 100644 changelog.d/8132.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/8132.misc b/changelog.d/8132.misc new file mode 100644 index 0000000000..7afa267c69 --- /dev/null +++ b/changelog.d/8132.misc @@ -0,0 +1 @@ +Micro-optimisations to get_auth_chain_ids. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4826be630c..e6a97b018c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -15,14 +15,16 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Dict, Iterable, List, Set, Tuple from synapse.api.errors import StoreError +from synapse.events import EventBase from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.types import Collection from synapse.util.caches.descriptors import cached from synapse.util.iterutils import batch_iter @@ -30,12 +32,14 @@ logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - async def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain( + self, event_ids: Collection[str], include_given: bool = False + ) -> List[EventBase]: """Get auth events for given event_ids. The events *must* be state events. Args: - event_ids (list): state events - include_given (bool): include the given events in result + event_ids: state events + include_given: include the given events in result Returns: list of events @@ -45,43 +49,34 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) return await self.get_events_as_list(event_ids) - def get_auth_chain_ids( - self, - event_ids: List[str], - include_given: bool = False, - ignore_events: Optional[Set[str]] = None, - ): + async def get_auth_chain_ids( + self, event_ids: Collection[str], include_given: bool = False, + ) -> List[str]: """Get auth events for given event_ids. The events *must* be state events. Args: event_ids: state events include_given: include the given events in result - ignore_events: Set of events to exclude from the returned auth - chain. This is useful if the caller will just discard the - given events anyway, and saves us from figuring out their auth - chains if not required. Returns: list of event_ids """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given, - ignore_events, ) - def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events): - if ignore_events is None: - ignore_events = set() - + def _get_auth_chain_ids_txn( + self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool + ) -> List[str]: if include_given: results = set(event_ids) else: results = set() - base_sql = "SELECT auth_id FROM event_auth WHERE " + base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE " front = set(event_ids) while front: @@ -93,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(base_sql + clause, args) new_front.update(r[0] for r in txn) - new_front -= ignore_events new_front -= results front = new_front -- cgit 1.5.1 From 3f49f74610197d32fe73678cabc10f08732e66b8 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 24 Aug 2020 11:33:55 +0100 Subject: Don't fail /submit_token requests on incorrect session ID if request_token_inhibit_3pid_errors is turned on (#7991) * Don't raise session_id errors on submit_token if request_token_inhibit_3pid_errors is set * Changelog * Also wait some time before responding to /requestToken * Incorporate review * Update synapse/storage/databases/main/registration.py Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> * Incorporate review Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/7991.misc | 1 + synapse/rest/client/v2_alpha/account.py | 10 +++++++++ synapse/rest/client/v2_alpha/register.py | 7 ++++++ synapse/storage/databases/main/registration.py | 25 ++++++++++++++++----- tests/storage/test_registration.py | 31 ++++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 6 deletions(-) create mode 100644 changelog.d/7991.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/7991.misc b/changelog.d/7991.misc new file mode 100644 index 0000000000..1562e3af9e --- /dev/null +++ b/changelog.d/7991.misc @@ -0,0 +1 @@ +Don't fail `/submit_token` requests on incorrect session ID if `request_token_inhibit_3pid_errors` is turned on. diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 203e76b9f2..3481477731 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random from http import HTTPStatus from synapse.api.constants import LoginType @@ -109,6 +110,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if self.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) @@ -448,6 +452,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): if self.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -516,6 +523,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index be0e680ac5..51372cdb5e 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,6 +16,7 @@ import hmac import logging +import random from typing import List, Union import synapse @@ -131,6 +132,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -203,6 +207,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 068ad22b30..321a51cc6a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -889,6 +889,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity + self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors if self._account_validity.enabled: self._clock.call_later( @@ -1302,15 +1303,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) if not row: - raise ThreepidValidationError(400, "Unknown session_id") + if self._ignore_unknown_session_error: + # If we need to inhibit the error caused by an incorrect session ID, + # use None as placeholder values for the client secret and the + # validation timestamp. + # It shouldn't be an issue because they're both only checked after + # the token check, which should fail. And if it doesn't for some + # reason, the next check is on the client secret, which is NOT NULL, + # so we don't have to worry about the client secret matching by + # accident. + row = {"client_secret": None, "validated_at": None} + else: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] validated_at = row["validated_at"] - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" - ) - row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_token", @@ -1326,6 +1334,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expires = row["expires"] next_link = row["next_link"] + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id" + ) + # If the session is already validated, no need to revalidate if validated_at: return next_link diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 840db66072..58f827d8d3 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.constants import UserTypes +from synapse.api.errors import ThreepidValidationError from tests import unittest from tests.utils import setup_test_homeserver @@ -122,3 +123,33 @@ class RegistrationStoreTestCase(unittest.TestCase): ) res = yield self.store.is_support_user(SUPPORT_USER) self.assertTrue(res) + + @defer.inlineCallbacks + def test_3pid_inhibit_invalid_validation_session_error(self): + """Tests that enabling the configuration option to inhibit 3PID errors on + /requestToken also inhibits validation errors caused by an unknown session ID. + """ + + # Check that, with the config setting set to false (the default value), a + # validation error is caused by the unknown session ID. + try: + yield defer.ensureDeferred( + self.store.validate_threepid_session( + "fake_sid", "fake_client_secret", "fake_token", 0, + ) + ) + except ThreepidValidationError as e: + self.assertEquals(e.msg, "Unknown session_id", e) + + # Set the config setting to true. + self.store._ignore_unknown_session_error = True + + # Check that now the validation error is caused by the token not matching. + try: + yield defer.ensureDeferred( + self.store.validate_threepid_session( + "fake_sid", "fake_client_secret", "fake_token", 0, + ) + ) + except ThreepidValidationError as e: + self.assertEquals(e.msg, "Validation token not found or has expired", e) -- cgit 1.5.1 From 97962ad17b204be0a88ef0cd3026f11c359fdb4a Mon Sep 17 00:00:00 2001 From: Manuel Stahl <37705355+awesome-manuel@users.noreply.github.com> Date: Tue, 25 Aug 2020 15:18:14 +0200 Subject: Search in columns 'name' and 'displayname' in the admin users endpoint (#7377) * Search in columns 'name' and 'displayname' in the admin users endpoint Signed-off-by: Manuel Stahl --- changelog.d/7377.misc | 1 + docs/admin_api/user_admin_api.rst | 6 ++++-- synapse/rest/admin/users.py | 4 +++- synapse/storage/databases/main/__init__.py | 31 ++++++++++++++++++------------ 4 files changed, 27 insertions(+), 15 deletions(-) create mode 100644 changelog.d/7377.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/7377.misc b/changelog.d/7377.misc new file mode 100644 index 0000000000..67e2da0dcb --- /dev/null +++ b/changelog.d/7377.misc @@ -0,0 +1 @@ +Search in columns 'name' and 'displayname' in the admin users endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index be05128b3e..99948ec061 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -119,8 +119,10 @@ from a previous call. The parameter ``limit`` is optional but is used for pagination, denoting the maximum number of items to return in this call. Defaults to ``100``. -The parameter ``user_id`` is optional and filters to only users with user IDs -that contain this value. +The parameter ``user_id`` is optional and can be used to filter by user id. + +The parameter ``name`` is optional and can be used to list only users with the +local part of the user ID or display name that contain this value. The parameter ``guests`` is optional and if ``false`` will **exclude** guest users. Defaults to ``true`` to include guest users. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index cc0bdfa5c9..f3e77da850 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -73,6 +73,7 @@ class UsersRestServletV2(RestServlet): The parameters `from` and `limit` are required only for pagination. By default, a `limit` of 100 is used. The parameter `user_id` can be used to filter by user id. + The parameter `name` can be used to filter by user id or display name. The parameter `guests` can be used to exclude guest users. The parameter `deactivated` can be used to include deactivated users. """ @@ -89,11 +90,12 @@ class UsersRestServletV2(RestServlet): start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) user_id = parse_string(request, "user_id", default=None) + name = parse_string(request, "name", default=None) guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) users, total = await self.store.get_users_paginate( - start, limit, user_id, guests, deactivated + start, limit, user_id, name, guests, deactivated ) ret = {"users": users, "total": total} if len(users) >= limit: diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 17fa470919..0ed726fee0 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -498,7 +498,7 @@ class DataStore( ) def get_users_paginate( - self, start, limit, name=None, guests=True, deactivated=False + self, start, limit, user_id=None, name=None, guests=True, deactivated=False ): """Function to retrieve a paginated list of users from users list. This will return a json list of users and the @@ -507,7 +507,8 @@ class DataStore( Args: start (int): start number to begin the query from limit (int): number of rows to retrieve - name (string): filter for user names + user_id (string): search for user_id + name (string): search for local part of user_id or display name guests (bool): whether to in include guest users deactivated (bool): whether to include deactivated users Returns: @@ -516,11 +517,14 @@ class DataStore( def get_users_paginate_txn(txn): filters = [] - args = [] + args = [self.hs.config.server_name] if name: + filters.append("(name LIKE ? OR displayname LIKE ?)") + args.extend(["@%" + name + "%:%", "%" + name + "%"]) + elif user_id: filters.append("name LIKE ?") - args.append("%" + name + "%") + args.extend(["%" + user_id + "%"]) if not guests: filters.append("is_guest = 0") @@ -530,20 +534,23 @@ class DataStore( where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" - sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause) - txn.execute(sql, args) - count = txn.fetchone()[0] - - args = [self.hs.config.server_name] + args + [limit, start] - sql = """ - SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url + sql_base = """ FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? {} - ORDER BY u.name LIMIT ? OFFSET ? """.format( where_clause ) + sql = "SELECT COUNT(*) as total_users " + sql_base + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = ( + "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url " + + sql_base + + " ORDER BY u.name LIMIT ? OFFSET ?" + ) + args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) return users, count -- cgit 1.5.1 From 74bf8d4d0659a87152804dc56df9284be87512bb Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 25 Aug 2020 15:03:24 +0100 Subject: Wording fixes to 'name' user admin api filter (#8163) Some fixes to wording I noticed after merging #7377. --- changelog.d/7377.misc | 2 +- changelog.d/8163.misc | 1 + docs/admin_api/user_admin_api.rst | 7 ++++--- synapse/storage/databases/main/__init__.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8163.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/7377.misc b/changelog.d/7377.misc index 67e2da0dcb..b3ec08855b 100644 --- a/changelog.d/7377.misc +++ b/changelog.d/7377.misc @@ -1 +1 @@ -Search in columns 'name' and 'displayname' in the admin users endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. +Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/changelog.d/8163.misc b/changelog.d/8163.misc new file mode 100644 index 0000000000..b3ec08855b --- /dev/null +++ b/changelog.d/8163.misc @@ -0,0 +1 @@ +Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 9cb924895f..d6e3194cda 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -119,10 +119,11 @@ from a previous call. The parameter ``limit`` is optional but is used for pagination, denoting the maximum number of items to return in this call. Defaults to ``100``. -The parameter ``user_id`` is optional and can be used to filter by user id. +The parameter ``user_id`` is optional and filters to only return users with user IDs +that contain this value. This parameter is ignored when using the ``name`` parameter. -The parameter ``name`` is optional and can be used to list only users with the -local part of the user ID or display name that contain this value. +The parameter ``name`` is optional and filters to only return users with user ID localparts +**or** displaynames that contain this value. The parameter ``guests`` is optional and if ``false`` will **exclude** guest users. Defaults to ``true`` to include guest users. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0ed726fee0..0934ae276c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -507,7 +507,7 @@ class DataStore( Args: start (int): start number to begin the query from limit (int): number of rows to retrieve - user_id (string): search for user_id + user_id (string): search for user_id. ignored if name is not None name (string): search for local part of user_id or display name guests (bool): whether to in include guest users deactivated (bool): whether to include deactivated users -- cgit 1.5.1 From 2231dffee6788836c86e868dd29574970b13dd18 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 25 Aug 2020 15:10:08 +0100 Subject: Make StreamIdGen `get_next` and `get_next_mult` async (#8161) This is mainly so that `StreamIdGenerator` and `MultiWriterIdGenerator` will have the same interface, allowing them to be used interchangeably. --- changelog.d/8161.misc | 1 + synapse/storage/databases/main/account_data.py | 4 +-- synapse/storage/databases/main/deviceinbox.py | 4 +-- synapse/storage/databases/main/devices.py | 8 +++-- synapse/storage/databases/main/end_to_end_keys.py | 43 ++++++++++++----------- synapse/storage/databases/main/events.py | 4 +-- synapse/storage/databases/main/group_server.py | 2 +- synapse/storage/databases/main/presence.py | 2 +- synapse/storage/databases/main/push_rule.py | 8 ++--- synapse/storage/databases/main/pusher.py | 4 +-- synapse/storage/databases/main/receipts.py | 3 +- synapse/storage/databases/main/room.py | 6 ++-- synapse/storage/databases/main/tags.py | 4 +-- synapse/storage/util/id_generators.py | 10 +++--- 14 files changed, 54 insertions(+), 49 deletions(-) create mode 100644 changelog.d/8161.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/8161.misc b/changelog.d/8161.misc new file mode 100644 index 0000000000..89ff274de3 --- /dev/null +++ b/changelog.d/8161.misc @@ -0,0 +1 @@ +Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 82aac2bbf3..04042a2c98 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 1f6e995c4f..bb85637a95 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9a786e2929..03b45dbc4d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with self._device_list_id_gen.get_next() as stream_id: + with await self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -1146,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: + with await self._device_list_id_gen.get_next_mult( + len(device_ids) + ) as stream_ids: await self.db_pool.runInteraction( "add_device_change_to_stream", self._add_device_change_to_stream_txn, @@ -1159,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with self._device_list_id_gen.get_next_mult( + with await self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f93e0d320d..385868bdab 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id): """Set a user's cross-signing key. Args: @@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key key (dict): the key data + stream_id (int) """ # the 'key' dict will look something like: # { @@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) # and finally, store the key itself - with self._cross_signing_id_gen.get_next() as stream_id: - self.db_pool.simple_insert_txn( - txn, - "e2e_cross_signing_keys", - values={ - "user_id": user_id, - "keytype": key_type, - "keydata": json_encoder.encode(key), - "stream_id": stream_id, - }, - ) + self.db_pool.simple_insert_txn( + txn, + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json_encoder.encode(key), + "stream_id": stream_id, + }, + ) self._invalidate_cache_and_stream( txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) - def set_e2e_cross_signing_key(self, user_id, key_type, key): + async def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. Args: @@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.db_pool.runInteraction( - "add_e2e_cross_signing_key", - self._set_e2e_cross_signing_key_txn, - user_id, - key_type, - key, - ) + + with await self._cross_signing_id_gen.get_next() as stream_id: + return await self.db_pool.runInteraction( + "add_e2e_cross_signing_key", + self._set_e2e_cross_signing_key_txn, + user_id, + key_type, + key, + stream_id, + ) def store_e2e_cross_signing_signatures(self, user_id, signatures): """Stores cross-signing signatures. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b90e6de2d5..6313b41eef 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -153,11 +153,11 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( + stream_ordering_manager = await self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = self._stream_id_gen.get_next_mult( + stream_ordering_manager = await self._stream_id_gen.get_next_mult( len(events_and_contexts) ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 0e3b8739c6..a488e0924b 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1182,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with self._group_updates_id_gen.get_next() as next_id: + with await self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 4e3ec02d14..c9f655dfb7 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = self._presence_id_gen.get_next_mult( + stream_ordering_manager = await self._presence_id_gen.get_next_mult( len(presence_states) ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index a585e54812..2fb5b02d7d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore): ) async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 1126fd0751..c388468273 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with self._pushers_id_gen.get_next() as stream_id: + with await self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with self._pushers_id_gen.get_next() as stream_id: + with await self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 19ad1c056f..6821476ee0 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -520,8 +520,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - stream_id_manager = self._receipts_id_gen.get_next() - with stream_id_manager as stream_id: + with await self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7d3ac47261..b3772be2b2 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1129,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1196,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1276,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index ade7abc927..0c34bbf21a 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 0bf772d4d1..ddb5c8c60c 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -80,7 +80,7 @@ class StreamIdGenerator(object): upwards, -1 to grow downwards. Usage: - with stream_id_gen.get_next() as stream_id: + with await stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -95,10 +95,10 @@ class StreamIdGenerator(object): ) self._unfinished_ids = deque() # type: Deque[int] - def get_next(self): + async def get_next(self): """ Usage: - with stream_id_gen.get_next() as stream_id: + with await stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -117,10 +117,10 @@ class StreamIdGenerator(object): return manager() - def get_next_mult(self, n): + async def get_next_mult(self, n): """ Usage: - with stream_id_gen.get_next(n) as stream_ids: + with await stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: -- cgit 1.5.1