diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 4d84f4595a..628292b890 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -270,7 +270,7 @@ def start(hs, listeners=None):
# Start the tracer
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
- hs.config
+ hs
)
# It is now safe to start your Synapse.
@@ -316,7 +316,7 @@ def setup_sentry(hs):
scope.set_tag("matrix_server_name", hs.config.server_name)
app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver"
- name = hs.config.worker_name if hs.config.worker_name else "master"
+ name = hs.get_instance_name()
scope.set_tag("worker_app", app)
scope.set_tag("worker_name", name)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 2a56fe0bd5..0ace7b787d 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -127,6 +127,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.data_stores.main.presence import UserPresenceState
+from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
@@ -439,6 +440,7 @@ class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
+ UIAuthWorkerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedReceiptsStore,
@@ -960,17 +962,22 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
- ss = GenericWorkerServer(
+ hs = GenericWorkerServer(
config.server_name,
config=config,
version_string="Synapse/" + get_version_string(synapse),
)
- setup_logging(ss, config, use_worker_options=True)
+ setup_logging(hs, config, use_worker_options=True)
+
+ hs.setup()
+
+ # Ensure the replication streamer is always started in case we write to any
+ # streams. Will no-op if no streams can be written to by this worker.
+ hs.get_replication_streamer()
- ss.setup()
reactor.addSystemEventTrigger(
- "before", "startup", _base.start, ss, config.worker_listeners
+ "before", "startup", _base.start, hs, config.worker_listeners
)
_base.start_worker_reactor("synapse-generic-worker", config)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index bfa9d28999..30d1050a91 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -657,6 +657,12 @@ def read_config_files(config_files):
for config_file in config_files:
with open(config_file) as file_stream:
yaml_config = yaml.safe_load(file_stream)
+
+ if not isinstance(yaml_config, dict):
+ err = "File %r is empty or doesn't parse into a key-value map. IGNORING."
+ print(err % (config_file,))
+ continue
+
specified_config.update(yaml_config)
if "server_name" not in specified_config:
diff --git a/synapse/config/database.py b/synapse/config/database.py
index c27fef157b..5b662d1b01 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -138,7 +138,7 @@ class DatabaseConfig(Config):
database_path = config.get("database_path")
if multi_database_config and database_config:
- raise ConfigError("Can't specify both 'database' and 'datbases' in config")
+ raise ConfigError("Can't specify both 'database' and 'databases' in config")
if multi_database_config:
if database_path:
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index f31fc85ec8..76b8957ea5 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -108,9 +108,14 @@ class EmailConfig(Config):
if self.trusted_third_party_id_servers:
# XXX: It's a little confusing that account_threepid_delegate_email is modified
# both in RegistrationConfig and here. We should factor this bit out
- self.account_threepid_delegate_email = self.trusted_third_party_id_servers[
- 0
- ] # type: Optional[str]
+
+ first_trusted_identity_server = self.trusted_third_party_id_servers[0]
+
+ # trusted_third_party_id_servers does not contain a scheme whereas
+ # account_threepid_delegate_email is expected to. Presume https
+ self.account_threepid_delegate_email = (
+ "https://" + first_trusted_identity_server
+ ) # type: Optional[str]
self.using_identity_server_from_trusted_list = True
else:
raise ConfigError(
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 6cd37d4324..cac6bc0139 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -113,6 +113,30 @@ class SSOConfig(Config):
#
# * server_name: the homeserver's name.
#
+ # * HTML page which notifies the user that they are authenticating to confirm
+ # an operation on their account during the user interactive authentication
+ # process: 'sso_auth_confirm.html'.
+ #
+ # When rendering, this template is given the following variables:
+ # * redirect_url: the URL the user is about to be redirected to. Needs
+ # manual escaping (see
+ # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+ #
+ # * description: the operation which the user is being asked to confirm
+ #
+ # * HTML page shown after a successful user interactive authentication session:
+ # 'sso_auth_success.html'.
+ #
+ # Note that this page must include the JavaScript which notifies of a successful authentication
+ # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
+ #
+ # This template has no additional variables.
+ #
+ # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
+ # attempts to login: 'sso_account_deactivated.html'.
+ #
+ # This template has no additional variables.
+ #
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index d950a8b246..1eec3874b6 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -37,15 +37,16 @@ An attestation is a signed blob of json that looks like:
import logging
import random
+from typing import Tuple
from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
+from synapse.util.async_helpers import yieldable_gather_results
logger = logging.getLogger(__name__)
@@ -162,19 +163,19 @@ class GroupAttestionRenewer(object):
def _start_renew_attestations(self):
return run_as_background_process("renew_attestations", self._renew_attestations)
- @defer.inlineCallbacks
- def _renew_attestations(self):
+ async def _renew_attestations(self):
"""Called periodically to check if we need to update any of our attestations
"""
now = self.clock.time_msec()
- rows = yield self.store.get_attestations_need_renewals(
+ rows = await self.store.get_attestations_need_renewals(
now + UPDATE_ATTESTATION_TIME_MS
)
@defer.inlineCallbacks
- def _renew_attestation(group_id, user_id):
+ def _renew_attestation(group_user: Tuple[str, str]):
+ group_id, user_id = group_user
try:
if not self.is_mine_id(group_id):
destination = get_domain_from_id(group_id)
@@ -207,8 +208,6 @@ class GroupAttestionRenewer(object):
"Error renewing attestation of %r in %r", user_id, group_id
)
- for row in rows:
- group_id = row["group_id"]
- user_id = row["user_id"]
-
- run_in_background(_renew_attestation, group_id, user_id)
+ await yieldable_gather_results(
+ _renew_attestation, ((row["group_id"], row["user_id"]) for row in rows)
+ )
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index dbe165ce1e..7613e5b6ab 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
-from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler
@@ -69,15 +69,6 @@ class AuthHandler(BaseHandler):
self.bcrypt_rounds = hs.config.bcrypt_rounds
- # This is not a cache per se, but a store of all current sessions that
- # expire after N hours
- self.sessions = ExpiringCache(
- cache_name="register_sessions",
- clock=hs.get_clock(),
- expiry_ms=self.SESSION_EXPIRE_MS,
- reset_expiry_on_get=True,
- )
-
account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
@@ -119,6 +110,15 @@ class AuthHandler(BaseHandler):
self._clock = self.hs.get_clock()
+ # Expire old UI auth sessions after a period of time.
+ if hs.config.worker_app is None:
+ self._clock.looping_call(
+ run_as_background_process,
+ 5 * 60 * 1000,
+ "expire_old_sessions",
+ self._expire_old_sessions,
+ )
+
# Load the SSO HTML templates.
# The following template is shown to the user during a client login via SSO,
@@ -301,16 +301,21 @@ class AuthHandler(BaseHandler):
if "session" in authdict:
sid = authdict["session"]
+ # Convert the URI and method to strings.
+ uri = request.uri.decode("utf-8")
+ method = request.uri.decode("utf-8")
+
# If there's no session ID, create a new session.
if not sid:
- session = self._create_session(
- clientdict, (request.uri, request.method, clientdict), description
+ session = await self.store.create_ui_auth_session(
+ clientdict, uri, method, description
)
- session_id = session["id"]
else:
- session = self._get_session_info(sid)
- session_id = sid
+ try:
+ session = await self.store.get_ui_auth_session(sid)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (sid,))
if not clientdict:
# This was designed to allow the client to omit the parameters
@@ -322,15 +327,15 @@ class AuthHandler(BaseHandler):
# on a homeserver.
# Revisit: Assuming the REST APIs do sensible validation, the data
# isn't arbitrary.
- clientdict = session["clientdict"]
+ clientdict = session.clientdict
# Ensure that the queried operation does not vary between stages of
# the UI authentication session. This is done by generating a stable
# comparator based on the URI, method, and body (minus the auth dict)
# and storing it during the initial query. Subsequent queries ensure
# that this comparator has not changed.
- comparator = (request.uri, request.method, clientdict)
- if session["ui_auth"] != comparator:
+ comparator = (uri, method, clientdict)
+ if (session.uri, session.method, session.clientdict) != comparator:
raise SynapseError(
403,
"Requested operation has changed during the UI authentication session.",
@@ -338,11 +343,9 @@ class AuthHandler(BaseHandler):
if not authdict:
raise InteractiveAuthIncompleteError(
- self._auth_dict_for_flows(flows, session_id)
+ self._auth_dict_for_flows(flows, session.session_id)
)
- creds = session["creds"]
-
# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
if "type" in authdict:
@@ -350,8 +353,9 @@ class AuthHandler(BaseHandler):
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
- creds[login_type] = result
- self._save_session(session)
+ await self.store.mark_ui_auth_stage_complete(
+ session.session_id, login_type, result
+ )
except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
@@ -367,6 +371,7 @@ class AuthHandler(BaseHandler):
# so that the client can have another go.
errordict = e.error_dict()
+ creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
@@ -380,9 +385,9 @@ class AuthHandler(BaseHandler):
list(clientdict),
)
- return creds, clientdict, session_id
+ return creds, clientdict, session.session_id
- ret = self._auth_dict_for_flows(flows, session_id)
+ ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
@@ -399,13 +404,11 @@ class AuthHandler(BaseHandler):
if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
- sess = self._get_session_info(authdict["session"])
- creds = sess["creds"]
-
result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
- creds[stagetype] = result
- self._save_session(sess)
+ await self.store.mark_ui_auth_stage_complete(
+ authdict["session"], stagetype, result
+ )
return True
return False
@@ -427,7 +430,7 @@ class AuthHandler(BaseHandler):
sid = authdict["session"]
return sid
- def set_session_data(self, session_id: str, key: str, value: Any) -> None:
+ async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
@@ -438,11 +441,12 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
value: The data to store
"""
- sess = self._get_session_info(session_id)
- sess["serverdict"][key] = value
- self._save_session(sess)
+ try:
+ await self.store.set_ui_auth_session_data(session_id, key, value)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- def get_session_data(
+ async def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
"""
@@ -453,8 +457,18 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
default: Value to return if the key has not been set
"""
- sess = self._get_session_info(session_id)
- return sess["serverdict"].get(key, default)
+ try:
+ return await self.store.get_ui_auth_session_data(session_id, key, default)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
+
+ async def _expire_old_sessions(self):
+ """
+ Invalidate any user interactive authentication sessions that have expired.
+ """
+ now = self._clock.time_msec()
+ expiration_time = now - self.SESSION_EXPIRE_MS
+ await self.store.delete_old_ui_auth_sessions(expiration_time)
async def _check_auth_dict(
self, authdict: Dict[str, Any], clientip: str
@@ -534,67 +548,6 @@ class AuthHandler(BaseHandler):
"params": params,
}
- def _create_session(
- self,
- clientdict: Dict[str, Any],
- ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
- description: str,
- ) -> dict:
- """
- Creates a new user interactive authentication session.
-
- The session can be used to track data across multiple requests, e.g. for
- interactive authentication.
-
- Each session has the following keys:
-
- id:
- A unique identifier for this session. Passed back to the client
- and returned for each stage.
- clientdict:
- The dictionary from the client root level, not the 'auth' key.
- ui_auth:
- A tuple which is checked at each stage of the authentication to
- ensure that the asked for operation has not changed.
- creds:
- A map, which maps each auth-type (str) to the relevant identity
- authenticated by that auth-type (mostly str, but for captcha, bool).
- serverdict:
- A map of data that is stored server-side and cannot be modified
- by the client.
- description:
- A string description of the operation that the current
- authentication is authorising.
- Returns:
- The newly created session.
- """
- session_id = None
- while session_id is None or session_id in self.sessions:
- session_id = stringutils.random_string(24)
-
- self.sessions[session_id] = {
- "id": session_id,
- "clientdict": clientdict,
- "ui_auth": ui_auth,
- "creds": {},
- "serverdict": {},
- "description": description,
- }
-
- return self.sessions[session_id]
-
- def _get_session_info(self, session_id: str) -> dict:
- """
- Gets a session given a session ID.
-
- The session can be used to track data across multiple requests, e.g. for
- interactive authentication.
- """
- try:
- return self.sessions[session_id]
- except KeyError:
- raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
-
async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
@@ -994,13 +947,6 @@ class AuthHandler(BaseHandler):
await self.store.user_delete_threepid(user_id, medium, address)
return result
- def _save_session(self, session: Dict[str, Any]) -> None:
- """Update the last used time on the session to now and add it back to the session store."""
- # TODO: Persistent storage
- logger.debug("Saving session %s", session)
- session["last_used"] = self.hs.get_clock().time_msec()
- self.sessions[session["id"]] = session
-
async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
@@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler):
else:
return False
- def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+ async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
@@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler):
Returns:
The HTML to render.
"""
- session = self._get_session_info(session_id)
+ try:
+ session = await self.store.get_ui_auth_session(session_id)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
return self._sso_auth_confirm_template.render(
- description=session["description"], redirect_url=redirect_url,
+ description=session.description, redirect_url=redirect_url,
)
- def complete_sso_ui_auth(
+ async def complete_sso_ui_auth(
self, registered_user_id: str, session_id: str, request: SynapseRequest,
):
"""Having figured out a mxid for this user, complete the HTTP request
@@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler):
process.
"""
# Mark the stage of the authentication as successful.
- sess = self._get_session_info(session_id)
- creds = sess["creds"]
-
# Save the user who authenticated with SSO, this will be used to ensure
# that the account be modified is also the person who logged in.
- creds[LoginType.SSO] = registered_user_id
- self._save_session(sess)
+ await self.store.mark_ui_auth_stage_complete(
+ session_id, LoginType.SSO, registered_user_id
+ )
# Render the HTML and return.
html_bytes = self._sso_auth_success_template.encode("utf-8")
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 5cb3f9d133..64aaa1335c 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -206,7 +206,7 @@ class CasHandler:
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if session:
- self._auth_handler.complete_sso_ui_auth(
+ await self._auth_handler.complete_sso_ui_auth(
registered_user_id, session, request,
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c7aa7acf3b..41b96c0a73 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -343,7 +343,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
- state_maps = list(ours.values()) # type: list[StateMap[str]]
+ state_maps = list(ours.values()) # type: List[StateMap[str]]
# we don't need this any more, let's delete it.
del ours
@@ -1694,16 +1694,15 @@ class FederationHandler(BaseHandler):
return None
- @defer.inlineCallbacks
- def get_state_for_pdu(self, room_id, event_id):
+ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event.
"""
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
+ state_groups = await self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(iteritems(state_groups)).pop()
@@ -1714,7 +1713,7 @@ class FederationHandler(BaseHandler):
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
- prev_event = yield self.store.get_event(prev_id)
+ prev_event = await self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event
else:
del results[(event.type, event.state_key)]
@@ -1724,15 +1723,14 @@ class FederationHandler(BaseHandler):
else:
return []
- @defer.inlineCallbacks
- def get_state_ids_for_pdu(self, room_id, event_id):
+ async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
+ state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
@@ -1751,49 +1749,50 @@ class FederationHandler(BaseHandler):
else:
return []
- @defer.inlineCallbacks
@log_function
- def on_backfill_request(self, origin, room_id, pdu_list, limit):
- in_room = yield self.auth.check_host_in_room(room_id, origin)
+ async def on_backfill_request(
+ self, origin: str, room_id: str, pdu_list: List[str], limit: int
+ ) -> List[EventBase]:
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
# Synapse asks for 100 events per backfill request. Do not allow more.
limit = min(limit, 100)
- events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
+ events = await self.store.get_backfill_events(room_id, pdu_list, limit)
- events = yield filter_events_for_server(self.storage, origin, events)
+ events = await filter_events_for_server(self.storage, origin, events)
return events
- @defer.inlineCallbacks
@log_function
- def get_persisted_pdu(self, origin, event_id):
+ async def get_persisted_pdu(
+ self, origin: str, event_id: str
+ ) -> Optional[EventBase]:
"""Get an event from the database for the given server.
Args:
- origin [str]: hostname of server which is requesting the event; we
+ origin: hostname of server which is requesting the event; we
will check that the server is allowed to see it.
- event_id [str]: id of the event being requested
+ event_id: id of the event being requested
Returns:
- Deferred[EventBase|None]: None if we know nothing about the event;
- otherwise the (possibly-redacted) event.
+ None if we know nothing about the event; otherwise the (possibly-redacted) event.
Raises:
AuthError if the server is not currently in the room
"""
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
if event:
- in_room = yield self.auth.check_host_in_room(event.room_id, origin)
+ in_room = await self.auth.check_host_in_room(event.room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- events = yield filter_events_for_server(self.storage, origin, [event])
+ events = await filter_events_for_server(self.storage, origin, [event])
event = events[0]
return event
else:
@@ -2397,7 +2396,7 @@ class FederationHandler(BaseHandler):
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
- event_key = (event.type, event.state_key)
+ event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
else:
event_key = None
state_updates = {
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 6c52951d7a..2aeceeaa6c 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -91,7 +91,11 @@ class RoomListHandler(BaseHandler):
logger.info("Bypassing cache as search request.")
return self._get_public_room_list(
- limit, since_token, search_filter, network_tuple=network_tuple
+ limit,
+ since_token,
+ search_filter,
+ network_tuple=network_tuple,
+ from_federation=from_federation,
)
key = (limit, since_token, network_tuple)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 7c9454b504..96f2dd36ad 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -149,7 +149,7 @@ class SamlHandler:
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
- self._auth_handler.complete_sso_ui_auth(
+ await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 0638cec429..5dddf57008 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -171,7 +171,7 @@ import logging
import re
import types
from functools import wraps
-from typing import Dict
+from typing import TYPE_CHECKING, Dict
from canonicaljson import json
@@ -179,6 +179,9 @@ from twisted.internet import defer
from synapse.config import ConfigError
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# Helper class
@@ -297,14 +300,11 @@ def _noop_context_manager(*args, **kwargs):
# Setup
-def init_tracer(config):
+def init_tracer(hs: "HomeServer"):
"""Set the whitelists and initialise the JaegerClient tracer
-
- Args:
- config (HomeserverConfig): The config used by the homeserver
"""
global opentracing
- if not config.opentracer_enabled:
+ if not hs.config.opentracer_enabled:
# We don't have a tracer
opentracing = None
return
@@ -315,18 +315,15 @@ def init_tracer(config):
"installed."
)
- # Include the worker name
- name = config.worker_name if config.worker_name else "master"
-
# Pull out the jaeger config if it was given. Otherwise set it to something sensible.
# See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
- set_homeserver_whitelist(config.opentracer_whitelist)
+ set_homeserver_whitelist(hs.config.opentracer_whitelist)
JaegerConfig(
- config=config.jaeger_config,
- service_name="{} {}".format(config.server_name, name),
- scope_manager=LogContextScopeManager(config),
+ config=hs.config.jaeger_config,
+ service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
+ scope_manager=LogContextScopeManager(hs.config),
).initialize_tracer()
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 6132727cbd..88a5a97caf 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -220,12 +220,6 @@ class Notifier(object):
"""
self.replication_callbacks.append(cb)
- def add_remote_server_up_callback(self, cb: Callable[[str], None]):
- """Add a callback that will be called when synapse detects a server
- has been
- """
- self.remote_server_up_callbacks.append(cb)
-
def on_new_room_event(
self, event, room_stream_id, max_room_stream_id, extra_users=[]
):
@@ -544,6 +538,3 @@ class Notifier(object):
# circular dependencies.
if self.federation_sender:
self.federation_sender.wake_destination(server)
-
- for cb in self.remote_server_up_callbacks:
- cb(server)
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index c7880d4b63..f58e384d17 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -95,7 +95,7 @@ class RdataCommand(Command):
Format::
- RDATA <stream_name> <token> <row_json>
+ RDATA <stream_name> <instance_name> <token> <row_json>
The `<token>` may either be a numeric stream id OR "batch". The latter case
is used to support sending multiple updates with the same stream ID. This
@@ -105,33 +105,40 @@ class RdataCommand(Command):
The client should batch all incoming RDATA with a token of "batch" (per
stream_name) until it sees an RDATA with a numeric stream ID.
+ The `<instance_name>` is the source of the new data (usually "master").
+
`<token>` of "batch" maps to the instance variable `token` being None.
An example of a batched series of RDATA::
- RDATA presence batch ["@foo:example.com", "online", ...]
- RDATA presence batch ["@bar:example.com", "online", ...]
- RDATA presence 59 ["@baz:example.com", "online", ...]
+ RDATA presence master batch ["@foo:example.com", "online", ...]
+ RDATA presence master batch ["@bar:example.com", "online", ...]
+ RDATA presence master 59 ["@baz:example.com", "online", ...]
"""
NAME = "RDATA"
- def __init__(self, stream_name, token, row):
+ def __init__(self, stream_name, instance_name, token, row):
self.stream_name = stream_name
+ self.instance_name = instance_name
self.token = token
self.row = row
@classmethod
def from_line(cls, line):
- stream_name, token, row_json = line.split(" ", 2)
+ stream_name, instance_name, token, row_json = line.split(" ", 3)
return cls(
- stream_name, None if token == "batch" else int(token), json.loads(row_json)
+ stream_name,
+ instance_name,
+ None if token == "batch" else int(token),
+ json.loads(row_json),
)
def to_line(self):
return " ".join(
(
self.stream_name,
+ self.instance_name,
str(self.token) if self.token is not None else "batch",
_json_encoder.encode(self.row),
)
@@ -145,23 +152,31 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
+ Format::
+
+ POSITION <stream_name> <instance_name> <token>
+
On receipt of a POSITION command clients should check if they have missed
any updates, and if so then fetch them out of band.
+
+ The `<instance_name>` is the process that sent the command and is the source
+ of the stream.
"""
NAME = "POSITION"
- def __init__(self, stream_name, token):
+ def __init__(self, stream_name, instance_name, token):
self.stream_name = stream_name
+ self.instance_name = instance_name
self.token = token
@classmethod
def from_line(cls, line):
- stream_name, token = line.split(" ", 1)
- return cls(stream_name, int(token))
+ stream_name, instance_name, token = line.split(" ", 2)
+ return cls(stream_name, instance_name, int(token))
def to_line(self):
- return " ".join((self.stream_name, str(self.token)))
+ return " ".join((self.stream_name, self.instance_name, str(self.token)))
class ErrorCommand(_SimpleCommand):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 0db5a3a24d..6f7054d5af 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -79,6 +79,7 @@ class ReplicationCommandHandler:
self._notifier = hs.get_notifier()
self._clock = hs.get_clock()
self._instance_id = hs.get_instance_id()
+ self._instance_name = hs.get_instance_name()
# Set of streams that we've caught up with.
self._streams_connected = set() # type: Set[str]
@@ -87,7 +88,9 @@ class ReplicationCommandHandler:
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
- self._position_linearizer = Linearizer("replication_position")
+ self._position_linearizer = Linearizer(
+ "replication_position", clock=self._clock
+ )
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
@@ -115,7 +118,6 @@ class ReplicationCommandHandler:
self._server_notices_sender = None
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
- self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@@ -155,13 +157,13 @@ class ReplicationCommandHandler:
hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
)
else:
- client_name = hs.config.worker_name
+ client_name = hs.get_instance_name()
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
- async def on_REPLICATE(self, cmd: ReplicateCommand):
+ async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self._is_master:
@@ -169,9 +171,11 @@ class ReplicationCommandHandler:
for stream_name, stream in self._streams.items():
current_token = stream.current_token()
- self.send_command(PositionCommand(stream_name, current_token))
+ self.send_command(
+ PositionCommand(stream_name, self._instance_name, current_token)
+ )
- async def on_USER_SYNC(self, cmd: UserSyncCommand):
+ async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
user_sync_counter.inc()
if self._is_master:
@@ -179,17 +183,23 @@ class ReplicationCommandHandler:
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
- async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
+ async def on_CLEAR_USER_SYNC(
+ self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+ ):
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
- async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
+ async def on_FEDERATION_ACK(
+ self, conn: AbstractConnection, cmd: FederationAckCommand
+ ):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.token)
- async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
+ async def on_REMOVE_PUSHER(
+ self, conn: AbstractConnection, cmd: RemovePusherCommand
+ ):
remove_pusher_counter.inc()
if self._is_master:
@@ -199,7 +209,9 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
- async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
+ async def on_INVALIDATE_CACHE(
+ self, conn: AbstractConnection, cmd: InvalidateCacheCommand
+ ):
invalidate_cache_counter.inc()
if self._is_master:
@@ -209,7 +221,7 @@ class ReplicationCommandHandler:
cmd.cache_func, tuple(cmd.keys)
)
- async def on_USER_IP(self, cmd: UserIpCommand):
+ async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
if self._is_master:
@@ -225,7 +237,11 @@ class ReplicationCommandHandler:
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
- async def on_RDATA(self, cmd: RdataCommand):
+ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+ if cmd.instance_name == self._instance_name:
+ # Ignore RDATA that are just our own echoes
+ return
+
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@@ -276,7 +292,11 @@ class ReplicationCommandHandler:
logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows)
- async def on_POSITION(self, cmd: PositionCommand):
+ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+ if cmd.instance_name == self._instance_name:
+ # Ignore POSITION that are just our own echoes
+ return
+
stream = self._streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
@@ -330,12 +350,30 @@ class ReplicationCommandHandler:
self._streams_connected.add(cmd.stream_name)
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ async def on_REMOTE_SERVER_UP(
+ self, conn: AbstractConnection, cmd: RemoteServerUpCommand
+ ):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
- if self._is_master:
- self._notifier.notify_remote_server_up(cmd.data)
+ self._notifier.notify_remote_server_up(cmd.data)
+
+ # We relay to all other connections to ensure every instance gets the
+ # notification.
+ #
+ # When configured to use redis we'll always only have one connection and
+ # so this is a no-op (all instances will have already received the same
+ # REMOTE_SERVER_UP command).
+ #
+ # For direct TCP connections this will relay to all other connections
+ # connected to us. When on master this will correctly fan out to all
+ # other direct TCP clients and on workers there'll only be the one
+ # connection to master.
+ #
+ # (The logic here should also be sound if we have a mix of Redis and
+ # direct TCP connections so long as there is only one traffic route
+ # between two instances, but that is not currently supported).
+ self.send_command(cmd, ignore_conn=conn)
def new_connection(self, connection: AbstractConnection):
"""Called when we have a new connection.
@@ -380,11 +418,21 @@ class ReplicationCommandHandler:
"""
return bool(self._connections)
- def send_command(self, cmd: Command):
+ def send_command(
+ self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+ ):
"""Send a command to all connected connections.
+
+ Args:
+ cmd
+ ignore_conn: If set don't send command to the given connection.
+ Used when relaying commands from one connection to all others.
"""
if self._connections:
for connection in self._connections:
+ if connection == ignore_conn:
+ continue
+
try:
connection.send_command(cmd)
except Exception:
@@ -448,7 +496,7 @@ class ReplicationCommandHandler:
We need to check if the client is interested in the stream or not
"""
- self.send_command(RdataCommand(stream_name, token, data))
+ self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
UpdateToken = TypeVar("UpdateToken")
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e3f64eba8f..4198eece71 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(cmd)
+ await cmd_func(self, cmd)
handled = True
if not handled:
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 49b3ed0c5e..617e860f95 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(cmd)
+ await cmd_func(self, cmd)
handled = True
if not handled:
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index b2d6baa2a2..33d2f589ac 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,9 +17,7 @@
import logging
import random
-from typing import Dict
-
-from six import itervalues
+from typing import Dict, List
from prometheus_client import Counter
@@ -71,29 +69,28 @@ class ReplicationStreamer(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
- self._server_notices_sender = hs.get_server_notices_sender()
self._replication_torture_level = hs.config.replication_torture_level
- # List of streams that clients can subscribe to.
- # We only support federation stream if federation sending hase been
- # disabled on the master.
- self.streams = [
- stream(hs)
- for stream in itervalues(STREAMS_MAP)
- if stream != FederationStream or not hs.config.send_federation
- ]
+ # Work out list of streams that this instance is the source of.
+ self.streams = [] # type: List[Stream]
+ if hs.config.worker_app is None:
+ for stream in STREAMS_MAP.values():
+ if stream == FederationStream and hs.config.send_federation:
+ # We only support federation stream if federation sending
+ # hase been disabled on the master.
+ continue
- self.streams_by_name = {stream.NAME: stream for stream in self.streams}
+ self.streams.append(stream(hs))
- self.federation_sender = None
- if not hs.config.send_federation:
- self.federation_sender = hs.get_federation_sender()
+ self.streams_by_name = {stream.NAME: stream for stream in self.streams}
- self.notifier.add_replication_callback(self.on_notifier_poke)
+ # Only bother registering the notifier callback if we have streams to
+ # publish.
+ if self.streams:
+ self.notifier.add_replication_callback(self.on_notifier_poke)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4ae3cffb1e..4af1afd119 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -176,10 +176,9 @@ def db_query_to_update_function(
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
- if len(updates) == limit:
+ if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
- assert len(updates) <= limit
return updates, upto_token, limited
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index aa50492569..52df81b1bd 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -170,22 +170,16 @@ class EventsStream(Stream):
limited = False
upper_limit = current_token
- # next up is the state delta table
-
- state_rows = await self._store.get_all_updated_current_state_deltas(
+ # next up is the state delta table.
+ (
+ state_rows,
+ upper_limit,
+ state_rows_limited,
+ ) = await self._store.get_all_updated_current_state_deltas(
from_token, upper_limit, target_row_count
- ) # type: List[Tuple]
-
- # again, if we've hit the limit there, we'll need to limit the other sources
- assert len(state_rows) < target_row_count
- if len(state_rows) == target_row_count:
- assert state_rows[-1][0] <= upper_limit
- upper_limit = state_rows[-1][0]
- limited = True
+ )
- # FIXME: is it a given that there is only one row per stream_id in the
- # state_deltas table (so that we can be sure that we have got all of the
- # rows for upper_limit)?
+ limited = limited or state_rows_limited
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 8551ac19b8..593ce011e8 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -94,10 +94,10 @@ class UsersRestServletV2(RestServlet):
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
- users = await self.store.get_users_paginate(
+ users, total = await self.store.get_users_paginate(
start, limit, user_id, guests, deactivated
)
- ret = {"users": users}
+ ret = {"users": users, "total": total}
if len(users) >= limit:
ret["next_token"] = str(start + len(users))
@@ -199,7 +199,7 @@ class UserRestServletV2(RestServlet):
user_id, threepid["medium"], threepid["address"], current_time
)
- if "avatar_url" in body:
+ if "avatar_url" in body and type(body["avatar_url"]) == str:
await self.profile_handler.set_avatar_url(
target_user, requester, body["avatar_url"], True
)
@@ -276,7 +276,7 @@ class UserRestServletV2(RestServlet):
user_id, threepid["medium"], threepid["address"], current_time
)
- if "avatar_url" in body:
+ if "avatar_url" in body and type(body["avatar_url"]) == str:
await self.profile_handler.set_avatar_url(
user_id, requester, body["avatar_url"], True
)
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 11599f5005..24dd3d3e96 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -140,7 +140,7 @@ class AuthRestServlet(RestServlet):
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
- def on_GET(self, request, stagetype):
+ async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
@@ -180,7 +180,7 @@ class AuthRestServlet(RestServlet):
else:
raise SynapseError(400, "Homeserver not configured for SSO.")
- html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+ html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
else:
raise SynapseError(404, "Unknown auth stage type")
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index d1b5c49989..af08cc6cce 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -499,7 +499,7 @@ class RegisterRestServlet(RestServlet):
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
- registered_user_id = self.auth_handler.get_session_data(
+ registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
@@ -598,7 +598,7 @@ class RegisterRestServlet(RestServlet):
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
- self.auth_handler.set_session_data(
+ await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
diff --git a/synapse/server.py b/synapse/server.py
index 9d273c980c..bf97a16c09 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -234,7 +234,8 @@ class HomeServer(object):
self._listening_services = []
self.start_time = None
- self.instance_id = random_string(5)
+ self._instance_id = random_string(5)
+ self._instance_name = config.worker_name or "master"
self.clock = Clock(reactor)
self.distributor = Distributor()
@@ -254,7 +255,15 @@ class HomeServer(object):
This is used to distinguish running instances in worker-based
deployments.
"""
- return self.instance_id
+ return self._instance_id
+
+ def get_instance_name(self) -> str:
+ """A unique name for this synapse process.
+
+ Used to identify the process over replication and in config. Does not
+ change over restarts.
+ """
+ return self._instance_name
def setup(self):
logger.info("Setting up.")
diff --git a/synapse/server.pyi b/synapse/server.pyi
index f1a5717028..18043a2593 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
+from synapse.events.builder import EventBuilderFactory
class HomeServer(object):
@property
@@ -121,3 +122,9 @@ class HomeServer(object):
pass
def get_instance_id(self) -> str:
pass
+ def get_instance_name(self) -> str:
+ pass
+ def get_event_builder_factory(self) -> EventBuilderFactory:
+ pass
+ def get_storage(self) -> synapse.storage.Storage:
+ pass
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 649e835303..ceba10882c 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -66,6 +66,7 @@ from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
+from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
@@ -112,6 +113,7 @@ class DataStore(
StatsStore,
RelationsStore,
CacheInvalidationStore,
+ UIAuthStore,
):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
@@ -503,7 +505,8 @@ class DataStore(
self, start, limit, name=None, guests=True, deactivated=False
):
"""Function to retrieve a paginated list of users from
- users list. This will return a json list of users.
+ users list. This will return a json list of users and the
+ total number of users matching the filter criteria.
Args:
start (int): start number to begin the query from
@@ -512,35 +515,44 @@ class DataStore(
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ defer.Deferred: resolves to list[dict[str, Any]], int
"""
- name_filter = {}
- if name:
- name_filter["name"] = "%" + name + "%"
-
- attr_filter = {}
- if not guests:
- attr_filter["is_guest"] = 0
- if not deactivated:
- attr_filter["deactivated"] = 0
-
- return self.db.simple_select_list_paginate(
- desc="get_users_paginate",
- table="users",
- orderby="name",
- start=start,
- limit=limit,
- filters=name_filter,
- keyvalues=attr_filter,
- retcols=[
- "name",
- "password_hash",
- "is_guest",
- "admin",
- "user_type",
- "deactivated",
- ],
- )
+
+ def get_users_paginate_txn(txn):
+ filters = []
+ args = []
+
+ if name:
+ filters.append("name LIKE ?")
+ args.append("%" + name + "%")
+
+ if not guests:
+ filters.append("is_guest = 0")
+
+ if not deactivated:
+ filters.append("deactivated = 0")
+
+ where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+ sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ args = [self.hs.config.server_name] + args + [limit, start]
+ sql = """
+ SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+ FROM users as u
+ LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
+ {}
+ ORDER BY u.name LIMIT ? OFFSET ?
+ """.format(
+ where_clause
+ )
+ txn.execute(sql, args)
+ users = self.db.cursor_to_dict(txn)
+ return users, count
+
+ return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn)
def search_users(self, term):
"""Function to search users list for one or more users with
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index ce8be72bfe..73df6b33ba 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,7 +19,7 @@ import itertools
import logging
import threading
from collections import namedtuple
-from typing import List, Optional
+from typing import List, Optional, Tuple
from canonicaljson import json
from constantly import NamedConstant, Names
@@ -1084,7 +1084,28 @@ class EventsWorkerStore(SQLBaseStore):
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
- def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+ async def get_all_updated_current_state_deltas(
+ self, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple], int, bool]:
+ """Fetch updates from current_state_delta_stream
+
+ Args:
+ from_token: The previous stream token. Updates from this stream id will
+ be excluded.
+
+ to_token: The current stream token (ie the upper limit). Updates up to this
+ stream id will be included (modulo the 'limit' param)
+
+ target_row_count: The number of rows to try to return. If more rows are
+ available, we will set 'limited' in the result. In the event of a large
+ batch, we may return more rows than this.
+ Returns:
+ A triplet `(updates, new_last_token, limited)`, where:
+ * `updates` is a list of database tuples.
+ * `new_last_token` is the new position in stream.
+ * `limited` is whether there are more updates to fetch.
+ """
+
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
@@ -1092,10 +1113,45 @@ class EventsWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
- txn.execute(sql, (from_token, to_token, limit))
+ txn.execute(sql, (from_token, to_token, target_row_count))
return txn.fetchall()
- return self.db.runInteraction(
+ def get_deltas_for_stream_id_txn(txn, stream_id):
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id
+ FROM current_state_delta_stream
+ WHERE stream_id = ?
+ """
+ txn.execute(sql, [stream_id])
+ return txn.fetchall()
+
+ # we need to make sure that, for every stream id in the results, we get *all*
+ # the rows with that stream id.
+
+ rows = await self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
+ ) # type: List[Tuple]
+
+ # if we've got fewer rows than the limit, we're good
+ if len(rows) < target_row_count:
+ return rows, to_token, False
+
+ # we hit the limit, so reduce the upper limit so that we exclude the stream id
+ # of the last row in the result.
+ assert rows[-1][0] <= to_token
+ to_token = rows[-1][0] - 1
+
+ # search backwards through the list for the point to truncate
+ for idx in range(len(rows) - 1, 0, -1):
+ if rows[idx - 1][0] <= to_token:
+ return rows[:idx], to_token, True
+
+ # bother. We didn't get a full set of changes for even a single
+ # stream id. let's run the query again, without a row limit, but for
+ # just one stream id.
+ to_token += 1
+ rows = await self.db.runInteraction(
+ "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
)
+ return rows, to_token, True
diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql
new file mode 100644
index 0000000000..dcb593fc2d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql
@@ -0,0 +1,36 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS ui_auth_sessions(
+ session_id TEXT NOT NULL, -- The session ID passed to the client.
+ creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds).
+ serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse.
+ clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client.
+ uri TEXT NOT NULL, -- The URI the UI authentication session is using.
+ method TEXT NOT NULL, -- The HTTP method the UI authentication session is using.
+ -- The clientdict, uri, and method make up an tuple that must be immutable
+ -- throughout the lifetime of the UI Auth session.
+ description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur.
+ UNIQUE (session_id)
+);
+
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials(
+ session_id TEXT NOT NULL, -- The corresponding UI Auth session.
+ stage_type TEXT NOT NULL, -- The stage type.
+ result TEXT NOT NULL, -- The result of the stage verification, stored as JSON.
+ UNIQUE (session_id, stage_type),
+ FOREIGN KEY (session_id)
+ REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
new file mode 100644
index 0000000000..c8eebc9378
--- /dev/null
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -0,0 +1,279 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+from typing import Any, Dict, Optional, Union
+
+import attr
+
+import synapse.util.stringutils as stringutils
+from synapse.api.errors import StoreError
+from synapse.storage._base import SQLBaseStore
+from synapse.types import JsonDict
+
+
+@attr.s
+class UIAuthSessionData:
+ session_id = attr.ib(type=str)
+ # The dictionary from the client root level, not the 'auth' key.
+ clientdict = attr.ib(type=JsonDict)
+ # The URI and method the session was intiatied with. These are checked at
+ # each stage of the authentication to ensure that the asked for operation
+ # has not changed.
+ uri = attr.ib(type=str)
+ method = attr.ib(type=str)
+ # A string description of the operation that the current authentication is
+ # authorising.
+ description = attr.ib(type=str)
+
+
+class UIAuthWorkerStore(SQLBaseStore):
+ """
+ Manage user interactive authentication sessions.
+ """
+
+ async def create_ui_auth_session(
+ self, clientdict: JsonDict, uri: str, method: str, description: str,
+ ) -> UIAuthSessionData:
+ """
+ Creates a new user interactive authentication session.
+
+ The session can be used to track the stages necessary to authenticate a
+ user across multiple HTTP requests.
+
+ Args:
+ clientdict:
+ The dictionary from the client root level, not the 'auth' key.
+ uri:
+ The URI this session was initiated with, this is checked at each
+ stage of the authentication to ensure that the asked for
+ operation has not changed.
+ method:
+ The method this session was initiated with, this is checked at each
+ stage of the authentication to ensure that the asked for
+ operation has not changed.
+ description:
+ A string description of the operation that the current
+ authentication is authorising.
+ Returns:
+ The newly created session.
+ Raises:
+ StoreError if a unique session ID cannot be generated.
+ """
+ # The clientdict gets stored as JSON.
+ clientdict_json = json.dumps(clientdict)
+
+ # autogen a session ID and try to create it. We may clash, so just
+ # try a few times till one goes through, giving up eventually.
+ attempts = 0
+ while attempts < 5:
+ session_id = stringutils.random_string(24)
+
+ try:
+ await self.db.simple_insert(
+ table="ui_auth_sessions",
+ values={
+ "session_id": session_id,
+ "clientdict": clientdict_json,
+ "uri": uri,
+ "method": method,
+ "description": description,
+ "serverdict": "{}",
+ "creation_time": self.hs.get_clock().time_msec(),
+ },
+ desc="create_ui_auth_session",
+ )
+ return UIAuthSessionData(
+ session_id, clientdict, uri, method, description
+ )
+ except self.db.engine.module.IntegrityError:
+ attempts += 1
+ raise StoreError(500, "Couldn't generate a session ID.")
+
+ async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
+ """Retrieve a UI auth session.
+
+ Args:
+ session_id: The ID of the session.
+ Returns:
+ A dict containing the device information.
+ Raises:
+ StoreError if the session is not found.
+ """
+ result = await self.db.simple_select_one(
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcols=("clientdict", "uri", "method", "description"),
+ desc="get_ui_auth_session",
+ )
+
+ result["clientdict"] = json.loads(result["clientdict"])
+
+ return UIAuthSessionData(session_id, **result)
+
+ async def mark_ui_auth_stage_complete(
+ self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
+ ):
+ """
+ Mark a session stage as completed.
+
+ Args:
+ session_id: The ID of the corresponding session.
+ stage_type: The completed stage type.
+ result: The result of the stage verification.
+ Raises:
+ StoreError if the session cannot be found.
+ """
+ # Add (or update) the results of the current stage to the database.
+ #
+ # Note that we need to allow for the same stage to complete multiple
+ # times here so that registration is idempotent.
+ try:
+ await self.db.simple_upsert(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id, "stage_type": stage_type},
+ values={"result": json.dumps(result)},
+ desc="mark_ui_auth_stage_complete",
+ )
+ except self.db.engine.module.IntegrityError:
+ raise StoreError(400, "Unknown session ID: %s" % (session_id,))
+
+ async def get_completed_ui_auth_stages(
+ self, session_id: str
+ ) -> Dict[str, Union[str, bool, JsonDict]]:
+ """
+ Retrieve the completed stages of a UI authentication session.
+
+ Args:
+ session_id: The ID of the session.
+ Returns:
+ The completed stages mapped to the result of the verification of
+ that auth-type.
+ """
+ results = {}
+ for row in await self.db.simple_select_list(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id},
+ retcols=("stage_type", "result"),
+ desc="get_completed_ui_auth_stages",
+ ):
+ results[row["stage_type"]] = json.loads(row["result"])
+
+ return results
+
+ async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
+ """
+ Store a key-value pair into the sessions data associated with this
+ request. This data is stored server-side and cannot be modified by
+ the client.
+
+ Args:
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ value: The data to store
+ Raises:
+ StoreError if the session cannot be found.
+ """
+ await self.db.runInteraction(
+ "set_ui_auth_session_data",
+ self._set_ui_auth_session_data_txn,
+ session_id,
+ key,
+ value,
+ )
+
+ def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+ # Get the current value.
+ result = self.db.simple_select_one_txn(
+ txn,
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcols=("serverdict",),
+ )
+
+ # Update it and add it back to the database.
+ serverdict = json.loads(result["serverdict"])
+ serverdict[key] = value
+
+ self.db.simple_update_one_txn(
+ txn,
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ updatevalues={"serverdict": json.dumps(serverdict)},
+ )
+
+ async def get_ui_auth_session_data(
+ self, session_id: str, key: str, default: Optional[Any] = None
+ ) -> Any:
+ """
+ Retrieve data stored with set_session_data
+
+ Args:
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ default: Value to return if the key has not been set
+ Raises:
+ StoreError if the session cannot be found.
+ """
+ result = await self.db.simple_select_one(
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcols=("serverdict",),
+ desc="get_ui_auth_session_data",
+ )
+
+ serverdict = json.loads(result["serverdict"])
+
+ return serverdict.get(key, default)
+
+
+class UIAuthStore(UIAuthWorkerStore):
+ def delete_old_ui_auth_sessions(self, expiration_time: int):
+ """
+ Remove sessions which were last used earlier than the expiration time.
+
+ Args:
+ expiration_time: The latest time that is still considered valid.
+ This is an epoch time in milliseconds.
+
+ """
+ return self.db.runInteraction(
+ "delete_old_ui_auth_sessions",
+ self._delete_old_ui_auth_sessions_txn,
+ expiration_time,
+ )
+
+ def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+ # Get the expired sessions.
+ sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
+ txn.execute(sql, [expiration_time])
+ session_ids = [r[0] for r in txn.fetchall()]
+
+ # Delete the corresponding completed credentials.
+ self.db.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
+
+ # Finally, delete the sessions.
+ self.db.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 3bc2e8b986..215a949442 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -85,6 +85,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank)
+ db_conn.execute("PRAGMA foreign_keys = ON;")
def is_deadlock(self, error):
return False
|