From ae38e0569fdb592875b45d4b37f200af5f6a86fc Mon Sep 17 00:00:00 2001 From: Half-Shot Date: Fri, 23 Aug 2019 09:15:10 +0100 Subject: Ignore consent for support users --- synapse/storage/registration.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/storage/registration.py') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 55e4e84d71..938fd00717 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -56,6 +56,7 @@ class RegistrationWorkerStore(SQLBaseStore): "consent_server_notice_sent", "appservice_id", "creation_ts", + "user_type", ], allow_none=True, desc="get_user_by_id", -- cgit 1.5.1 From 1a7e6eb63387704ef379bf962318f710ce5ae5f3 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Tue, 27 Aug 2019 10:14:00 +0100 Subject: Add Admin API capability to set adminship of a user (#5878) Admin API: Set adminship of a user --- changelog.d/5878.feature | 1 + docs/admin_api/user_admin_api.rst | 20 +++++++++++ synapse/handlers/admin.py | 10 ++++++ synapse/rest/admin/__init__.py | 2 ++ synapse/rest/admin/users.py | 76 +++++++++++++++++++++++++++++++++++++++ synapse/storage/registration.py | 23 ++++++++++++ 6 files changed, 132 insertions(+) create mode 100644 changelog.d/5878.feature create mode 100644 synapse/rest/admin/users.py (limited to 'synapse/storage/registration.py') diff --git a/changelog.d/5878.feature b/changelog.d/5878.feature new file mode 100644 index 0000000000..d9d6df880e --- /dev/null +++ b/changelog.d/5878.feature @@ -0,0 +1 @@ +Add admin API endpoint for setting whether or not a user is a server administrator. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 213359d0c0..6ee5080eed 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -84,3 +84,23 @@ with a body of: } including an ``access_token`` of a server admin. + + +Change whether a user is a server administrator or not +====================================================== + +Note that you cannot demote yourself. + +The api is:: + + PUT /_synapse/admin/v1/users//admin + +with a body of: + +.. code:: json + + { + "admin": true + } + +including an ``access_token`` of a server admin. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 2f22f56ca4..d30a68b650 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -94,6 +94,16 @@ class AdminHandler(BaseHandler): return ret + def set_user_server_admin(self, user, admin): + """ + Set the admin bit on a user. + + Args: + user_id (UserID): the (necessarily local) user to manipulate + admin (bool): whether or not the user should be an admin of this server + """ + return self.store.set_server_admin(user, admin) + @defer.inlineCallbacks def export_user_data(self, user_id, writer): """Write all data we have on the user to the given writer. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 0dce256840..9ab1c2c9e0 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.admin._base import ( from synapse.rest.admin.media import register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet +from synapse.rest.admin.users import UserAdminServlet from synapse.types import UserID, create_requester from synapse.util.versionstring import get_version_string @@ -742,6 +743,7 @@ def register_servlets(hs, http_server): PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) + UserAdminServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py new file mode 100644 index 0000000000..b0fddb6898 --- /dev/null +++ b/synapse/rest/admin/users.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.rest.admin import assert_requester_is_admin +from synapse.types import UserID + + +class UserAdminServlet(RestServlet): + """ + Set whether or not a user is a server administrator. + + Note that only local users can be server administrators, and that an + administrator may not demote themselves. + + Only server administrators can use this API. + + Example: + PUT /_synapse/admin/v1/users/@reivilibre:librepush.net/admin + { + "admin": true + } + """ + + PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P@[^/]*)/admin$"),) + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_PUT(self, request, user_id): + yield assert_requester_is_admin(self.auth, request) + requester = yield self.auth.get_user_by_req(request) + auth_user = requester.user + + target_user = UserID.from_string(user_id) + + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["admin"]) + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Only local users can be admins of this homeserver") + + set_admin_to = bool(body["admin"]) + + if target_user == auth_user and not set_admin_to: + raise SynapseError(400, "You may not demote yourself.") + + yield self.handlers.admin_handler.set_user_server_admin( + target_user, set_admin_to + ) + + return (200, {}) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 55e4e84d71..9027b917c1 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -272,6 +272,14 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def is_server_admin(self, user): + """Determines if a user is an admin of this homeserver. + + Args: + user (UserID): user ID of the user to test + + Returns (bool): + true iff the user is a server admin, false otherwise. + """ res = yield self._simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, @@ -282,6 +290,21 @@ class RegistrationWorkerStore(SQLBaseStore): return res if res else False + def set_server_admin(self, user, admin): + """Sets whether a user is an admin of this homeserver. + + Args: + user (UserID): user ID of the user to test + admin (bool): true iff the user is to be a server admin, + false otherwise. + """ + return self._simple_update_one( + table="users", + keyvalues={"name": user.to_string()}, + updatevalues={"admin": 1 if admin else 0}, + desc="set_server_admin", + ) + def _query_for_auth(self, txn, token): sql = ( "SELECT users.name, users.is_guest, access_tokens.id as token_id," -- cgit 1.5.1 From 6e834e94fcc97811e4cc8185e86c6b9da06eb28e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Sep 2019 13:04:27 +0100 Subject: Fix and refactor room and user stats (#5971) Previously the stats were not being correctly populated. --- changelog.d/5971.bugfix | 1 + docs/room_and_user_statistics.md | 62 ++ synapse/config/stats.py | 13 +- synapse/handlers/stats.py | 307 +++--- synapse/storage/events.py | 5 +- synapse/storage/registration.py | 12 + synapse/storage/roommember.py | 44 +- .../storage/schema/delta/56/stats_separated.sql | 152 +++ synapse/storage/stats.py | 1036 ++++++++++++++------ tests/handlers/test_stats.py | 643 +++++++++--- tests/rest/client/v1/utils.py | 8 +- 11 files changed, 1642 insertions(+), 641 deletions(-) create mode 100644 changelog.d/5971.bugfix create mode 100644 docs/room_and_user_statistics.md create mode 100644 synapse/storage/schema/delta/56/stats_separated.sql (limited to 'synapse/storage/registration.py') diff --git a/changelog.d/5971.bugfix b/changelog.d/5971.bugfix new file mode 100644 index 0000000000..9ea095103b --- /dev/null +++ b/changelog.d/5971.bugfix @@ -0,0 +1 @@ +Fix room and user stats tracking. diff --git a/docs/room_and_user_statistics.md b/docs/room_and_user_statistics.md new file mode 100644 index 0000000000..e1facb38d4 --- /dev/null +++ b/docs/room_and_user_statistics.md @@ -0,0 +1,62 @@ +Room and User Statistics +======================== + +Synapse maintains room and user statistics (as well as a cache of room state), +in various tables. These can be used for administrative purposes but are also +used when generating the public room directory. + + +# Synapse Developer Documentation + +## High-Level Concepts + +### Definitions + +* **subject**: Something we are tracking stats about – currently a room or user. +* **current row**: An entry for a subject in the appropriate current statistics + table. Each subject can have only one. +* **historical row**: An entry for a subject in the appropriate historical + statistics table. Each subject can have any number of these. + +### Overview + +Stats are maintained as time series. There are two kinds of column: + +* absolute columns – where the value is correct for the time given by `end_ts` + in the stats row. (Imagine a line graph for these values) + * They can also be thought of as 'gauges' in Prometheus, if you are familiar. +* per-slice columns – where the value corresponds to how many of the occurrences + occurred within the time slice given by `(end_ts − bucket_size)…end_ts` + or `start_ts…end_ts`. (Imagine a histogram for these values) + +Stats are maintained in two tables (for each type): current and historical. + +Current stats correspond to the present values. Each subject can only have one +entry. + +Historical stats correspond to values in the past. Subjects may have multiple +entries. + +## Concepts around the management of stats + +### Current rows + +Current rows contain the most up-to-date statistics for a room. +They only contain absolute columns + +### Historical rows + +Historical rows can always be considered to be valid for the time slice and +end time specified. + +* historical rows will not exist for every time slice – they will be omitted + if there were no changes. In this case, the following assumptions can be + made to interpolate/recreate missing rows: + - absolute fields have the same values as in the preceding row + - per-slice fields are zero (`0`) +* historical rows will not be retained forever – rows older than a configurable + time will be purged. + +#### Purge + +The purging of historical rows is not yet implemented. diff --git a/synapse/config/stats.py b/synapse/config/stats.py index b518a3ed9c..b18ddbd1fa 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -27,19 +27,16 @@ class StatsConfig(Config): def read_config(self, config, **kwargs): self.stats_enabled = True - self.stats_bucket_size = 86400 + self.stats_bucket_size = 86400 * 1000 self.stats_retention = sys.maxsize stats_config = config.get("stats", None) if stats_config: self.stats_enabled = stats_config.get("enabled", self.stats_enabled) - self.stats_bucket_size = ( - self.parse_duration(stats_config.get("bucket_size", "1d")) / 1000 + self.stats_bucket_size = self.parse_duration( + stats_config.get("bucket_size", "1d") ) - self.stats_retention = ( - self.parse_duration( - stats_config.get("retention", "%ds" % (sys.maxsize,)) - ) - / 1000 + self.stats_retention = self.parse_duration( + stats_config.get("retention", "%ds" % (sys.maxsize,)) ) def generate_config_section(self, config_dir_path, server_name, **kwargs): diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 4449da6669..921735edb3 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -14,15 +14,14 @@ # limitations under the License. import logging +from collections import Counter from twisted.internet import defer -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import EventTypes, Membership from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID -from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -62,11 +61,10 @@ class StatsHandler(StateDeltasHandler): def notify_new_event(self): """Called when there may be more deltas to process """ - if not self.hs.config.stats_enabled: + if not self.hs.config.stats_enabled or self._is_processing: return - if self._is_processing: - return + self._is_processing = True @defer.inlineCallbacks def process(): @@ -75,39 +73,72 @@ class StatsHandler(StateDeltasHandler): finally: self._is_processing = False - self._is_processing = True run_as_background_process("stats.notify_new_event", process) @defer.inlineCallbacks def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: - self.pos = yield self.store.get_stats_stream_pos() - - # If still None then the initial background update hasn't happened yet - if self.pos is None: - return None + self.pos = yield self.store.get_stats_positions() # Loop round handling deltas until we're up to date + while True: - with Measure(self.clock, "stats_delta"): - deltas = yield self.store.get_current_state_deltas(self.pos) - if not deltas: - return + deltas = yield self.store.get_current_state_deltas(self.pos) + + if deltas: + logger.debug("Handling %d state deltas", len(deltas)) + room_deltas, user_deltas = yield self._handle_deltas(deltas) + + max_pos = deltas[-1]["stream_id"] + else: + room_deltas = {} + user_deltas = {} + max_pos = yield self.store.get_room_max_stream_ordering() - logger.info("Handling %d state deltas", len(deltas)) - yield self._handle_deltas(deltas) + # Then count deltas for total_events and total_event_bytes. + room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes( + self.pos, max_pos + ) + + for room_id, fields in room_count.items(): + room_deltas.setdefault(room_id, {}).update(fields) + + for user_id, fields in user_count.items(): + user_deltas.setdefault(user_id, {}).update(fields) + + logger.debug("room_deltas: %s", room_deltas) + logger.debug("user_deltas: %s", user_deltas) - self.pos = deltas[-1]["stream_id"] - yield self.store.update_stats_stream_pos(self.pos) + # Always call this so that we update the stats position. + yield self.store.bulk_update_stats_delta( + self.clock.time_msec(), + updates={"room": room_deltas, "user": user_deltas}, + stream_id=max_pos, + ) + + event_processing_positions.labels("stats").set(max_pos) - event_processing_positions.labels("stats").set(self.pos) + if self.pos == max_pos: + break + + self.pos = max_pos @defer.inlineCallbacks def _handle_deltas(self, deltas): + """Called with the state deltas to process + + Returns: + Deferred[tuple[dict[str, Counter], dict[str, counter]]] + Resovles to two dicts, the room deltas and the user deltas, + mapping from room/user ID to changes in the various fields. """ - Called with the state deltas to process - """ + + room_to_stats_deltas = {} + user_to_stats_deltas = {} + + room_to_state_updates = {} + for delta in deltas: typ = delta["type"] state_key = delta["state_key"] @@ -115,11 +146,10 @@ class StatsHandler(StateDeltasHandler): event_id = delta["event_id"] stream_id = delta["stream_id"] prev_event_id = delta["prev_event_id"] - stream_pos = delta["stream_id"] - logger.debug("Handling: %r %r, %s", typ, state_key, event_id) + logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id) - token = yield self.store.get_earliest_token_for_room_stats(room_id) + token = yield self.store.get_earliest_token_for_stats("room", room_id) # If the earliest token to begin from is larger than our current # stream ID, skip processing this delta. @@ -131,203 +161,130 @@ class StatsHandler(StateDeltasHandler): continue if event_id is None and prev_event_id is None: - # Errr... + logger.error( + "event ID is None and so is the previous event ID. stream_id: %s", + stream_id, + ) continue event_content = {} + sender = None if event_id is not None: event = yield self.store.get_event(event_id, allow_none=True) if event: event_content = event.content or {} + sender = event.sender + + # All the values in this dict are deltas (RELATIVE changes) + room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter()) - # We use stream_pos here rather than fetch by event_id as event_id - # may be None - now = yield self.store.get_received_ts_by_stream_pos(stream_pos) + room_state = room_to_state_updates.setdefault(room_id, {}) - # quantise time to the nearest bucket - now = (now // 1000 // self.stats_bucket_size) * self.stats_bucket_size + if prev_event_id is None: + # this state event doesn't overwrite another, + # so it is a new effective/current state event + room_stats_delta["current_state_events"] += 1 if typ == EventTypes.Member: # we could use _get_key_change here but it's a bit inefficient # given we're not testing for a specific result; might as well # just grab the prev_membership and membership strings and # compare them. - prev_event_content = {} + # We take None rather than leave as a previous membership + # in the absence of a previous event because we do not want to + # reduce the leave count when a new-to-the-room user joins. + prev_membership = None if prev_event_id is not None: prev_event = yield self.store.get_event( prev_event_id, allow_none=True ) if prev_event: prev_event_content = prev_event.content + prev_membership = prev_event_content.get( + "membership", Membership.LEAVE + ) membership = event_content.get("membership", Membership.LEAVE) - prev_membership = prev_event_content.get("membership", Membership.LEAVE) - - if prev_membership == membership: - continue - if prev_membership == Membership.JOIN: - yield self.store.update_stats_delta( - now, "room", room_id, "joined_members", -1 - ) + if prev_membership is None: + logger.debug("No previous membership for this user.") + elif membership == prev_membership: + pass # noop + elif prev_membership == Membership.JOIN: + room_stats_delta["joined_members"] -= 1 elif prev_membership == Membership.INVITE: - yield self.store.update_stats_delta( - now, "room", room_id, "invited_members", -1 - ) + room_stats_delta["invited_members"] -= 1 elif prev_membership == Membership.LEAVE: - yield self.store.update_stats_delta( - now, "room", room_id, "left_members", -1 - ) + room_stats_delta["left_members"] -= 1 elif prev_membership == Membership.BAN: - yield self.store.update_stats_delta( - now, "room", room_id, "banned_members", -1 - ) + room_stats_delta["banned_members"] -= 1 else: - err = "%s is not a valid prev_membership" % (repr(prev_membership),) - logger.error(err) - raise ValueError(err) + raise ValueError( + "%r is not a valid prev_membership" % (prev_membership,) + ) + if membership == prev_membership: + pass # noop if membership == Membership.JOIN: - yield self.store.update_stats_delta( - now, "room", room_id, "joined_members", +1 - ) + room_stats_delta["joined_members"] += 1 elif membership == Membership.INVITE: - yield self.store.update_stats_delta( - now, "room", room_id, "invited_members", +1 - ) + room_stats_delta["invited_members"] += 1 + + if sender and self.is_mine_id(sender): + user_to_stats_deltas.setdefault(sender, Counter())[ + "invites_sent" + ] += 1 + elif membership == Membership.LEAVE: - yield self.store.update_stats_delta( - now, "room", room_id, "left_members", +1 - ) + room_stats_delta["left_members"] += 1 elif membership == Membership.BAN: - yield self.store.update_stats_delta( - now, "room", room_id, "banned_members", +1 - ) + room_stats_delta["banned_members"] += 1 else: - err = "%s is not a valid membership" % (repr(membership),) - logger.error(err) - raise ValueError(err) + raise ValueError("%r is not a valid membership" % (membership,)) user_id = state_key if self.is_mine_id(user_id): - # update user_stats as it's one of our users - public = yield self._is_public_room(room_id) - - if membership == Membership.LEAVE: - yield self.store.update_stats_delta( - now, - "user", - user_id, - "public_rooms" if public else "private_rooms", - -1, - ) - elif membership == Membership.JOIN: - yield self.store.update_stats_delta( - now, - "user", - user_id, - "public_rooms" if public else "private_rooms", - +1, - ) + # this accounts for transitions like leave → ban and so on. + has_changed_joinedness = (prev_membership == Membership.JOIN) != ( + membership == Membership.JOIN + ) - elif typ == EventTypes.Create: - # Newly created room. Add it with all blank portions. - yield self.store.update_room_state( - room_id, - { - "join_rules": None, - "history_visibility": None, - "encryption": None, - "name": None, - "topic": None, - "avatar": None, - "canonical_alias": None, - }, - ) + if has_changed_joinedness: + delta = +1 if membership == Membership.JOIN else -1 - elif typ == EventTypes.JoinRules: - yield self.store.update_room_state( - room_id, {"join_rules": event_content.get("join_rule")} - ) + user_to_stats_deltas.setdefault(user_id, Counter())[ + "joined_rooms" + ] += delta - is_public = yield self._get_key_change( - prev_event_id, event_id, "join_rule", JoinRules.PUBLIC - ) - if is_public is not None: - yield self.update_public_room_stats(now, room_id, is_public) + room_stats_delta["local_users_in_room"] += delta + elif typ == EventTypes.Create: + room_state["is_federatable"] = event_content.get("m.federate", True) + if sender and self.is_mine_id(sender): + user_to_stats_deltas.setdefault(sender, Counter())[ + "rooms_created" + ] += 1 + elif typ == EventTypes.JoinRules: + room_state["join_rules"] = event_content.get("join_rule") elif typ == EventTypes.RoomHistoryVisibility: - yield self.store.update_room_state( - room_id, - {"history_visibility": event_content.get("history_visibility")}, - ) - - is_public = yield self._get_key_change( - prev_event_id, event_id, "history_visibility", "world_readable" + room_state["history_visibility"] = event_content.get( + "history_visibility" ) - if is_public is not None: - yield self.update_public_room_stats(now, room_id, is_public) - elif typ == EventTypes.Encryption: - yield self.store.update_room_state( - room_id, {"encryption": event_content.get("algorithm")} - ) + room_state["encryption"] = event_content.get("algorithm") elif typ == EventTypes.Name: - yield self.store.update_room_state( - room_id, {"name": event_content.get("name")} - ) + room_state["name"] = event_content.get("name") elif typ == EventTypes.Topic: - yield self.store.update_room_state( - room_id, {"topic": event_content.get("topic")} - ) + room_state["topic"] = event_content.get("topic") elif typ == EventTypes.RoomAvatar: - yield self.store.update_room_state( - room_id, {"avatar": event_content.get("url")} - ) + room_state["avatar"] = event_content.get("url") elif typ == EventTypes.CanonicalAlias: - yield self.store.update_room_state( - room_id, {"canonical_alias": event_content.get("alias")} - ) + room_state["canonical_alias"] = event_content.get("alias") + elif typ == EventTypes.GuestAccess: + room_state["guest_access"] = event_content.get("guest_access") - @defer.inlineCallbacks - def update_public_room_stats(self, ts, room_id, is_public): - """ - Increment/decrement a user's number of public rooms when a room they are - in changes to/from public visibility. + for room_id, state in room_to_state_updates.items(): + yield self.store.update_room_state(room_id, state) - Args: - ts (int): Timestamp in seconds - room_id (str) - is_public (bool) - """ - # For now, blindly iterate over all local users in the room so that - # we can handle the whole problem of copying buckets over as needed - user_ids = yield self.store.get_users_in_room(room_id) - - for user_id in user_ids: - if self.hs.is_mine(UserID.from_string(user_id)): - yield self.store.update_stats_delta( - ts, "user", user_id, "public_rooms", +1 if is_public else -1 - ) - yield self.store.update_stats_delta( - ts, "user", user_id, "private_rooms", -1 if is_public else +1 - ) - - @defer.inlineCallbacks - def _is_public_room(self, room_id): - join_rules = yield self.state.get_current_state(room_id, EventTypes.JoinRules) - history_visibility = yield self.state.get_current_state( - room_id, EventTypes.RoomHistoryVisibility - ) - - if (join_rules and join_rules.content.get("join_rule") == JoinRules.PUBLIC) or ( - ( - history_visibility - and history_visibility.content.get("history_visibility") - == "world_readable" - ) - ): - return True - else: - return False + return room_to_stats_deltas, user_to_stats_deltas diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 32050868ff..1958afe1d7 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -2270,8 +2270,9 @@ class EventsStore( "room_aliases", "room_depth", "room_memberships", - "room_state", - "room_stats", + "room_stats_state", + "room_stats_current", + "room_stats_historical", "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 3f50324253..2d3c7e2dc9 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -869,6 +869,17 @@ class RegistrationStore( (user_id_obj.localpart, create_profile_with_displayname), ) + if self.hs.config.stats_enabled: + # we create a new completed user statistics row + + # we don't strictly need current_token since this user really can't + # have any state deltas before now (as it is a new user), but still, + # we include it for completeness. + current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) + self._update_stats_delta_txn( + txn, now, "user", user_id, {}, complete_with_stream_id=current_token + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) @@ -1140,6 +1151,7 @@ class RegistrationStore( deferred str|None: A str representing a link to redirect the user to if there is one. """ + # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): row = self._simple_select_one_txn( diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index eecb276465..f8b682ebd9 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -112,29 +112,31 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): - def f(txn): - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - sql = """ - SELECT state_key FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? AND membership = ? - """ - else: - sql = """ - SELECT state_key FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? - """ + return self.runInteraction( + "get_users_in_room", self.get_users_in_room_txn, room_id + ) - txn.execute(sql, (room_id, Membership.JOIN)) - return [to_ascii(r[0]) for r in txn] + def get_users_in_room_txn(self, txn, room_id): + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT state_key FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? AND membership = ? + """ + else: + sql = """ + SELECT state_key FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? + """ - return self.runInteraction("get_users_in_room", f) + txn.execute(sql, (room_id, Membership.JOIN)) + return [to_ascii(r[0]) for r in txn] @cached(max_entries=100000) def get_room_summary(self, room_id): diff --git a/synapse/storage/schema/delta/56/stats_separated.sql b/synapse/storage/schema/delta/56/stats_separated.sql new file mode 100644 index 0000000000..163529c071 --- /dev/null +++ b/synapse/storage/schema/delta/56/stats_separated.sql @@ -0,0 +1,152 @@ +/* Copyright 2018 New Vector Ltd + * Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +----- First clean up from previous versions of room stats. + +-- First remove old stats stuff +DROP TABLE IF EXISTS room_stats; +DROP TABLE IF EXISTS room_state; +DROP TABLE IF EXISTS room_stats_state; +DROP TABLE IF EXISTS user_stats; +DROP TABLE IF EXISTS room_stats_earliest_tokens; +DROP TABLE IF EXISTS _temp_populate_stats_position; +DROP TABLE IF EXISTS _temp_populate_stats_rooms; +DROP TABLE IF EXISTS stats_stream_pos; + +-- Unschedule old background updates if they're still scheduled +DELETE FROM background_updates WHERE update_name IN ( + 'populate_stats_createtables', + 'populate_stats_process_rooms', + 'populate_stats_process_users', + 'populate_stats_cleanup' +); + +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', ''); + +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); + +----- Create tables for our version of room stats. + +-- single-row table to track position of incremental updates +DROP TABLE IF EXISTS stats_incremental_position; +CREATE TABLE stats_incremental_position ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); + +-- insert a null row and make sure it is the only one. +INSERT INTO stats_incremental_position ( + stream_id +) SELECT COALESCE(MAX(stream_ordering), 0) from events; + +-- represents PRESENT room statistics for a room +-- only holds absolute fields +DROP TABLE IF EXISTS room_stats_current; +CREATE TABLE room_stats_current ( + room_id TEXT NOT NULL PRIMARY KEY, + + -- These are absolute counts + current_state_events INT NOT NULL, + joined_members INT NOT NULL, + invited_members INT NOT NULL, + left_members INT NOT NULL, + banned_members INT NOT NULL, + + local_users_in_room INT NOT NULL, + + -- The maximum delta stream position that this row takes into account. + completed_delta_stream_id BIGINT NOT NULL +); + + +-- represents HISTORICAL room statistics for a room +DROP TABLE IF EXISTS room_stats_historical; +CREATE TABLE room_stats_historical ( + room_id TEXT NOT NULL, + -- These stats cover the time from (end_ts - bucket_size)...end_ts (in ms). + -- Note that end_ts is quantised. + end_ts BIGINT NOT NULL, + bucket_size BIGINT NOT NULL, + + -- These stats are absolute counts + current_state_events BIGINT NOT NULL, + joined_members BIGINT NOT NULL, + invited_members BIGINT NOT NULL, + left_members BIGINT NOT NULL, + banned_members BIGINT NOT NULL, + local_users_in_room BIGINT NOT NULL, + + -- These stats are per time slice + total_events BIGINT NOT NULL, + total_event_bytes BIGINT NOT NULL, + + PRIMARY KEY (room_id, end_ts) +); + +-- We use this index to speed up deletion of ancient room stats. +CREATE INDEX room_stats_historical_end_ts ON room_stats_historical (end_ts); + +-- represents PRESENT statistics for a user +-- only holds absolute fields +DROP TABLE IF EXISTS user_stats_current; +CREATE TABLE user_stats_current ( + user_id TEXT NOT NULL PRIMARY KEY, + + joined_rooms BIGINT NOT NULL, + + -- The maximum delta stream position that this row takes into account. + completed_delta_stream_id BIGINT NOT NULL +); + +-- represents HISTORICAL statistics for a user +DROP TABLE IF EXISTS user_stats_historical; +CREATE TABLE user_stats_historical ( + user_id TEXT NOT NULL, + end_ts BIGINT NOT NULL, + bucket_size BIGINT NOT NULL, + + joined_rooms BIGINT NOT NULL, + + invites_sent BIGINT NOT NULL, + rooms_created BIGINT NOT NULL, + total_events BIGINT NOT NULL, + total_event_bytes BIGINT NOT NULL, + + PRIMARY KEY (user_id, end_ts) +); + +-- We use this index to speed up deletion of ancient user stats. +CREATE INDEX user_stats_historical_end_ts ON user_stats_historical (end_ts); + + +CREATE TABLE room_stats_state ( + room_id TEXT NOT NULL, + name TEXT, + canonical_alias TEXT, + join_rules TEXT, + history_visibility TEXT, + encryption TEXT, + avatar TEXT, + guest_access TEXT, + is_federatable BOOLEAN, + topic TEXT +); + +CREATE UNIQUE INDEX room_stats_state_room ON room_stats_state(room_id); diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index e13efed417..6560173c08 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018, 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,17 +15,22 @@ # limitations under the License. import logging +from itertools import chain from twisted.internet import defer +from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership -from synapse.storage.prepare_database import get_statements +from synapse.storage import PostgresEngine from synapse.storage.state_deltas import StateDeltasStore from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) # these fields track absolutes (e.g. total number of rooms on the server) +# You can think of these as Prometheus Gauges. +# You can draw these stats on a line graph. +# Example: number of users in a room ABSOLUTE_STATS_FIELDS = { "room": ( "current_state_events", @@ -32,14 +38,23 @@ ABSOLUTE_STATS_FIELDS = { "invited_members", "left_members", "banned_members", - "state_events", + "local_users_in_room", ), - "user": ("public_rooms", "private_rooms"), + "user": ("joined_rooms",), } -TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} +# these fields are per-timeslice and so should be reset to 0 upon a new slice +# You can draw these stats on a histogram. +# Example: number of events sent locally during a time slice +PER_SLICE_FIELDS = { + "room": ("total_events", "total_event_bytes"), + "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"), +} + +TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} -TEMP_TABLE = "_temp_populate_stats" +# these are the tables (& ID columns) which contain our actual subjects +TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): @@ -51,136 +66,102 @@ class StatsStore(StateDeltasStore): self.stats_enabled = hs.config.stats_enabled self.stats_bucket_size = hs.config.stats_bucket_size - self.register_background_update_handler( - "populate_stats_createtables", self._populate_stats_createtables - ) + self.stats_delta_processing_lock = DeferredLock() + self.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) self.register_background_update_handler( - "populate_stats_cleanup", self._populate_stats_cleanup + "populate_stats_process_users", self._populate_stats_process_users ) + # we no longer need to perform clean-up, but we will give ourselves + # the potential to reintroduce it in the future – so documentation + # will still encourage the use of this no-op handler. + self.register_noop_background_update("populate_stats_cleanup") + self.register_noop_background_update("populate_stats_prepare") - @defer.inlineCallbacks - def _populate_stats_createtables(self, progress, batch_size): - - if not self.stats_enabled: - yield self._end_background_update("populate_stats_createtables") - return 1 - - # Get all the rooms that we want to process. - def _make_staging_area(txn): - # Create the temporary tables - stmts = get_statements( - """ - -- We just recreate the table, we'll be reinserting the - -- correct entries again later anyway. - DROP TABLE IF EXISTS {temp}_rooms; - - CREATE TABLE IF NOT EXISTS {temp}_rooms( - room_id TEXT NOT NULL, - events BIGINT NOT NULL - ); - - CREATE INDEX {temp}_rooms_events - ON {temp}_rooms(events); - CREATE INDEX {temp}_rooms_id - ON {temp}_rooms(room_id); - """.format( - temp=TEMP_TABLE - ).splitlines() - ) - - for statement in stmts: - txn.execute(statement) - - sql = ( - "CREATE TABLE IF NOT EXISTS " - + TEMP_TABLE - + "_position(position TEXT NOT NULL)" - ) - txn.execute(sql) - - # Get rooms we want to process from the database, only adding - # those that we haven't (i.e. those not in room_stats_earliest_token) - sql = """ - INSERT INTO %s_rooms (room_id, events) - SELECT c.room_id, count(*) FROM current_state_events AS c - LEFT JOIN room_stats_earliest_token AS t USING (room_id) - WHERE t.room_id IS NULL - GROUP BY c.room_id - """ % ( - TEMP_TABLE, - ) - txn.execute(sql) + def quantise_stats_time(self, ts): + """ + Quantises a timestamp to be a multiple of the bucket size. - new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.runInteraction("populate_stats_temp_build", _make_staging_area) - yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) - self.get_earliest_token_for_room_stats.invalidate_all() + Args: + ts (int): the timestamp to quantise, in milliseconds since the Unix + Epoch - yield self._end_background_update("populate_stats_createtables") - return 1 + Returns: + int: a timestamp which + - is divisible by the bucket size; + - is no later than `ts`; and + - is the largest such timestamp. + """ + return (ts // self.stats_bucket_size) * self.stats_bucket_size @defer.inlineCallbacks - def _populate_stats_cleanup(self, progress, batch_size): + def _populate_stats_process_users(self, progress, batch_size): """ - Update the user directory stream position, then clean up the old tables. + This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_cleanup") + yield self._end_background_update("populate_stats_process_users") return 1 - position = yield self._simple_select_one_onecol( - TEMP_TABLE + "_position", None, "position" + last_user_id = progress.get("last_user_id", "") + + def _get_next_batch(txn): + sql = """ + SELECT DISTINCT name FROM users + WHERE name > ? + ORDER BY name ASC + LIMIT ? + """ + txn.execute(sql, (last_user_id, batch_size)) + return [r for r, in txn] + + users_to_work_on = yield self.runInteraction( + "_populate_stats_process_users", _get_next_batch ) - yield self.update_stats_stream_pos(position) - def _delete_staging_area(txn): - txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") - txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") + # No more rooms -- complete the transaction. + if not users_to_work_on: + yield self._end_background_update("populate_stats_process_users") + return 1 - yield self.runInteraction("populate_stats_cleanup", _delete_staging_area) + for user_id in users_to_work_on: + yield self._calculate_and_set_initial_state_for_user(user_id) + progress["last_user_id"] = user_id - yield self._end_background_update("populate_stats_cleanup") - return 1 + yield self.runInteraction( + "populate_stats_process_users", + self._background_update_progress_txn, + "populate_stats_process_users", + progress, + ) + + return len(users_to_work_on) @defer.inlineCallbacks def _populate_stats_process_rooms(self, progress, batch_size): - + """ + This is a background update which regenerates statistics for rooms. + """ if not self.stats_enabled: yield self._end_background_update("populate_stats_process_rooms") return 1 - # If we don't have progress filed, delete everything. - if not progress: - yield self.delete_all_stats() + last_room_id = progress.get("last_room_id", "") def _get_next_batch(txn): - # Only fetch 250 rooms, so we don't fetch too many at once, even - # if those 250 rooms have less than batch_size state events. sql = """ - SELECT room_id, events FROM %s_rooms - ORDER BY events DESC - LIMIT 250 - """ % ( - TEMP_TABLE, - ) - txn.execute(sql) - rooms_to_work_on = txn.fetchall() - - if not rooms_to_work_on: - return None - - # Get how many are left to process, so we can give status on how - # far we are in processing - txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") - progress["remaining"] = txn.fetchone()[0] - - return rooms_to_work_on + SELECT DISTINCT room_id FROM current_state_events + WHERE room_id > ? + ORDER BY room_id ASC + LIMIT ? + """ + txn.execute(sql, (last_room_id, batch_size)) + return [r for r, in txn] rooms_to_work_on = yield self.runInteraction( - "populate_stats_temp_read", _get_next_batch + "populate_stats_rooms_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. @@ -188,154 +169,28 @@ class StatsStore(StateDeltasStore): yield self._end_background_update("populate_stats_process_rooms") return 1 - logger.info( - "Processing the next %d rooms of %d remaining", - len(rooms_to_work_on), - progress["remaining"], - ) - - # Number of state events we've processed by going through each room - processed_event_count = 0 - - for room_id, event_count in rooms_to_work_on: - - current_state_ids = yield self.get_current_state_ids(room_id) - - join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) - history_visibility_id = current_state_ids.get( - (EventTypes.RoomHistoryVisibility, "") - ) - encryption_id = current_state_ids.get((EventTypes.RoomEncryption, "")) - name_id = current_state_ids.get((EventTypes.Name, "")) - topic_id = current_state_ids.get((EventTypes.Topic, "")) - avatar_id = current_state_ids.get((EventTypes.RoomAvatar, "")) - canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, "")) - - event_ids = [ - join_rules_id, - history_visibility_id, - encryption_id, - name_id, - topic_id, - avatar_id, - canonical_alias_id, - ] - - state_events = yield self.get_events( - [ev for ev in event_ids if ev is not None] - ) - - def _get_or_none(event_id, arg): - event = state_events.get(event_id) - if event: - return event.content.get(arg) - return None - - yield self.update_room_state( - room_id, - { - "join_rules": _get_or_none(join_rules_id, "join_rule"), - "history_visibility": _get_or_none( - history_visibility_id, "history_visibility" - ), - "encryption": _get_or_none(encryption_id, "algorithm"), - "name": _get_or_none(name_id, "name"), - "topic": _get_or_none(topic_id, "topic"), - "avatar": _get_or_none(avatar_id, "url"), - "canonical_alias": _get_or_none(canonical_alias_id, "alias"), - }, - ) + for room_id in rooms_to_work_on: + yield self._calculate_and_set_initial_state_for_room(room_id) + progress["last_room_id"] = room_id - now = self.hs.get_reactor().seconds() - - # quantise time to the nearest bucket - now = (now // self.stats_bucket_size) * self.stats_bucket_size - - def _fetch_data(txn): - - # Get the current token of the room - current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) - - current_state_events = len(current_state_ids) - - membership_counts = self._get_user_counts_in_room_txn(txn, room_id) - - total_state_events = self._get_total_state_event_counts_txn( - txn, room_id - ) - - self._update_stats_txn( - txn, - "room", - room_id, - now, - { - "bucket_size": self.stats_bucket_size, - "current_state_events": current_state_events, - "joined_members": membership_counts.get(Membership.JOIN, 0), - "invited_members": membership_counts.get(Membership.INVITE, 0), - "left_members": membership_counts.get(Membership.LEAVE, 0), - "banned_members": membership_counts.get(Membership.BAN, 0), - "state_events": total_state_events, - }, - ) - self._simple_insert_txn( - txn, - "room_stats_earliest_token", - {"room_id": room_id, "token": current_token}, - ) - - # We've finished a room. Delete it from the table. - self._simple_delete_one_txn( - txn, TEMP_TABLE + "_rooms", {"room_id": room_id} - ) - - yield self.runInteraction("update_room_stats", _fetch_data) - - # Update the remaining counter. - progress["remaining"] -= 1 - yield self.runInteraction( - "populate_stats", - self._background_update_progress_txn, - "populate_stats_process_rooms", - progress, - ) - - processed_event_count += event_count - - if processed_event_count > batch_size: - # Don't process any more rooms, we've hit our batch size. - return processed_event_count + yield self.runInteraction( + "_populate_stats_process_rooms", + self._background_update_progress_txn, + "populate_stats_process_rooms", + progress, + ) - return processed_event_count + return len(rooms_to_work_on) - def delete_all_stats(self): + def get_stats_positions(self): """ - Delete all statistics records. + Returns the stats processor positions. """ - - def _delete_all_stats_txn(txn): - txn.execute("DELETE FROM room_state") - txn.execute("DELETE FROM room_stats") - txn.execute("DELETE FROM room_stats_earliest_token") - txn.execute("DELETE FROM user_stats") - - return self.runInteraction("delete_all_stats", _delete_all_stats_txn) - - def get_stats_stream_pos(self): return self._simple_select_one_onecol( - table="stats_stream_pos", + table="stats_incremental_position", keyvalues={}, retcol="stream_id", - desc="stats_stream_pos", - ) - - def update_stats_stream_pos(self, stream_id): - return self._simple_update_one( - table="stats_stream_pos", - keyvalues={}, - updatevalues={"stream_id": stream_id}, - desc="update_stats_stream_pos", + desc="stats_incremental_position", ) def update_room_state(self, room_id, fields): @@ -361,42 +216,87 @@ class StatsStore(StateDeltasStore): fields[col] = None return self._simple_upsert( - table="room_state", + table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, desc="update_room_state", ) - def get_deltas_for_room(self, room_id, start, size=100): + def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): """ - Get statistics deltas for a given room. + Get statistics for a given subject. Args: - room_id (str) + stats_type (str): The type of subject + stats_id (str): The ID of the subject (e.g. room_id or user_id) start (int): Pagination start. Number of entries, not timestamp. size (int): How many entries to return. Returns: Deferred[list[dict]], where the dict has the keys of - ABSOLUTE_STATS_FIELDS["room"] and "ts". + ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". """ - return self._simple_select_list_paginate( - "room_stats", - {"room_id": room_id}, - "ts", + return self.runInteraction( + "get_statistics_for_subject", + self._get_statistics_for_subject_txn, + stats_type, + stats_id, + start, + size, + ) + + def _get_statistics_for_subject_txn( + self, txn, stats_type, stats_id, start, size=100 + ): + """ + Transaction-bound version of L{get_statistics_for_subject}. + """ + + table, id_col = TYPE_TO_TABLE[stats_type] + selected_columns = list( + ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] + ) + + slice_list = self._simple_select_list_paginate_txn( + txn, + table + "_historical", + {id_col: stats_id}, + "end_ts", start, size, - retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]), + retcols=selected_columns + ["bucket_size", "end_ts"], order_direction="DESC", ) - def get_all_room_state(self): - return self._simple_select_list( - "room_state", None, retcols=("name", "topic", "canonical_alias") + return slice_list + + def get_room_stats_state(self, room_id): + """ + Returns the current room_stats_state for a room. + + Args: + room_id (str): The ID of the room to return state for. + + Returns (dict): + Dictionary containing these keys: + "name", "topic", "canonical_alias", "avatar", "join_rules", + "history_visibility" + """ + return self._simple_select_one( + "room_stats_state", + {"room_id": room_id}, + retcols=( + "name", + "topic", + "canonical_alias", + "avatar", + "join_rules", + "history_visibility", + ), ) @cached() - def get_earliest_token_for_room_stats(self, room_id): + def get_earliest_token_for_stats(self, stats_type, id): """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -406,79 +306,571 @@ class StatsStore(StateDeltasStore): Returns: Deferred[int] """ + table, id_col = TYPE_TO_TABLE[stats_type] + return self._simple_select_one_onecol( - "room_stats_earliest_token", - {"room_id": room_id}, - retcol="token", + "%s_current" % (table,), + keyvalues={id_col: id}, + retcol="completed_delta_stream_id", allow_none=True, ) - def update_stats(self, stats_type, stats_id, ts, fields): - table, id_col = TYPE_TO_ROOM[stats_type] - return self._simple_upsert( - table=table, - keyvalues={id_col: stats_id, "ts": ts}, - values=fields, - desc="update_stats", + def bulk_update_stats_delta(self, ts, updates, stream_id): + """Bulk update stats tables for a given stream_id and updates the stats + incremental position. + + Args: + ts (int): Current timestamp in ms + updates(dict[str, dict[str, dict[str, Counter]]]): The updates to + commit as a mapping stats_type -> stats_id -> field -> delta. + stream_id (int): Current position. + + Returns: + Deferred + """ + + def _bulk_update_stats_delta_txn(txn): + for stats_type, stats_updates in updates.items(): + for stats_id, fields in stats_updates.items(): + self._update_stats_delta_txn( + txn, + ts=ts, + stats_type=stats_type, + stats_id=stats_id, + fields=fields, + complete_with_stream_id=stream_id, + ) + + self._simple_update_one_txn( + txn, + table="stats_incremental_position", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + ) + + return self.runInteraction( + "bulk_update_stats_delta", _bulk_update_stats_delta_txn ) - def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields): - table, id_col = TYPE_TO_ROOM[stats_type] - return self._simple_upsert_txn( - txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields + def update_stats_delta( + self, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id, + absolute_field_overrides=None, + ): + """ + Updates the statistics for a subject, with a delta (difference/relative + change). + + Args: + ts (int): timestamp of the change + stats_type (str): "room" or "user" – the kind of subject + stats_id (str): the subject's ID (room ID or user ID) + fields (dict[str, int]): Deltas of stats values. + complete_with_stream_id (int, optional): + If supplied, converts an incomplete row into a complete row, + with the supplied stream_id marked as the stream_id where the + row was completed. + absolute_field_overrides (dict[str, int]): Current stats values + (i.e. not deltas) of absolute fields. + Does not work with per-slice fields. + """ + + return self.runInteraction( + "update_stats_delta", + self._update_stats_delta_txn, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id=complete_with_stream_id, + absolute_field_overrides=absolute_field_overrides, ) - def update_stats_delta(self, ts, stats_type, stats_id, field, value): - def _update_stats_delta(txn): - table, id_col = TYPE_TO_ROOM[stats_type] - - sql = ( - "SELECT * FROM %s" - " WHERE %s=? and ts=(" - " SELECT MAX(ts) FROM %s" - " WHERE %s=?" - ")" - ) % (table, id_col, table, id_col) - txn.execute(sql, (stats_id, stats_id)) - rows = self.cursor_to_dict(txn) - if len(rows) == 0: - # silently skip as we don't have anything to apply a delta to yet. - # this tries to minimise any race between the initial sync and - # subsequent deltas arriving. - return - - current_ts = ts - latest_ts = rows[0]["ts"] - if current_ts < latest_ts: - # This one is in the past, but we're just encountering it now. - # Mark it as part of the current bucket. - current_ts = latest_ts - elif ts != latest_ts: - # we have to copy our absolute counters over to the new entry. - values = { - key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type] - } - values[id_col] = stats_id - values["ts"] = ts - values["bucket_size"] = self.stats_bucket_size - - self._simple_insert_txn(txn, table=table, values=values) - - # actually update the new value - if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]: - self._simple_update_txn( - txn, - table=table, - keyvalues={id_col: stats_id, "ts": current_ts}, - updatevalues={field: value}, + def _update_stats_delta_txn( + self, + txn, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id, + absolute_field_overrides=None, + ): + if absolute_field_overrides is None: + absolute_field_overrides = {} + + table, id_col = TYPE_TO_TABLE[stats_type] + + quantised_ts = self.quantise_stats_time(int(ts)) + end_ts = quantised_ts + self.stats_bucket_size + + # Lets be paranoid and check that all the given field names are known + abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type] + slice_field_names = PER_SLICE_FIELDS[stats_type] + for field in chain(fields.keys(), absolute_field_overrides.keys()): + if field not in abs_field_names and field not in slice_field_names: + # guard against potential SQL injection dodginess + raise ValueError( + "%s is not a recognised field" + " for stats type %s" % (field, stats_type) ) + + # Per slice fields do not get added to the _current table + + # This calculates the deltas (`field = field + ?` values) + # for absolute fields, + # * defaulting to 0 if not specified + # (required for the INSERT part of upserting to work) + # * omitting overrides specified in `absolute_field_overrides` + deltas_of_absolute_fields = { + key: fields.get(key, 0) + for key in abs_field_names + if key not in absolute_field_overrides + } + + # Keep the delta stream ID field up to date + absolute_field_overrides = absolute_field_overrides.copy() + absolute_field_overrides["completed_delta_stream_id"] = complete_with_stream_id + + # first upsert the `_current` table + self._upsert_with_additive_relatives_txn( + txn=txn, + table=table + "_current", + keyvalues={id_col: stats_id}, + absolutes=absolute_field_overrides, + additive_relatives=deltas_of_absolute_fields, + ) + + per_slice_additive_relatives = { + key: fields.get(key, 0) for key in slice_field_names + } + self._upsert_copy_from_table_with_additive_relatives_txn( + txn=txn, + into_table=table + "_historical", + keyvalues={id_col: stats_id}, + extra_dst_insvalues={"bucket_size": self.stats_bucket_size}, + extra_dst_keyvalues={"end_ts": end_ts}, + additive_relatives=per_slice_additive_relatives, + src_table=table + "_current", + copy_columns=abs_field_names, + ) + + def _upsert_with_additive_relatives_txn( + self, txn, table, keyvalues, absolutes, additive_relatives + ): + """Used to update values in the stats tables. + + This is basically a slightly convoluted upsert that *adds* to any + existing rows. + + Args: + txn + table (str): Table name + keyvalues (dict[str, any]): Row-identifying key values + absolutes (dict[str, any]): Absolute (set) fields + additive_relatives (dict[str, int]): Fields that will be added onto + if existing row present. + """ + if self.database_engine.can_native_upsert: + absolute_updates = [ + "%(field)s = EXCLUDED.%(field)s" % {"field": field} + for field in absolutes.keys() + ] + + relative_updates = [ + "%(field)s = EXCLUDED.%(field)s + %(table)s.%(field)s" + % {"table": table, "field": field} + for field in additive_relatives.keys() + ] + + insert_cols = [] + qargs = [] + + for (key, val) in chain( + keyvalues.items(), absolutes.items(), additive_relatives.items() + ): + insert_cols.append(key) + qargs.append(val) + + sql = """ + INSERT INTO %(table)s (%(insert_cols_cs)s) + VALUES (%(insert_vals_qs)s) + ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s + """ % { + "table": table, + "insert_cols_cs": ", ".join(insert_cols), + "insert_vals_qs": ", ".join( + ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives)) + ), + "key_columns": ", ".join(keyvalues), + "updates": ", ".join(chain(absolute_updates, relative_updates)), + } + + txn.execute(sql, qargs) + else: + self.database_engine.lock_table(txn, table) + retcols = list(chain(absolutes.keys(), additive_relatives.keys())) + current_row = self._simple_select_one_txn( + txn, table, keyvalues, retcols, allow_none=True + ) + if current_row is None: + merged_dict = {**keyvalues, **absolutes, **additive_relatives} + self._simple_insert_txn(txn, table, merged_dict) + else: + for (key, val) in additive_relatives.items(): + current_row[key] += val + current_row.update(absolutes) + self._simple_update_one_txn(txn, table, keyvalues, current_row) + + def _upsert_copy_from_table_with_additive_relatives_txn( + self, + txn, + into_table, + keyvalues, + extra_dst_keyvalues, + extra_dst_insvalues, + additive_relatives, + src_table, + copy_columns, + ): + """Updates the historic stats table with latest updates. + + This involves copying "absolute" fields from the `_current` table, and + adding relative fields to any existing values. + + Args: + txn: Transaction + into_table (str): The destination table to UPSERT the row into + keyvalues (dict[str, any]): Row-identifying key values + extra_dst_keyvalues (dict[str, any]): Additional keyvalues + for `into_table`. + extra_dst_insvalues (dict[str, any]): Additional values to insert + on new row creation for `into_table`. + additive_relatives (dict[str, any]): Fields that will be added onto + if existing row present. (Must be disjoint from copy_columns.) + src_table (str): The source table to copy from + copy_columns (iterable[str]): The list of columns to copy + """ + if self.database_engine.can_native_upsert: + ins_columns = chain( + keyvalues, + copy_columns, + additive_relatives, + extra_dst_keyvalues, + extra_dst_insvalues, + ) + sel_exprs = chain( + keyvalues, + copy_columns, + ( + "?" + for _ in chain( + additive_relatives, extra_dst_keyvalues, extra_dst_insvalues + ) + ), + ) + keyvalues_where = ("%s = ?" % f for f in keyvalues) + + sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns) + sets_ar = ( + "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f) + for f in additive_relatives + ) + + sql = """ + INSERT INTO %(into_table)s (%(ins_columns)s) + SELECT %(sel_exprs)s + FROM %(src_table)s + WHERE %(keyvalues_where)s + ON CONFLICT (%(keyvalues)s) + DO UPDATE SET %(sets)s + """ % { + "into_table": into_table, + "ins_columns": ", ".join(ins_columns), + "sel_exprs": ", ".join(sel_exprs), + "keyvalues_where": " AND ".join(keyvalues_where), + "src_table": src_table, + "keyvalues": ", ".join( + chain(keyvalues.keys(), extra_dst_keyvalues.keys()) + ), + "sets": ", ".join(chain(sets_cc, sets_ar)), + } + + qargs = list( + chain( + additive_relatives.values(), + extra_dst_keyvalues.values(), + extra_dst_insvalues.values(), + keyvalues.values(), + ) + ) + txn.execute(sql, qargs) + else: + self.database_engine.lock_table(txn, into_table) + src_row = self._simple_select_one_txn( + txn, src_table, keyvalues, copy_columns + ) + all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} + dest_current_row = self._simple_select_one_txn( + txn, + into_table, + keyvalues=all_dest_keyvalues, + retcols=list(chain(additive_relatives.keys(), copy_columns)), + allow_none=True, + ) + + if dest_current_row is None: + merged_dict = { + **keyvalues, + **extra_dst_keyvalues, + **extra_dst_insvalues, + **src_row, + **additive_relatives, + } + self._simple_insert_txn(txn, into_table, merged_dict) else: - sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % ( - table, - field, - field, - id_col, + for (key, val) in additive_relatives.items(): + src_row[key] = dest_current_row[key] + val + self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + + def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): + """Fetches the counts of events in the given range of stream IDs. + + Args: + min_pos (int) + max_pos (int) + + Returns: + Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field + changes. + """ + + return self.runInteraction( + "stats_incremental_total_events_and_bytes", + self.get_changes_room_total_events_and_bytes_txn, + min_pos, + max_pos, + ) + + def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos): + """Gets the total_events and total_event_bytes counts for rooms and + senders, in a range of stream_orderings (including backfilled events). + + Args: + txn + low_pos (int): Low stream ordering + high_pos (int): High stream ordering + + Returns: + tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The + room and user deltas for total_events/total_event_bytes in the + format of `stats_id` -> fields + """ + + if low_pos >= high_pos: + # nothing to do here. + return {}, {} + + if isinstance(self.database_engine, PostgresEngine): + new_bytes_expression = "OCTET_LENGTH(json)" + else: + new_bytes_expression = "LENGTH(CAST(json AS BLOB))" + + sql = """ + SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes + FROM events INNER JOIN event_json USING (event_id) + WHERE (? < stream_ordering AND stream_ordering <= ?) + OR (? <= stream_ordering AND stream_ordering <= ?) + GROUP BY events.room_id + """ % ( + new_bytes_expression, + ) + + txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) + + room_deltas = { + room_id: {"total_events": new_events, "total_event_bytes": new_bytes} + for room_id, new_events, new_bytes in txn + } + + sql = """ + SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes + FROM events INNER JOIN event_json USING (event_id) + WHERE (? < stream_ordering AND stream_ordering <= ?) + OR (? <= stream_ordering AND stream_ordering <= ?) + GROUP BY events.sender + """ % ( + new_bytes_expression, + ) + + txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) + + user_deltas = { + user_id: {"total_events": new_events, "total_event_bytes": new_bytes} + for user_id, new_events, new_bytes in txn + if self.hs.is_mine_id(user_id) + } + + return room_deltas, user_deltas + + @defer.inlineCallbacks + def _calculate_and_set_initial_state_for_room(self, room_id): + """Calculate and insert an entry into room_stats_current. + + Args: + room_id (str) + + Returns: + Deferred[tuple[dict, dict, int]]: A tuple of room state, membership + counts and stream position. + """ + + def _fetch_current_state_stats(txn): + pos = self.get_room_max_stream_ordering() + + rows = self._simple_select_many_txn( + txn, + table="current_state_events", + column="type", + iterable=[ + EventTypes.Create, + EventTypes.JoinRules, + EventTypes.RoomHistoryVisibility, + EventTypes.Encryption, + EventTypes.Name, + EventTypes.Topic, + EventTypes.RoomAvatar, + EventTypes.CanonicalAlias, + ], + keyvalues={"room_id": room_id, "state_key": ""}, + retcols=["event_id"], + ) + + event_ids = [row["event_id"] for row in rows] + + txn.execute( + """ + SELECT membership, count(*) FROM current_state_events + WHERE room_id = ? AND type = 'm.room.member' + GROUP BY membership + """, + (room_id,), + ) + membership_counts = {membership: cnt for membership, cnt in txn} + + txn.execute( + """ + SELECT COALESCE(count(*), 0) FROM current_state_events + WHERE room_id = ? + """, + (room_id,), + ) + + current_state_events_count, = txn.fetchone() + + users_in_room = self.get_users_in_room_txn(txn, room_id) + + return ( + event_ids, + membership_counts, + current_state_events_count, + users_in_room, + pos, + ) + + ( + event_ids, + membership_counts, + current_state_events_count, + users_in_room, + pos, + ) = yield self.runInteraction( + "get_initial_state_for_room", _fetch_current_state_stats + ) + + state_event_map = yield self.get_events(event_ids, get_prev_content=False) + + room_state = { + "join_rules": None, + "history_visibility": None, + "encryption": None, + "name": None, + "topic": None, + "avatar": None, + "canonical_alias": None, + "is_federatable": True, + } + + for event in state_event_map.values(): + if event.type == EventTypes.JoinRules: + room_state["join_rules"] = event.content.get("join_rule") + elif event.type == EventTypes.RoomHistoryVisibility: + room_state["history_visibility"] = event.content.get( + "history_visibility" ) - txn.execute(sql, (value, stats_id, current_ts)) + elif event.type == EventTypes.Encryption: + room_state["encryption"] = event.content.get("algorithm") + elif event.type == EventTypes.Name: + room_state["name"] = event.content.get("name") + elif event.type == EventTypes.Topic: + room_state["topic"] = event.content.get("topic") + elif event.type == EventTypes.RoomAvatar: + room_state["avatar"] = event.content.get("url") + elif event.type == EventTypes.CanonicalAlias: + room_state["canonical_alias"] = event.content.get("alias") + elif event.type == EventTypes.Create: + room_state["is_federatable"] = event.content.get("m.federate", True) + + yield self.update_room_state(room_id, room_state) + + local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] + + yield self.update_stats_delta( + ts=self.clock.time_msec(), + stats_type="room", + stats_id=room_id, + fields={}, + complete_with_stream_id=pos, + absolute_field_overrides={ + "current_state_events": current_state_events_count, + "joined_members": membership_counts.get(Membership.JOIN, 0), + "invited_members": membership_counts.get(Membership.INVITE, 0), + "left_members": membership_counts.get(Membership.LEAVE, 0), + "banned_members": membership_counts.get(Membership.BAN, 0), + "local_users_in_room": len(local_users_in_room), + }, + ) + + @defer.inlineCallbacks + def _calculate_and_set_initial_state_for_user(self, user_id): + def _calculate_and_set_initial_state_for_user_txn(txn): + pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) - return self.runInteraction("update_stats_delta", _update_stats_delta) + txn.execute( + """ + SELECT COUNT(distinct room_id) FROM current_state_events + WHERE type = 'm.room.member' AND state_key = ? + AND membership = 'join' + """, + (user_id,), + ) + count, = txn.fetchone() + return count, pos + + joined_rooms, pos = yield self.runInteraction( + "calculate_and_set_initial_state_for_user", + _calculate_and_set_initial_state_for_user_txn, + ) + + yield self.update_stats_delta( + ts=self.clock.time_msec(), + stats_type="user", + stats_id=user_id, + fields={}, + complete_with_stream_id=pos, + absolute_field_overrides={"joined_rooms": joined_rooms}, + ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index a8b858eb4f..7569b6fab5 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -13,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock - -from twisted.internet import defer - -from synapse.api.constants import EventTypes, Membership +from synapse import storage from synapse.rest import admin from synapse.rest.client.v1 import login, room from tests import unittest +# The expected number of state events in a fresh public room. +EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5 +# The expected number of state events in a fresh private room. +EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6 + class StatsRoomTests(unittest.HomeserverTestCase): @@ -33,7 +34,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() self.handler = self.hs.get_stats_handler() @@ -47,7 +47,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.get_success( self.store._simple_insert( "background_updates", - {"update_name": "populate_stats_createtables", "progress_json": "{}"}, + {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( @@ -56,7 +56,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): { "update_name": "populate_stats_process_rooms", "progress_json": "{}", - "depends_on": "populate_stats_createtables", + "depends_on": "populate_stats_prepare", }, ) ) @@ -64,18 +64,58 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._simple_insert( "background_updates", { - "update_name": "populate_stats_cleanup", + "update_name": "populate_stats_process_users", "progress_json": "{}", "depends_on": "populate_stats_process_rooms", }, ) ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_cleanup", + "progress_json": "{}", + "depends_on": "populate_stats_process_users", + }, + ) + ) + + def get_all_room_state(self): + return self.store._simple_select_list( + "room_stats_state", None, retcols=("name", "topic", "canonical_alias") + ) + + def _get_current_stats(self, stats_type, stat_id): + table, id_col = storage.stats.TYPE_TO_TABLE[stats_type] + + cols = list(storage.stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list( + storage.stats.PER_SLICE_FIELDS[stats_type] + ) + + end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) + + return self.get_success( + self.store._simple_select_one( + table + "_historical", + {id_col: stat_id, end_ts: end_ts}, + cols, + allow_none=True, + ) + ) + + def _perform_background_initial_update(self): + # Do the initial population of the stats via the background update + self._add_background_updates() + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) def test_initial_room(self): """ The background updates will build the table from scratch. """ - r = self.get_success(self.store.get_all_room_state()) + r = self.get_success(self.get_all_room_state()) self.assertEqual(len(r), 0) # Disable stats @@ -91,7 +131,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) # Stats disabled, shouldn't have done anything - r = self.get_success(self.store.get_all_room_state()) + r = self.get_success(self.get_all_room_state()) self.assertEqual(len(r), 0) # Enable stats @@ -104,7 +144,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): while not self.get_success(self.store.has_completed_background_updates()): self.get_success(self.store.do_next_background_update(100), by=0.1) - r = self.get_success(self.store.get_all_room_state()) + r = self.get_success(self.get_all_room_state()) self.assertEqual(len(r), 1) self.assertEqual(r[0]["topic"], "foo") @@ -114,6 +154,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): Ingestion via notify_new_event will ignore tokens that the background update have already processed. """ + self.reactor.advance(86401) self.hs.config.stats_enabled = False @@ -138,12 +179,18 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.hs.config.stats_enabled = True self.handler.stats_enabled = True self.store._all_done = False - self.get_success(self.store.update_stats_stream_pos(None)) + self.get_success( + self.store._simple_update_one( + table="stats_incremental_position", + keyvalues={}, + updatevalues={"stream_id": 0}, + ) + ) self.get_success( self.store._simple_insert( "background_updates", - {"update_name": "populate_stats_createtables", "progress_json": "{}"}, + {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) @@ -154,6 +201,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) self.helper.join(room=room_1, user=u2, tok=u2_token) + # orig_delta_processor = self.store. + # Now do the initial ingestion. self.get_success( self.store._simple_insert( @@ -185,8 +234,15 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token) self.helper.join(room=room_1, user=u3, tok=u3_token) - # Get the deltas! There should be two -- day 1, and day 2. - r = self.get_success(self.store.get_deltas_for_room(room_1, 0)) + # self.handler.notify_new_event() + + # We need to let the delta processor advance… + self.pump(10 * 60) + + # Get the slices! There should be two -- day 1, and day 2. + r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0)) + + self.assertEqual(len(r), 2) # The oldest has 2 joined members self.assertEqual(r[-1]["joined_members"], 2) @@ -194,111 +250,476 @@ class StatsRoomTests(unittest.HomeserverTestCase): # The newest has 3 self.assertEqual(r[0]["joined_members"], 3) - def test_incorrect_state_transition(self): - """ - If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to - (JOIN, INVITE, LEAVE, BAN), an error is raised. - """ - events = { - "a1": {"membership": Membership.LEAVE}, - "a2": {"membership": "not a real thing"}, - } - - def get_event(event_id, allow_none=True): - m = Mock() - m.content = events[event_id] - d = defer.Deferred() - self.reactor.callLater(0.0, d.callback, m) - return d - - def get_received_ts(event_id): - return defer.succeed(1) - - self.store.get_received_ts = get_received_ts - self.store.get_event = get_event - - deltas = [ - { - "type": EventTypes.Member, - "state_key": "some_user", - "room_id": "room", - "event_id": "a1", - "prev_event_id": "a2", - "stream_id": 60, - } - ] - - f = self.get_failure(self.handler._handle_deltas(deltas), ValueError) + def test_create_user(self): + """ + When we create a user, it should have statistics already ready. + """ + + u1 = self.register_user("u1", "pass") + + u1stats = self._get_current_stats("user", u1) + + self.assertIsNotNone(u1stats) + + # not in any rooms by default + self.assertEqual(u1stats["joined_rooms"], 0) + + def test_create_room(self): + """ + When we create a room, it should have statistics already ready. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + r1stats = self._get_current_stats("room", r1) + r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False) + r2stats = self._get_current_stats("room", r2) + + self.assertIsNotNone(r1stats) + self.assertIsNotNone(r2stats) + + # contains the default things you'd expect in a fresh room self.assertEqual( - f.value.args[0], "'not a real thing' is not a valid prev_membership" - ) - - # And the other way... - deltas = [ - { - "type": EventTypes.Member, - "state_key": "some_user", - "room_id": "room", - "event_id": "a2", - "prev_event_id": "a1", - "stream_id": 100, - } - ] - - f = self.get_failure(self.handler._handle_deltas(deltas), ValueError) + r1stats["total_events"], + EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM, + "Wrong number of total_events in new room's stats!" + " You may need to update this if more state events are added to" + " the room creation process.", + ) self.assertEqual( - f.value.args[0], "'not a real thing' is not a valid membership" + r2stats["total_events"], + EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM, + "Wrong number of total_events in new room's stats!" + " You may need to update this if more state events are added to" + " the room creation process.", ) - def test_redacted_prev_event(self): + self.assertEqual( + r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM + ) + self.assertEqual( + r2stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM + ) + + self.assertEqual(r1stats["joined_members"], 1) + self.assertEqual(r1stats["invited_members"], 0) + self.assertEqual(r1stats["banned_members"], 0) + + self.assertEqual(r2stats["joined_members"], 1) + self.assertEqual(r2stats["invited_members"], 0) + self.assertEqual(r2stats["banned_members"], 0) + + def test_send_message_increments_total_events(self): """ - If the prev_event does not exist, then it is assumed to be a LEAVE. + When we send a message, it increments total_events. """ + + self._perform_background_initial_update() + u1 = self.register_user("u1", "pass") - u1_token = self.login("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + r1stats_ante = self._get_current_stats("room", r1) - room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.helper.send(r1, "hiss", tok=u1token) - # Do the initial population of the user directory via the background update - self._add_background_updates() + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + + def test_send_state_event_nonoverwriting(self): + """ + When we send a non-overwriting state event, it increments total_events AND current_state_events + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + self.helper.send_state( + r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby" + ) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.send_state( + r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy" + ) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 1, + ) + + def test_send_state_event_overwriting(self): + """ + When we send an overwriting state event, it increments total_events ONLY + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + self.helper.send_state( + r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby" + ) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.send_state( + r1, "cat.hissing", {"value": False}, tok=u1token, state_key="tabby" + ) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 0, + ) + + def test_join_first_time(self): + """ + When a user joins a room for the first time, total_events, current_state_events and + joined_members should increase by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.join(r1, u2, tok=u2token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 1, + ) + self.assertEqual( + r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1 + ) + + def test_join_after_leave(self): + """ + When a user joins a room after being previously left, total_events and + joined_members should increase by exactly 1. + current_state_events should not increase. + left_members should decrease by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + + self.helper.join(r1, u2, tok=u2token) + self.helper.leave(r1, u2, tok=u2token) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.join(r1, u2, tok=u2token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 0, + ) + self.assertEqual( + r1stats_post["joined_members"] - r1stats_ante["joined_members"], +1 + ) + self.assertEqual( + r1stats_post["left_members"] - r1stats_ante["left_members"], -1 + ) + + def test_invited(self): + """ + When a user invites another user, current_state_events, total_events and + invited_members should increase by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.invite(r1, u1, u2, tok=u1token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 1, + ) + self.assertEqual( + r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1 + ) + + def test_join_after_invite(self): + """ + When a user joins a room after being invited, total_events and + joined_members should increase by exactly 1. + current_state_events should not increase. + invited_members should decrease by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + + self.helper.invite(r1, u1, u2, tok=u1token) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.join(r1, u2, tok=u2token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 0, + ) + self.assertEqual( + r1stats_post["joined_members"] - r1stats_ante["joined_members"], +1 + ) + self.assertEqual( + r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1 + ) + + def test_left(self): + """ + When a user leaves a room after joining, total_events and + left_members should increase by exactly 1. + current_state_events should not increase. + joined_members should decrease by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + + self.helper.join(r1, u2, tok=u2token) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.leave(r1, u2, tok=u2token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 0, + ) + self.assertEqual( + r1stats_post["left_members"] - r1stats_ante["left_members"], +1 + ) + self.assertEqual( + r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 + ) + + def test_banned(self): + """ + When a user is banned from a room after joining, total_events and + left_members should increase by exactly 1. + current_state_events should not increase. + banned_members should decrease by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + + self.helper.join(r1, u2, tok=u2token) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.change_membership(r1, u1, u2, "ban", tok=u1token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 0, + ) + self.assertEqual( + r1stats_post["banned_members"] - r1stats_ante["banned_members"], +1 + ) + self.assertEqual( + r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 + ) + + def test_initial_background_update(self): + """ + Test that statistics can be generated by the initial background update + handler. + + This test also tests that stats rows are not created for new subjects + when stats are disabled. However, it may be desirable to change this + behaviour eventually to still keep current rows. + """ + + self.hs.config.stats_enabled = False + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + # test that these subjects, which were created during a time of disabled + # stats, do not have stats. + self.assertIsNone(self._get_current_stats("room", r1)) + self.assertIsNone(self._get_current_stats("user", u1)) + + self.hs.config.stats_enabled = True + + self._perform_background_initial_update() + + r1stats = self._get_current_stats("room", r1) + u1stats = self._get_current_stats("user", u1) + + self.assertEqual(r1stats["joined_members"], 1) + self.assertEqual( + r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM + ) + + self.assertEqual(u1stats["joined_rooms"], 1) + + def test_incomplete_stats(self): + """ + This tests that we track incomplete statistics. + + We first test that incomplete stats are incrementally generated, + following the preparation of a background regen. + + We then test that these incomplete rows are completed by the background + regen. + """ + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + u3 = self.register_user("u3", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token, is_public=False) + + # preparation stage of the initial background update + # Ugh, have to reset this flag + self.store._all_done = False + + self.get_success( + self.store._simple_delete( + "room_stats_current", {"1": 1}, "test_delete_stats" + ) + ) + self.get_success( + self.store._simple_delete( + "user_stats_current", {"1": 1}, "test_delete_stats" + ) + ) + + self.helper.invite(r1, u1, u2, tok=u1token) + self.helper.join(r1, u2, tok=u2token) + self.helper.invite(r1, u1, u3, tok=u1token) + self.helper.send(r1, "thou shalt yield", tok=u1token) + + # now do the background updates + + self.store._all_done = False + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + "depends_on": "populate_stats_prepare", + }, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_process_users", + "progress_json": "{}", + "depends_on": "populate_stats_process_rooms", + }, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_cleanup", + "progress_json": "{}", + "depends_on": "populate_stats_process_users", + }, + ) + ) while not self.get_success(self.store.has_completed_background_updates()): self.get_success(self.store.do_next_background_update(100), by=0.1) - events = {"a1": None, "a2": {"membership": Membership.JOIN}} - - def get_event(event_id, allow_none=True): - if events.get(event_id): - m = Mock() - m.content = events[event_id] - else: - m = None - d = defer.Deferred() - self.reactor.callLater(0.0, d.callback, m) - return d - - def get_received_ts(event_id): - return defer.succeed(1) - - self.store.get_received_ts = get_received_ts - self.store.get_event = get_event - - deltas = [ - { - "type": EventTypes.Member, - "state_key": "some_user:test", - "room_id": room_1, - "event_id": "a2", - "prev_event_id": "a1", - "stream_id": 100, - } - ] - - # Handle our fake deltas, which has a user going from LEAVE -> JOIN. - self.get_success(self.handler._handle_deltas(deltas)) - - # One delta, with two joined members -- the room creator, and our fake - # user. - r = self.get_success(self.store.get_deltas_for_room(room_1, 0)) - self.assertEqual(len(r), 1) - self.assertEqual(r[0]["joined_members"], 2) + r1stats_complete = self._get_current_stats("room", r1) + u1stats_complete = self._get_current_stats("user", u1) + u2stats_complete = self._get_current_stats("user", u2) + + # now we make our assertions + + # check that _complete rows are complete and correct + self.assertEqual(r1stats_complete["joined_members"], 2) + self.assertEqual(r1stats_complete["invited_members"], 1) + + self.assertEqual( + r1stats_complete["current_state_events"], + 2 + EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM, + ) + + self.assertEqual(u1stats_complete["joined_rooms"], 1) + self.assertEqual(u2stats_complete["joined_rooms"], 1) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 9915367144..cdded88b7f 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -128,8 +128,12 @@ class RestHelper(object): return channel.json_body - def send_state(self, room_id, event_type, body, tok, expect_code=200): - path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type) + def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""): + path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( + room_id, + event_type, + state_key, + ) if tok: path = path + "?access_token=%s" % tok -- cgit 1.5.1 From e059c5e6487f0e86e395f1e79d400ae0b0d24a82 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 6 Sep 2019 13:10:11 +0100 Subject: Move get_threepid_validation_session into RegistrationWorkerStore --- synapse/storage/registration.py | 108 ++++++++++++++++++++-------------------- 1 file changed, 54 insertions(+), 54 deletions(-) (limited to 'synapse/storage/registration.py') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 2d3c7e2dc9..24509bd455 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -614,6 +614,60 @@ class RegistrationWorkerStore(SQLBaseStore): # Convert the integer into a boolean. return res == 1 + def get_threepid_validation_session( + self, medium, client_secret, address=None, sid=None, validated=True + ): + """Gets a session_id and last_send_attempt (if available) for a + client_secret/medium/(address|session_id) combo + + Args: + medium (str|None): The medium of the 3PID + address (str|None): The address of the 3PID + sid (str|None): The ID of the validation session + client_secret (str|None): A unique string provided by the client to + help identify this validation attempt + validated (bool|None): Whether sessions should be filtered by + whether they have been validated already or not. None to + perform no filtering + + Returns: + deferred {str, int}|None: A dict containing the + latest session_id and send_attempt count for this 3PID. + Otherwise None if there hasn't been a previous attempt + """ + keyvalues = {"medium": medium, "client_secret": client_secret} + if address: + keyvalues["address"] = address + if sid: + keyvalues["session_id"] = sid + + assert address or sid + + def get_threepid_validation_session_txn(txn): + sql = """ + SELECT address, session_id, medium, client_secret, + last_send_attempt, validated_at + FROM threepid_validation_session WHERE %s + """ % ( + " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), + ) + + if validated is not None: + sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") + + sql += " LIMIT 1" + + txn.execute(sql, list(keyvalues.values())) + rows = self.cursor_to_dict(txn) + if not rows: + return None + + return rows[0] + + return self.runInteraction( + "get_threepid_validation_session", get_threepid_validation_session_txn + ) + class RegistrationStore( RegistrationWorkerStore, background_updates.BackgroundUpdateStore @@ -1082,60 +1136,6 @@ class RegistrationStore( return 1 - def get_threepid_validation_session( - self, medium, client_secret, address=None, sid=None, validated=True - ): - """Gets a session_id and last_send_attempt (if available) for a - client_secret/medium/(address|session_id) combo - - Args: - medium (str|None): The medium of the 3PID - address (str|None): The address of the 3PID - sid (str|None): The ID of the validation session - client_secret (str|None): A unique string provided by the client to - help identify this validation attempt - validated (bool|None): Whether sessions should be filtered by - whether they have been validated already or not. None to - perform no filtering - - Returns: - deferred {str, int}|None: A dict containing the - latest session_id and send_attempt count for this 3PID. - Otherwise None if there hasn't been a previous attempt - """ - keyvalues = {"medium": medium, "client_secret": client_secret} - if address: - keyvalues["address"] = address - if sid: - keyvalues["session_id"] = sid - - assert address or sid - - def get_threepid_validation_session_txn(txn): - sql = """ - SELECT address, session_id, medium, client_secret, - last_send_attempt, validated_at - FROM threepid_validation_session WHERE %s - """ % ( - " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), - ) - - if validated is not None: - sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") - - sql += " LIMIT 1" - - txn.execute(sql, list(keyvalues.values())) - rows = self.cursor_to_dict(txn) - if not rows: - return None - - return rows[0] - - return self.runInteraction( - "get_threepid_validation_session", get_threepid_validation_session_txn - ) - def validate_threepid_session(self, session_id, client_secret, token, current_ts): """Attempt to validate a threepid session using a token -- cgit 1.5.1 From 6ddda8152ed4f8111d8108f6f2f92f365b9069ba Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 6 Sep 2019 13:23:10 +0100 Subject: Move delete_threepid_session into RegistrationWorkerStore --- synapse/storage/registration.py | 50 ++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 25 deletions(-) (limited to 'synapse/storage/registration.py') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 24509bd455..5138792a5f 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -668,6 +668,31 @@ class RegistrationWorkerStore(SQLBaseStore): "get_threepid_validation_session", get_threepid_validation_session_txn ) + def delete_threepid_session(self, session_id): + """Removes a threepid validation session from the database. This can + be done after validation has been performed and whatever action was + waiting on it has been carried out + + Args: + session_id (str): The ID of the session to delete + """ + + def delete_threepid_session_txn(txn): + self._simple_delete_txn( + txn, + table="threepid_validation_token", + keyvalues={"session_id": session_id}, + ) + self._simple_delete_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + ) + + return self.runInteraction( + "delete_threepid_session", delete_threepid_session_txn + ) + class RegistrationStore( RegistrationWorkerStore, background_updates.BackgroundUpdateStore @@ -1323,31 +1348,6 @@ class RegistrationStore( self.clock.time_msec(), ) - def delete_threepid_session(self, session_id): - """Removes a threepid validation session from the database. This can - be done after validation has been performed and whatever action was - waiting on it has been carried out - - Args: - session_id (str): The ID of the session to delete - """ - - def delete_threepid_session_txn(txn): - self._simple_delete_txn( - txn, - table="threepid_validation_token", - keyvalues={"session_id": session_id}, - ) - self._simple_delete_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - ) - - return self.runInteraction( - "delete_threepid_session", delete_threepid_session_txn - ) - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): self._simple_update_one_txn( txn=txn, -- cgit 1.5.1 From be618e055178f4aa9865ab426182218312bed07f Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Mon, 9 Sep 2019 14:43:51 +0300 Subject: Only count real users when checking for auto-creation of auto-join room Previously if the first registered user was a "support" or "bot" user, when the first real user registers, the auto-join rooms were not created. Fix to exclude non-real (ie users with a special user type) users when counting how many users there are to determine whether we should auto-create a room. Signed-off-by: Jason Robinson --- changelog.d/6004.bugfix | 1 + synapse/handlers/register.py | 12 ++++-------- synapse/storage/registration.py | 39 +++++++++++++++++++++++++++++++++++++++ tests/handlers/test_register.py | 29 +++++++++++++++++++++++++++-- 4 files changed, 71 insertions(+), 10 deletions(-) create mode 100644 changelog.d/6004.bugfix (limited to 'synapse/storage/registration.py') diff --git a/changelog.d/6004.bugfix b/changelog.d/6004.bugfix new file mode 100644 index 0000000000..45c179c8fd --- /dev/null +++ b/changelog.d/6004.bugfix @@ -0,0 +1 @@ +Only count real users when checking for auto-creation of auto-join room. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 975da57ffd..06bd03b77c 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -275,16 +275,12 @@ class RegistrationHandler(BaseHandler): fake_requester = create_requester(user_id) # try to create the room if we're the first real user on the server. Note - # that an auto-generated support user is not a real user and will never be + # that an auto-generated support or bot user is not a real user and will never be # the user to create the room should_auto_create_rooms = False - is_support = yield self.store.is_support_user(user_id) - # There is an edge case where the first user is the support user, then - # the room is never created, though this seems unlikely and - # recoverable from given the support user being involved in the first - # place. - if self.hs.config.autocreate_auto_join_rooms and not is_support: - count = yield self.store.count_all_users() + is_real_user = yield self.store.is_real_user(user_id) + if self.hs.config.autocreate_auto_join_rooms and is_real_user: + count = yield self.store.count_real_users() should_auto_create_rooms = count == 1 for r in self.hs.config.auto_join_rooms: logger.info("Auto-joining %s to %s", user_id, r) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 5138792a5f..b054d86ae0 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -322,6 +322,21 @@ class RegistrationWorkerStore(SQLBaseStore): return None + @cachedInlineCallbacks() + def is_real_user(self, user_id): + """Determines if the user is a real user, ie does not have a 'user_type'. + + Args: + user_id (str): user id to test + + Returns: + Deferred[bool]: True if user 'user_type' is null or empty string + """ + res = yield self.runInteraction( + "is_real_user", self.is_real_user_txn, user_id + ) + return res + @cachedInlineCallbacks() def is_support_user(self, user_id): """Determines if the user is of type UserTypes.SUPPORT @@ -337,6 +352,16 @@ class RegistrationWorkerStore(SQLBaseStore): ) return res + def is_real_user_txn(self, txn, user_id): + res = self._simple_select_one_onecol_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + retcol="user_type", + allow_none=True, + ) + return True if res is None or res == "" else False + def is_support_user_txn(self, txn, user_id): res = self._simple_select_one_onecol_txn( txn=txn, @@ -421,6 +446,20 @@ class RegistrationWorkerStore(SQLBaseStore): ret = yield self.runInteraction("count_users", _count_users) return ret + @defer.inlineCallbacks + def count_real_users(self): + """Counts all users without a special user_type registered on the homeserver.""" + + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null or user_type = ''") + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_real_users", _count_users) + return ret + @defer.inlineCallbacks def find_next_generated_user_id_localpart(self): """ diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e10296a5e4..1e9ba3a201 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -171,11 +171,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase): rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) - def test_auto_create_auto_join_rooms_when_support_user_exists(self): + def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_support_user = Mock(return_value=True) + self.store.is_real_user = Mock(return_value=False) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -183,6 +183,31 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias = RoomAlias.from_string(room_alias_str) self.get_failure(directory_handler.get_association(room_alias), SynapseError) + def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + + self.store.count_real_users = Mock(return_value=1) + self.store.is_real_user = Mock(return_value=True) + user_id = self.get_success(self.handler.register_user(localpart="real")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + self.assertTrue(room_id["room_id"] in rooms) + self.assertEqual(len(rooms), 1) + + def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + + self.store.count_real_users = Mock(return_value=2) + self.store.is_real_user = Mock(return_value=True) + user_id = self.get_success(self.handler.register_user(localpart="real")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 0) + def test_auto_create_auto_join_where_no_consent(self): """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. -- cgit 1.5.1 From 62fac9d969cea98694093a5f80bed6bdd4848968 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Mon, 9 Sep 2019 14:59:35 +0300 Subject: Auto-fix a few code style issues Signed-off-by: Jason Robinson --- synapse/storage/registration.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'synapse/storage/registration.py') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index b054d86ae0..9387b29503 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -332,9 +332,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if user 'user_type' is null or empty string """ - res = yield self.runInteraction( - "is_real_user", self.is_real_user_txn, user_id - ) + res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id) return res @cachedInlineCallbacks() @@ -451,7 +449,9 @@ class RegistrationWorkerStore(SQLBaseStore): """Counts all users without a special user_type registered on the homeserver.""" def _count_users(txn): - txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null or user_type = ''") + txn.execute( + "SELECT COUNT(*) AS users FROM users where user_type is null or user_type = ''" + ) rows = self.cursor_to_dict(txn) if rows: return rows[0]["users"] -- cgit 1.5.1 From 8c03cd0e5f73fb59ee773dc6cce77f2dc4dab827 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Mon, 9 Sep 2019 16:40:40 +0300 Subject: Simplify is_real_user_txn check to trust user_type is null if real user Signed-off-by: Jason Robinson --- synapse/storage/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage/registration.py') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 9387b29503..54b0846c54 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -358,7 +358,7 @@ class RegistrationWorkerStore(SQLBaseStore): retcol="user_type", allow_none=True, ) - return True if res is None or res == "" else False + return res is None def is_support_user_txn(self, txn, user_id): res = self._simple_select_one_onecol_txn( -- cgit 1.5.1 From e89fea4f04c6fc7df41c5cade63609b513a98073 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Mon, 9 Sep 2019 16:43:32 +0300 Subject: Simplify count_real_users SQL to only count user_type is null rows Signed-off-by: Jason Robinson --- synapse/storage/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage/registration.py') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 54b0846c54..c0ca25733b 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -450,7 +450,7 @@ class RegistrationWorkerStore(SQLBaseStore): def _count_users(txn): txn.execute( - "SELECT COUNT(*) AS users FROM users where user_type is null or user_type = ''" + "SELECT COUNT(*) AS users FROM users where user_type is null" ) rows = self.cursor_to_dict(txn) if rows: -- cgit 1.5.1 From aaed6b39e140195a0f2b48e4de0519e08f16a119 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Mon, 9 Sep 2019 17:10:02 +0300 Subject: Fix code style, again Signed-off-by: Jason Robinson --- synapse/storage/registration.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'synapse/storage/registration.py') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index c0ca25733b..109052fa41 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -449,9 +449,7 @@ class RegistrationWorkerStore(SQLBaseStore): """Counts all users without a special user_type registered on the homeserver.""" def _count_users(txn): - txn.execute( - "SELECT COUNT(*) AS users FROM users where user_type is null" - ) + txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") rows = self.cursor_to_dict(txn) if rows: return rows[0]["users"] -- cgit 1.5.1 From a8ac40445c98b9e1fc2538d7d4ec49c80b0298ac Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 13 Sep 2019 15:20:49 +0100 Subject: Record mappings from saml users in an external table We want to assign unique mxids to saml users based on an incrementing suffix. For that to work, we need to record the allocated mxid in a separate table. --- docs/sample_config.yaml | 26 ++++++ synapse/config/saml2_config.py | 78 +++++++++++++++- synapse/handlers/saml_handler.py | 103 +++++++++++++++++++-- synapse/rest/client/v1/login.py | 14 +++ synapse/storage/registration.py | 41 ++++++++ .../storage/schema/delta/56/user_external_ids.sql | 24 +++++ 6 files changed, 276 insertions(+), 10 deletions(-) create mode 100644 synapse/storage/schema/delta/56/user_external_ids.sql (limited to 'synapse/storage/registration.py') diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 8cfc5c312a..9021fe2cb8 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1099,6 +1099,32 @@ saml2_config: # #saml_session_lifetime: 5m + # The SAML attribute (after mapping via the attribute maps) to use to derive + # the Matrix ID from. 'uid' by default. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a matrix ID. + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with '.'). + # The default is 'hexencode'. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to MXID was + # always calculated dynamically rather than stored in a table. For backwards- + # compatibility, we will look for user_ids matching such a pattern before + # creating a new account. + # + # This setting controls the SAML attribute which will be used for this + # backwards-compatibility lookup. Typically it should be 'uid', but if the + # attribute maps are changed, it may be necessary to change it. + # + # The default is 'uid'. + # + #grandfathered_mxid_source_attribute: upn + # Enable CAS for registration and login. diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index c46ac087db..a022470702 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -12,7 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re + from synapse.python_dependencies import DependencyException, check_requirements +from synapse.types import ( + map_username_to_mxid_localpart, + mxid_localpart_allowed_characters, +) from ._base import Config, ConfigError @@ -36,6 +42,14 @@ class SAML2Config(Config): self.saml2_enabled = True + self.saml2_mxid_source_attribute = saml2_config.get( + "mxid_source_attribute", "uid" + ) + + self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( + "grandfathered_mxid_source_attribute", "uid" + ) + import saml2.config self.saml2_sp_config = saml2.config.SPConfig() @@ -51,6 +65,12 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "5m") ) + mapping = saml2_config.get("mxid_mapping", "hexencode") + try: + self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping] + except KeyError: + raise ConfigError("%s is not a known mxid_mapping" % (mapping,)) + def _default_saml_config_dict(self): import saml2 @@ -58,6 +78,13 @@ class SAML2Config(Config): if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") + required_attributes = {"uid", self.saml2_mxid_source_attribute} + + optional_attributes = {"displayName"} + if self.saml2_grandfathered_mxid_source_attribute: + optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) + optional_attributes -= required_attributes + metadata_url = public_baseurl + "_matrix/saml2/metadata.xml" response_url = public_baseurl + "_matrix/saml2/authn_response" return { @@ -69,8 +96,9 @@ class SAML2Config(Config): (response_url, saml2.BINDING_HTTP_POST) ] }, - "required_attributes": ["uid"], - "optional_attributes": ["mail", "surname", "givenname"], + "required_attributes": list(required_attributes), + "optional_attributes": list(optional_attributes), + # "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT, } }, } @@ -146,6 +174,52 @@ class SAML2Config(Config): # The default is 5 minutes. # #saml_session_lifetime: 5m + + # The SAML attribute (after mapping via the attribute maps) to use to derive + # the Matrix ID from. 'uid' by default. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a matrix ID. + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with '.'). + # The default is 'hexencode'. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to MXID was + # always calculated dynamically rather than stored in a table. For backwards- + # compatibility, we will look for user_ids matching such a pattern before + # creating a new account. + # + # This setting controls the SAML attribute which will be used for this + # backwards-compatibility lookup. Typically it should be 'uid', but if the + # attribute maps are changed, it may be necessary to change it. + # + # The default is 'uid'. + # + #grandfathered_mxid_source_attribute: upn """ % { "config_dir_path": config_dir_path } + + +DOT_REPLACE_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) +) + + +def dot_replace_for_mxid(username: str) -> str: + username = username.lower() + username = DOT_REPLACE_PATTERN.sub(".", username) + + # regular mxids aren't allowed to start with an underscore either + username = re.sub("^_", "", username) + return username + + +MXID_MAPPER_MAP = { + "hexencode": map_username_to_mxid_localpart, + "dotreplace": dot_replace_for_mxid, +} diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a1ce6929cf..5fa8272dc9 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -21,6 +21,8 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.http.servlet import parse_string from synapse.rest.client.v1.login import SSOAuthHandler +from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) @@ -29,12 +31,26 @@ class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._sso_auth_handler = SSOAuthHandler(hs) + self._registration_handler = hs.get_registration_handler() + + self._clock = hs.get_clock() + self._datastore = hs.get_datastore() + self._hostname = hs.hostname + self._saml2_session_lifetime = hs.config.saml2_session_lifetime + self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute + self._grandfathered_mxid_source_attribute = ( + hs.config.saml2_grandfathered_mxid_source_attribute + ) + self._mxid_mapper = hs.config.saml2_mxid_mapper + + # identifier for the external_ids table + self._auth_provider_id = "saml" # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} - self._clock = hs.get_clock() - self._saml2_session_lifetime = hs.config.saml2_session_lifetime + # a lock on the mappings + self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) def handle_redirect_request(self, client_redirect_url): """Handle an incoming request to /login/sso/redirect @@ -60,7 +76,7 @@ class SamlHandler: # this shouldn't happen! raise Exception("prepare_for_authenticate didn't return a Location header") - def handle_saml_response(self, request): + async def handle_saml_response(self, request): """Handle an incoming request to /_matrix/saml2/authn_response Args: @@ -77,6 +93,10 @@ class SamlHandler: # the dict. self.expire_sessions() + user_id = await self._map_saml_response_to_user(resp_bytes) + self._sso_auth_handler.complete_sso_login(user_id, request, relay_state) + + async def _map_saml_response_to_user(self, resp_bytes): try: saml2_auth = self._saml_client.parse_authn_request_response( resp_bytes, @@ -91,18 +111,85 @@ class SamlHandler: logger.warning("SAML2 response was not signed") raise SynapseError(400, "SAML2 response was not signed") - if "uid" not in saml2_auth.ava: + try: + remote_user_id = saml2_auth.ava["uid"][0] + except KeyError: logger.warning("SAML2 response lacks a 'uid' attestation") raise SynapseError(400, "uid not in SAML2 response") + try: + mxid_source = saml2_auth.ava[self._mxid_source_attribute][0] + except KeyError: + logger.warning( + "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute + ) + raise SynapseError( + 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + ) + self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) - username = saml2_auth.ava["uid"][0] displayName = saml2_auth.ava.get("displayName", [None])[0] - return self._sso_auth_handler.on_successful_auth( - username, request, relay_state, user_display_name=displayName - ) + with (await self._mapping_lock.queue(self._auth_provider_id)): + # first of all, check if we already have a mapping for this user + logger.info( + "Looking for existing mapping for user %s:%s", + self._auth_provider_id, + remote_user_id, + ) + registered_user_id = await self._datastore.get_user_by_external_id( + self._auth_provider_id, remote_user_id + ) + if registered_user_id is not None: + logger.info("Found existing mapping %s", registered_user_id) + return registered_user_id + + # backwards-compatibility hack: see if there is an existing user with a + # suitable mapping from the uid + if ( + self._grandfathered_mxid_source_attribute + and self._grandfathered_mxid_source_attribute in saml2_auth.ava + ): + attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0] + user_id = UserID( + map_username_to_mxid_localpart(attrval), self._hostname + ).to_string() + logger.info( + "Looking for existing account based on mapped %s %s", + self._grandfathered_mxid_source_attribute, + user_id, + ) + + users = await self._datastore.get_users_by_id_case_insensitive(user_id) + if users: + registered_user_id = list(users.keys())[0] + logger.info("Grandfathering mapping to %s", registered_user_id) + await self._datastore.record_user_external_id( + self._auth_provider_id, remote_user_id, registered_user_id + ) + return registered_user_id + + # figure out a new mxid for this user + base_mxid_localpart = self._mxid_mapper(mxid_source) + + suffix = 0 + while True: + localpart = base_mxid_localpart + (str(suffix) if suffix else "") + if not await self._datastore.get_users_by_id_case_insensitive( + UserID(localpart, self._hostname).to_string() + ): + break + suffix += 1 + logger.info("Allocating mxid for new user with localpart %s", localpart) + + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=displayName + ) + await self._datastore.record_user_external_id( + self._auth_provider_id, remote_user_id, registered_user_id + ) + return registered_user_id def expire_sessions(self): expire_before = self._clock.time_msec() - self._saml2_session_lifetime diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 5762b9fd06..eeaa72b205 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -29,6 +29,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import UserID, map_username_to_mxid_localpart @@ -507,6 +508,19 @@ class SSOAuthHandler(object): localpart=localpart, default_display_name=user_display_name ) + self.complete_sso_login(registered_user_id, request, client_redirect_url) + + def complete_sso_login( + self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str + ): + """Having figured out a mxid for this user, complete the HTTP request + + Args: + registered_user_id: + request: + client_redirect_url: + """ + login_token = self._macaroon_gen.generate_short_term_login_token( registered_user_id ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 55e4e84d71..1e3c2148f6 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -22,6 +22,7 @@ from six import iterkeys from six.moves import range from twisted.internet import defer +from twisted.internet.defer import Deferred from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, ThreepidValidationError @@ -337,6 +338,26 @@ class RegistrationWorkerStore(SQLBaseStore): return self.runInteraction("get_users_by_id_case_insensitive", f) + async def get_user_by_external_id( + self, auth_provider: str, external_id: str + ) -> str: + """Look up a user by their external auth id + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + + Returns: + str|None: the mxid of the user, or None if they are not known + """ + return await self._simple_select_one_onecol( + table="user_external_ids", + keyvalues={"auth_provider": auth_provider, "external_id": external_id}, + retcol="user_id", + allow_none=True, + desc="get_user_by_external_id", + ) + @defer.inlineCallbacks def count_all_users(self): """Counts all users registered on the homeserver.""" @@ -848,6 +869,26 @@ class RegistrationStore( self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) + def record_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> Deferred: + """Record a mapping from an external user id to a mxid + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + return self._simple_insert( + table="user_external_ids", + values={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="record_user_external_id", + ) + def user_set_password_hash(self, user_id, password_hash): """ NB. This does *not* evict any cache because the one use for this diff --git a/synapse/storage/schema/delta/56/user_external_ids.sql b/synapse/storage/schema/delta/56/user_external_ids.sql new file mode 100644 index 0000000000..91390c4527 --- /dev/null +++ b/synapse/storage/schema/delta/56/user_external_ids.sql @@ -0,0 +1,24 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * a table which records mappings from external auth providers to mxids + */ +CREATE TABLE IF NOT EXISTS user_external_ids ( + auth_provider TEXT NOT NULL, + external_id TEXT NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (auth_provider, external_id) +); -- cgit 1.5.1 From df3401a71d78088da36a03c73d35bc116c712df6 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 20 Sep 2019 15:21:30 +0100 Subject: Allow HS to send emails when adding an email to the HS (#6042) --- changelog.d/6042.feature | 1 + docs/sample_config.yaml | 12 ++ synapse/config/emailconfig.py | 36 ++++ synapse/handlers/identity.py | 17 +- synapse/push/mailer.py | 29 +++ synapse/res/templates/add_threepid.html | 9 + synapse/res/templates/add_threepid.txt | 6 + synapse/res/templates/add_threepid_failure.html | 8 + synapse/res/templates/add_threepid_success.html | 6 + synapse/rest/client/v2_alpha/account.py | 252 ++++++++++++++++++++---- synapse/rest/client/v2_alpha/register.py | 24 +-- synapse/storage/registration.py | 31 ++- 12 files changed, 359 insertions(+), 72 deletions(-) create mode 100644 changelog.d/6042.feature create mode 100644 synapse/res/templates/add_threepid.html create mode 100644 synapse/res/templates/add_threepid.txt create mode 100644 synapse/res/templates/add_threepid_failure.html create mode 100644 synapse/res/templates/add_threepid_success.html (limited to 'synapse/storage/registration.py') diff --git a/changelog.d/6042.feature b/changelog.d/6042.feature new file mode 100644 index 0000000000..a737760363 --- /dev/null +++ b/changelog.d/6042.feature @@ -0,0 +1 @@ +Allow homeserver to handle or delegate email validation when adding an email to a user's account. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 3e4edc6b0b..61d9f09a99 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1261,6 +1261,12 @@ password_config: # #registration_template_html: registration.html # #registration_template_text: registration.txt # +# # Templates for validation emails sent by the homeserver when adding an email to +# # your user account +# # +# #add_threepid_template_html: add_threepid.html +# #add_threepid_template_text: add_threepid.txt +# # # Templates for password reset success and failure pages that a user # # will see after attempting to reset their password # # @@ -1272,6 +1278,12 @@ password_config: # # # #registration_template_success_html: registration_success.html # #registration_template_failure_html: registration_failure.html +# +# # Templates for success and failure pages that a user will see after attempting +# # to add an email or phone to their account +# # +# #add_threepid_success_html: add_threepid_success.html +# #add_threepid_failure_html: add_threepid_failure.html #password_providers: diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index e5de768b0c..d9b43de660 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -169,12 +169,22 @@ class EmailConfig(Config): self.email_registration_template_text = email_config.get( "registration_template_text", "registration.txt" ) + self.email_add_threepid_template_html = email_config.get( + "add_threepid_template_html", "add_threepid.html" + ) + self.email_add_threepid_template_text = email_config.get( + "add_threepid_template_text", "add_threepid.txt" + ) + self.email_password_reset_template_failure_html = email_config.get( "password_reset_template_failure_html", "password_reset_failure.html" ) self.email_registration_template_failure_html = email_config.get( "registration_template_failure_html", "registration_failure.html" ) + self.email_add_threepid_template_failure_html = email_config.get( + "add_threepid_template_failure_html", "add_threepid_failure.html" + ) # These templates do not support any placeholder variables, so we # will read them from disk once during setup @@ -184,6 +194,9 @@ class EmailConfig(Config): email_registration_template_success_html = email_config.get( "registration_template_success_html", "registration_success.html" ) + email_add_threepid_template_success_html = email_config.get( + "add_threepid_template_success_html", "add_threepid_success.html" + ) # Check templates exist for f in [ @@ -191,9 +204,14 @@ class EmailConfig(Config): self.email_password_reset_template_text, self.email_registration_template_html, self.email_registration_template_text, + self.email_add_threepid_template_html, + self.email_add_threepid_template_text, self.email_password_reset_template_failure_html, + self.email_registration_template_failure_html, + self.email_add_threepid_template_failure_html, email_password_reset_template_success_html, email_registration_template_success_html, + email_add_threepid_template_success_html, ]: p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): @@ -212,6 +230,12 @@ class EmailConfig(Config): self.email_registration_template_success_html_content = self.read_file( filepath, "email.registration_template_success_html" ) + filepath = os.path.join( + self.email_template_dir, email_add_threepid_template_success_html + ) + self.email_add_threepid_template_success_html_content = self.read_file( + filepath, "email.add_threepid_template_success_html" + ) if self.email_enable_notifs: required = [ @@ -328,6 +352,12 @@ class EmailConfig(Config): # #registration_template_html: registration.html # #registration_template_text: registration.txt # + # # Templates for validation emails sent by the homeserver when adding an email to + # # your user account + # # + # #add_threepid_template_html: add_threepid.html + # #add_threepid_template_text: add_threepid.txt + # # # Templates for password reset success and failure pages that a user # # will see after attempting to reset their password # # @@ -339,6 +369,12 @@ class EmailConfig(Config): # # # #registration_template_success_html: registration_success.html # #registration_template_failure_html: registration_failure.html + # + # # Templates for success and failure pages that a user will see after attempting + # # to add an email or phone to their account + # # + # #add_threepid_success_html: add_threepid_success.html + # #add_threepid_failure_html: add_threepid_failure.html """ diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 512f38e5a6..156719e308 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -81,11 +81,10 @@ class IdentityHandler(BaseHandler): given identity server Args: - id_server (str|None): The identity server to validate 3PIDs against. If None, - we will attempt to extract id_server creds + id_server (str): The identity server to validate 3PIDs against. Must be a + complete URL including the protocol (http(s)://) creds (dict[str, str]): Dictionary containing the following keys: - * id_server|idServer: An optional domain name of an identity server * client_secret|clientSecret: A unique secret str provided by the client * sid: The ID of the validation session @@ -104,20 +103,10 @@ class IdentityHandler(BaseHandler): raise SynapseError( 400, "Missing param session_id in creds", errcode=Codes.MISSING_PARAM ) - if not id_server: - # Attempt to get the id_server from the creds dict - id_server = creds.get("id_server") or creds.get("idServer") - if not id_server: - raise SynapseError( - 400, "Missing param id_server in creds", errcode=Codes.MISSING_PARAM - ) query_params = {"sid": session_id, "client_secret": client_secret} - url = "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/3pid/getValidated3pid", - ) + url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid" data = yield self.http_client.get_json(url, query_params) return data if "medium" in data else None diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 2437235dc4..5a4fc78b4c 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -179,6 +179,35 @@ class Mailer(object): template_vars, ) + @defer.inlineCallbacks + def send_add_threepid_mail(self, email_address, token, client_secret, sid): + """Send an email with a validation link to a user for adding a 3pid to their account + + Args: + email_address (str): Email address we're sending the validation link to + + token (str): Unique token generated by the server to verify the email was received + + client_secret (str): Unique token generated by the client to group together + multiple email sending attempts + + sid (str): The generated session ID + """ + params = {"token": token, "client_secret": client_secret, "sid": sid} + link = ( + self.hs.config.public_baseurl + + "_matrix/client/unstable/add_threepid/email/submit_token?%s" + % urllib.parse.urlencode(params) + ) + + template_vars = {"link": link} + + yield self.send_email( + email_address, + "[%s] Validate Your Email" % self.hs.config.server_name, + template_vars, + ) + @defer.inlineCallbacks def send_notification_mail( self, app_id, user_id, email_address, push_actions, reason diff --git a/synapse/res/templates/add_threepid.html b/synapse/res/templates/add_threepid.html new file mode 100644 index 0000000000..cc4ab07e09 --- /dev/null +++ b/synapse/res/templates/add_threepid.html @@ -0,0 +1,9 @@ + + +

A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:

+ + {{ link }} + +

If this was not you, you can safely ignore this email. Thank you.

+ + diff --git a/synapse/res/templates/add_threepid.txt b/synapse/res/templates/add_threepid.txt new file mode 100644 index 0000000000..a60c1ff659 --- /dev/null +++ b/synapse/res/templates/add_threepid.txt @@ -0,0 +1,6 @@ +A request to add an email address to your Matrix account has been received. If this was you, +please click the link below to confirm adding this email: + +{{ link }} + +If this was not you, you can safely ignore this email. Thank you. diff --git a/synapse/res/templates/add_threepid_failure.html b/synapse/res/templates/add_threepid_failure.html new file mode 100644 index 0000000000..441d11c846 --- /dev/null +++ b/synapse/res/templates/add_threepid_failure.html @@ -0,0 +1,8 @@ + + + +

The request failed for the following reason: {{ failure_reason }}.

+ +

No changes have been made to your account.

+ + diff --git a/synapse/res/templates/add_threepid_success.html b/synapse/res/templates/add_threepid_success.html new file mode 100644 index 0000000000..fbd6e4018f --- /dev/null +++ b/synapse/res/templates/add_threepid_success.html @@ -0,0 +1,6 @@ + + + +

Your email has now been validated, please return to your client. You may now close this window.

+ + diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 3c5b23dc80..1139bb156c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -21,7 +21,12 @@ from six.moves import http_client from twisted.internet import defer from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + HttpResponseException, + SynapseError, + ThreepidValidationError, +) from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.server import finish_request from synapse.http.servlet import ( @@ -103,16 +108,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - # Have the configured identity server handle the request - if not self.hs.config.account_threepid_delegate_email: - logger.warn( - "No upstream email account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, "Password reset by email is not supported on this homeserver" - ) + assert self.hs.config.account_threepid_delegate_email + # Have the configured identity server handle the request ret = yield self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, @@ -214,6 +212,11 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastore() + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self.failure_email_template, = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_password_reset_template_failure_html], + ) @defer.inlineCallbacks def on_GET(self, request, medium): @@ -261,13 +264,8 @@ class PasswordResetSubmitTokenServlet(RestServlet): request.setResponseCode(e.code) # Show a failure page with a reason - html_template, = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_password_reset_template_failure_html], - ) - template_vars = {"failure_reason": e.msg} - html = html_template.render(**template_vars) + html = self.failure_email_template.render(**template_vars) request.write(html.encode("utf-8")) finish_request(request) @@ -399,13 +397,35 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler self.store = self.hs.get_datastore() + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + template_html, template_text = load_jinja2_templates( + self.config.email_template_dir, + [ + self.config.email_add_threepid_template_html, + self.config.email_add_threepid_template_text, + ], + public_baseurl=self.config.public_baseurl, + ) + self.mailer = Mailer( + hs=self.hs, + app_name=self.config.email_app_name, + template_html=template_html, + template_text=template_text, + ) + @defer.inlineCallbacks def on_POST(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warn( + "Adding emails have been disabled due to lack of an email config" + ) + raise SynapseError( + 400, "Adding an email to your account is disabled on this server" + ) + body = parse_json_object_from_request(request) - assert_params_in_dict( - body, ["id_server", "client_secret", "email", "send_attempt"] - ) - id_server = "https://" + body["id_server"] # Assume https + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) client_secret = body["client_secret"] email = body["email"] send_attempt = body["send_attempt"] @@ -425,9 +445,30 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): if existing_user_id is not None: raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - ret = yield self.identity_handler.requestEmailToken( - id_server, email, client_secret, send_attempt, next_link - ) + if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + assert self.hs.config.account_threepid_delegate_email + + # Have the configured identity server handle the request + ret = yield self.identity_handler.requestEmailToken( + self.hs.config.account_threepid_delegate_email, + email, + client_secret, + send_attempt, + next_link, + ) + else: + # Send threepid validation emails from Synapse + sid = yield self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_add_threepid_mail, + next_link, + ) + + # Wrap the session id in a JSON object + ret = {"sid": sid} + return 200, ret @@ -471,9 +512,86 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): ret = yield self.identity_handler.requestMsisdnToken( id_server, country, phone_number, client_secret, send_attempt, next_link ) + return 200, ret +class AddThreepidSubmitTokenServlet(RestServlet): + """Handles 3PID validation token submission for adding an email to a user's account""" + + PATTERNS = client_patterns( + "/add_threepid/email/submit_token$", releases=(), unstable=True + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.config = hs.config + self.clock = hs.get_clock() + self.store = hs.get_datastore() + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self.failure_email_template, = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_add_threepid_template_failure_html], + ) + + @defer.inlineCallbacks + def on_GET(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warn( + "Adding emails have been disabled due to lack of an email config" + ) + raise SynapseError( + 400, "Adding an email to your account is disabled on this server" + ) + elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + raise SynapseError( + 400, + "This homeserver is not validating threepids. Use an identity server " + "instead.", + ) + + sid = parse_string(request, "sid", required=True) + client_secret = parse_string(request, "client_secret", required=True) + token = parse_string(request, "token", required=True) + + # Attempt to validate a 3PID session + try: + # Mark the session as valid + next_link = yield self.store.validate_threepid_session( + sid, client_secret, token, self.clock.time_msec() + ) + + # Perform a 302 redirect if next_link is set + if next_link: + if next_link.startswith("file:///"): + logger.warn( + "Not redirecting to next_link as it is a local file: address" + ) + else: + request.setResponseCode(302) + request.setHeader("Location", next_link) + finish_request(request) + return None + + # Otherwise show the success template + html = self.config.email_add_threepid_template_success_html_content + request.setResponseCode(200) + except ThreepidValidationError as e: + request.setResponseCode(e.code) + + # Show a failure page with a reason + template_vars = {"failure_reason": e.msg} + html = self.failure_email_template.render(**template_vars) + + request.write(html.encode("utf-8")) + finish_request(request) + + class ThreepidRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid$") @@ -495,6 +613,8 @@ class ThreepidRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() body = parse_json_object_from_request(request) threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") @@ -502,26 +622,85 @@ class ThreepidRestServlet(RestServlet): raise SynapseError( 400, "Missing param three_pid_creds", Codes.MISSING_PARAM ) + assert_params_in_dict(threepid_creds, ["client_secret", "sid"]) - requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + client_secret = threepid_creds["client_secret"] + sid = threepid_creds["sid"] - # Specify None as the identity server to retrieve it from the request body instead - threepid = yield self.identity_handler.threepid_from_creds(None, threepid_creds) + # We don't actually know which medium this 3PID is. Thus we first assume it's email, + # and if validation fails we try msisdn + validation_session = None - if not threepid: - raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED) + # Try to validate as email + if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + # Ask our delegated email identity server + try: + validation_session = yield self.identity_handler.threepid_from_creds( + self.hs.config.account_threepid_delegate_email, threepid_creds + ) + except HttpResponseException: + logger.debug( + "%s reported non-validated threepid: %s", + self.hs.config.account_threepid_delegate_email, + threepid_creds, + ) + elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + # Get a validated session matching these details + validation_session = yield self.datastore.get_threepid_validation_session( + "email", client_secret, sid=sid, validated=True + ) - for reqd in ["medium", "address", "validated_at"]: - if reqd not in threepid: - logger.warn("Couldn't add 3pid: invalid response from ID server") - raise SynapseError(500, "Invalid response from ID Server") + # Old versions of Sydent return a 200 http code even on a failed validation check. + # Thus, in addition to the HttpResponseException check above (which checks for + # non-200 errors), we need to make sure validation_session isn't actually an error, + # identified by containing an "error" key + # See https://github.com/matrix-org/sydent/issues/215 for details + if validation_session and "error" not in validation_session: + yield self._add_threepid_to_account(user_id, validation_session) + return 200, {} - yield self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + # Try to validate as msisdn + if self.hs.config.account_threepid_delegate_msisdn: + # Ask our delegated msisdn identity server + try: + validation_session = yield self.identity_handler.threepid_from_creds( + self.hs.config.account_threepid_delegate_msisdn, threepid_creds + ) + except HttpResponseException: + logger.debug( + "%s reported non-validated threepid: %s", + self.hs.config.account_threepid_delegate_email, + threepid_creds, + ) + + # Check that validation_session isn't actually an error due to old Sydent instances + # See explanatory comment above + if validation_session and "error" not in validation_session: + yield self._add_threepid_to_account(user_id, validation_session) + return 200, {} + + raise SynapseError( + 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) - return 200, {} + @defer.inlineCallbacks + def _add_threepid_to_account(self, user_id, validation_session): + """Add a threepid wrapped in a validation_session dict to an account + + Args: + user_id (str): The mxid of the user to add this 3PID to + + validation_session (dict): A dict containing the following: + * medium - medium of the threepid + * address - address of the threepid + * validated_at - timestamp of when the validation occurred + """ + yield self.auth_handler.add_threepid( + user_id, + validation_session["medium"], + validation_session["address"], + validation_session["validated_at"], + ) class ThreepidUnbindRestServlet(RestServlet): @@ -613,6 +792,7 @@ def register_servlets(hs, http_server): DeactivateAccountRestServlet(hs).register(http_server) EmailThreepidRequestTokenRestServlet(hs).register(http_server) MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) + AddThreepidSubmitTokenServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5c7a5f3579..34276ea3fa 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -131,15 +131,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - if not self.hs.config.account_threepid_delegate_email: - logger.warn( - "No upstream email account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, "Registration by email is not supported on this homeserver" - ) + assert self.hs.config.account_threepid_delegate_email + # Have the configured identity server handle the request ret = yield self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, @@ -246,6 +240,12 @@ class RegistrationSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self.failure_email_template, = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_registration_template_failure_html], + ) + @defer.inlineCallbacks def on_GET(self, request, medium): if medium != "email": @@ -289,17 +289,11 @@ class RegistrationSubmitTokenServlet(RestServlet): request.setResponseCode(200) except ThreepidValidationError as e: - # Show a failure page with a reason request.setResponseCode(e.code) # Show a failure page with a reason - html_template, = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], - ) - template_vars = {"failure_reason": e.msg} - html = html_template.render(**template_vars) + html = self.failure_email_template.render(**template_vars) request.write(html.encode("utf-8")) finish_request(request) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 109052fa41..da27ad76b6 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -24,7 +24,7 @@ from six.moves import range from twisted.internet import defer from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, StoreError, ThreepidValidationError +from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore @@ -661,18 +661,31 @@ class RegistrationWorkerStore(SQLBaseStore): medium (str|None): The medium of the 3PID address (str|None): The address of the 3PID sid (str|None): The ID of the validation session - client_secret (str|None): A unique string provided by the client to - help identify this validation attempt + client_secret (str): A unique string provided by the client to help identify this + validation attempt validated (bool|None): Whether sessions should be filtered by whether they have been validated already or not. None to perform no filtering Returns: - deferred {str, int}|None: A dict containing the - latest session_id and send_attempt count for this 3PID. - Otherwise None if there hasn't been a previous attempt + Deferred[dict|None]: A dict containing the following: + * address - address of the 3pid + * medium - medium of the 3pid + * client_secret - a secret provided by the client for this validation session + * session_id - ID of the validation session + * send_attempt - a number serving to dedupe send attempts for this session + * validated_at - timestamp of when this session was validated if so + + Otherwise None if a validation session is not found """ - keyvalues = {"medium": medium, "client_secret": client_secret} + if not client_secret: + raise SynapseError( + 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM + ) + + keyvalues = {"client_secret": client_secret} + if medium: + keyvalues["medium"] = medium if address: keyvalues["address"] = address if sid: @@ -1209,6 +1222,10 @@ class RegistrationStore( current_ts (int): The current unix time in milliseconds. Used for checking token expiry status + Raises: + ThreepidValidationError: if a matching validation token was not found or has + expired + Returns: deferred str|None: A str representing a link to redirect the user to if there is one. -- cgit 1.5.1 From 30af161af27146cc44152292060c7005a6b8546b Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 23 Sep 2019 17:50:27 +0200 Subject: Implement MSC2290 (#6043) Implements MSC2290. This PR adds two new endpoints, /unstable/account/3pid/add and /unstable/account/3pid/bind. Depending on the progress of that MSC the unstable prefix may go away. This PR also removes the blacklist on some 3PID tests which occurs in #6042, as the corresponding Sytest PR changes them to use the new endpoints. Finally, it also modifies the account deactivation code such that it doesn't just try to deactivate 3PIDs that were bound to the user's account, but any 3PIDs that were bound through the homeserver on that user's account. --- changelog.d/6043.feature | 1 + synapse/handlers/deactivate_account.py | 4 +- synapse/handlers/identity.py | 134 +++++++++++++++---------- synapse/rest/client/v2_alpha/account.py | 161 +++++++++++++++++-------------- synapse/rest/client/v2_alpha/register.py | 6 ++ synapse/storage/registration.py | 22 ++++- sytest-blacklist | 9 -- 7 files changed, 203 insertions(+), 134 deletions(-) create mode 100644 changelog.d/6043.feature (limited to 'synapse/storage/registration.py') diff --git a/changelog.d/6043.feature b/changelog.d/6043.feature new file mode 100644 index 0000000000..cd27b0400b --- /dev/null +++ b/changelog.d/6043.feature @@ -0,0 +1 @@ +Implement new Client Server API endpoints `/account/3pid/add` and `/account/3pid/bind` as per [MSC2290](https://github.com/matrix-org/matrix-doc/pull/2290). \ No newline at end of file diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 5f804d1f13..d83912c9a4 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -73,7 +73,9 @@ class DeactivateAccountHandler(BaseHandler): # unbinding identity_server_supports_unbinding = True - threepids = yield self.store.user_get_threepids(user_id) + # Retrieve the 3PIDs this user has bound to an identity server + threepids = yield self.store.user_get_bound_threepids(user_id) + for threepid in threepids: try: result = yield self._identity_handler.try_unbind_threepid( diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index cd4700b521..d50d485e06 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -30,6 +30,7 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) +from synapse.config.emailconfig import ThreepidBehaviour from synapse.util.stringutils import random_string from ._base import BaseHandler @@ -45,36 +46,6 @@ class IdentityHandler(BaseHandler): self.federation_http_client = hs.get_http_client() self.hs = hs - def _extract_items_from_creds_dict(self, creds): - """ - Retrieve entries from a "credentials" dictionary - - Args: - creds (dict[str, str]): Dictionary of credentials that contain the following keys: - * client_secret|clientSecret: A unique secret str provided by the client - * id_server|idServer: the domain of the identity server to query - * id_access_token: The access token to authenticate to the identity - server with. - - Returns: - tuple(str, str, str|None): A tuple containing the client_secret, the id_server, - and the id_access_token value if available. - """ - client_secret = creds.get("client_secret") or creds.get("clientSecret") - if not client_secret: - raise SynapseError( - 400, "No client_secret in creds", errcode=Codes.MISSING_PARAM - ) - - id_server = creds.get("id_server") or creds.get("idServer") - if not id_server: - raise SynapseError( - 400, "No id_server in creds", errcode=Codes.MISSING_PARAM - ) - - id_access_token = creds.get("id_access_token") - return client_secret, id_server, id_access_token - @defer.inlineCallbacks def threepid_from_creds(self, id_server, creds): """ @@ -113,35 +84,50 @@ class IdentityHandler(BaseHandler): data = yield self.http_client.get_json(url, query_params) except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") - return data if "medium" in data else None + except HttpResponseException as e: + logger.info( + "%s returned %i for threepid validation for: %s", + id_server, + e.code, + creds, + ) + return None + + # Old versions of Sydent return a 200 http code even on a failed validation + # check. Thus, in addition to the HttpResponseException check above (which + # checks for non-200 errors), we need to make sure validation_session isn't + # actually an error, identified by the absence of a "medium" key + # See https://github.com/matrix-org/sydent/issues/215 for details + if "medium" in data: + return data + + logger.info("%s reported non-validated threepid: %s", id_server, creds) + return None @defer.inlineCallbacks - def bind_threepid(self, creds, mxid, use_v2=True): + def bind_threepid( + self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True + ): """Bind a 3PID to an identity server Args: - creds (dict[str, str]): Dictionary of credentials that contain the following keys: - * client_secret|clientSecret: A unique secret str provided by the client - * id_server|idServer: the domain of the identity server to query - * id_access_token: The access token to authenticate to the identity - server with. Required if use_v2 is true + client_secret (str): A unique secret provided by the client + + sid (str): The ID of the validation session + mxid (str): The MXID to bind the 3PID to - use_v2 (bool): Whether to use v2 Identity Service API endpoints + + id_server (str): The domain of the identity server to query + + id_access_token (str): The access token to authenticate to the identity + server with, if necessary. Required if use_v2 is true + + use_v2 (bool): Whether to use v2 Identity Service API endpoints. Defaults to True Returns: Deferred[dict]: The response from the identity server """ - logger.debug("binding threepid %r to %s", creds, mxid) - - client_secret, id_server, id_access_token = self._extract_items_from_creds_dict( - creds - ) - - sid = creds.get("sid") - if not sid: - raise SynapseError( - 400, "No sid in three_pid_creds", errcode=Codes.MISSING_PARAM - ) + logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server) # If an id_access_token is not supplied, force usage of v1 if id_access_token is None: @@ -160,7 +146,6 @@ class IdentityHandler(BaseHandler): data = yield self.http_client.post_json_get_json( bind_url, bind_data, headers=headers ) - logger.debug("bound threepid %r to %s", creds, mxid) # Remember where we bound the threepid yield self.store.add_user_bound_threepid( @@ -182,7 +167,10 @@ class IdentityHandler(BaseHandler): return data logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) - return (yield self.bind_threepid(creds, mxid, use_v2=False)) + res = yield self.bind_threepid( + client_secret, sid, mxid, id_server, id_access_token, use_v2=False + ) + return res @defer.inlineCallbacks def try_unbind_threepid(self, mxid, threepid): @@ -459,6 +447,50 @@ class IdentityHandler(BaseHandler): except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") + @defer.inlineCallbacks + def validate_threepid_session(self, client_secret, sid): + """Validates a threepid session with only the client secret and session ID + Tries validating against any configured account_threepid_delegates as well as locally. + + Args: + client_secret (str): A secret provided by the client + + sid (str): The ID of the session + + Returns: + Dict[str, str|int] if validation was successful, otherwise None + """ + # XXX: We shouldn't need to keep wrapping and unwrapping this value + threepid_creds = {"client_secret": client_secret, "sid": sid} + + # We don't actually know which medium this 3PID is. Thus we first assume it's email, + # and if validation fails we try msisdn + validation_session = None + + # Try to validate as email + if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + # Ask our delegated email identity server + validation_session = yield self.threepid_from_creds( + self.hs.config.account_threepid_delegate_email, threepid_creds + ) + elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + # Get a validated session matching these details + validation_session = yield self.store.get_threepid_validation_session( + "email", client_secret, sid=sid, validated=True + ) + + if validation_session: + return validation_session + + # Try to validate as msisdn + if self.hs.config.account_threepid_delegate_msisdn: + # Ask our delegated msisdn identity server + validation_session = yield self.threepid_from_creds( + self.hs.config.account_threepid_delegate_msisdn, threepid_creds + ) + + return validation_session + def create_id_access_token_header(id_access_token): """Create an Authorization header for passing to SimpleHttpClient as the header value diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 1139bb156c..b8c48dc8f1 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -21,12 +21,7 @@ from six.moves import http_client from twisted.internet import defer from synapse.api.constants import LoginType -from synapse.api.errors import ( - Codes, - HttpResponseException, - SynapseError, - ThreepidValidationError, -) +from synapse.api.errors import Codes, SynapseError, ThreepidValidationError from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.server import finish_request from synapse.http.servlet import ( @@ -485,10 +480,8 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( - body, - ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + body, ["client_secret", "country", "phone_number", "send_attempt"] ) - id_server = "https://" + body["id_server"] # Assume https client_secret = body["client_secret"] country = body["country"] phone_number = body["phone_number"] @@ -509,8 +502,23 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): if existing_user_id is not None: raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) + if not self.hs.config.account_threepid_delegate_msisdn: + logger.warn( + "No upstream msisdn account_threepid_delegate configured on the server to " + "handle this request" + ) + raise SynapseError( + 400, + "Adding phone numbers to user account is not supported by this homeserver", + ) + ret = yield self.identity_handler.requestMsisdnToken( - id_server, country, phone_number, client_secret, send_attempt, next_link + self.hs.config.account_threepid_delegate_msisdn, + country, + phone_number, + client_secret, + send_attempt, + next_link, ) return 200, ret @@ -627,81 +635,88 @@ class ThreepidRestServlet(RestServlet): client_secret = threepid_creds["client_secret"] sid = threepid_creds["sid"] - # We don't actually know which medium this 3PID is. Thus we first assume it's email, - # and if validation fails we try msisdn - validation_session = None - - # Try to validate as email - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - # Ask our delegated email identity server - try: - validation_session = yield self.identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds - ) - except HttpResponseException: - logger.debug( - "%s reported non-validated threepid: %s", - self.hs.config.account_threepid_delegate_email, - threepid_creds, - ) - elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - # Get a validated session matching these details - validation_session = yield self.datastore.get_threepid_validation_session( - "email", client_secret, sid=sid, validated=True - ) - - # Old versions of Sydent return a 200 http code even on a failed validation check. - # Thus, in addition to the HttpResponseException check above (which checks for - # non-200 errors), we need to make sure validation_session isn't actually an error, - # identified by containing an "error" key - # See https://github.com/matrix-org/sydent/issues/215 for details - if validation_session and "error" not in validation_session: - yield self._add_threepid_to_account(user_id, validation_session) + validation_session = yield self.identity_handler.validate_threepid_session( + client_secret, sid + ) + if validation_session: + yield self.auth_handler.add_threepid( + user_id, + validation_session["medium"], + validation_session["address"], + validation_session["validated_at"], + ) return 200, {} - # Try to validate as msisdn - if self.hs.config.account_threepid_delegate_msisdn: - # Ask our delegated msisdn identity server - try: - validation_session = yield self.identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds - ) - except HttpResponseException: - logger.debug( - "%s reported non-validated threepid: %s", - self.hs.config.account_threepid_delegate_email, - threepid_creds, - ) + raise SynapseError( + 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED + ) + - # Check that validation_session isn't actually an error due to old Sydent instances - # See explanatory comment above - if validation_session and "error" not in validation_session: - yield self._add_threepid_to_account(user_id, validation_session) - return 200, {} +class ThreepidAddRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/add$", releases=(), unstable=True) + + def __init__(self, hs): + super(ThreepidAddRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["client_secret", "sid"]) + client_secret = body["client_secret"] + sid = body["sid"] + + validation_session = yield self.identity_handler.validate_threepid_session( + client_secret, sid + ) + if validation_session: + yield self.auth_handler.add_threepid( + user_id, + validation_session["medium"], + validation_session["address"], + validation_session["validated_at"], + ) + return 200, {} raise SynapseError( 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) + +class ThreepidBindRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/bind$", releases=(), unstable=True) + + def __init__(self, hs): + super(ThreepidBindRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + self.auth = hs.get_auth() + @defer.inlineCallbacks - def _add_threepid_to_account(self, user_id, validation_session): - """Add a threepid wrapped in a validation_session dict to an account + def on_POST(self, request): + body = parse_json_object_from_request(request) - Args: - user_id (str): The mxid of the user to add this 3PID to + assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) + id_server = body["id_server"] + sid = body["sid"] + client_secret = body["client_secret"] + id_access_token = body.get("id_access_token") # optional - validation_session (dict): A dict containing the following: - * medium - medium of the threepid - * address - address of the threepid - * validated_at - timestamp of when the validation occurred - """ - yield self.auth_handler.add_threepid( - user_id, - validation_session["medium"], - validation_session["address"], - validation_session["validated_at"], + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + yield self.identity_handler.bind_threepid( + client_secret, sid, user_id, id_server, id_access_token ) + return 200, {} + class ThreepidUnbindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/unbind$", releases=(), unstable=True) @@ -794,6 +809,8 @@ def register_servlets(hs, http_server): MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) AddThreepidSubmitTokenServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) + ThreepidAddRestServlet(hs).register(http_server) + ThreepidBindRestServlet(hs).register(http_server) ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index e99b1f5c45..135a70808f 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -246,6 +246,12 @@ class RegistrationSubmitTokenServlet(RestServlet): [self.config.email_registration_template_failure_html], ) + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self.failure_email_template, = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_registration_template_failure_html], + ) + @defer.inlineCallbacks def on_GET(self, request, medium): if medium != "email": diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index da27ad76b6..805411a6b2 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -586,6 +586,26 @@ class RegistrationWorkerStore(SQLBaseStore): desc="add_user_bound_threepid", ) + def user_get_bound_threepids(self, user_id): + """Get the threepids that a user has bound to an identity server through the homeserver + The homeserver remembers where binds to an identity server occurred. Using this + method can retrieve those threepids. + + Args: + user_id (str): The ID of the user to retrieve threepids for + + Returns: + Deferred[list[dict]]: List of dictionaries containing the following: + medium (str): The medium of the threepid (e.g "email") + address (str): The address of the threepid (e.g "bob@example.com") + """ + return self._simple_select_list( + table="user_threepid_id_server", + keyvalues={"user_id": user_id}, + retcols=["medium", "address"], + desc="user_get_bound_threepids", + ) + def remove_user_bound_threepid(self, user_id, medium, address, id_server): """The server proxied an unbind request to the given identity server on behalf of the given user, so we remove the mapping of threepid to @@ -655,7 +675,7 @@ class RegistrationWorkerStore(SQLBaseStore): self, medium, client_secret, address=None, sid=None, validated=True ): """Gets a session_id and last_send_attempt (if available) for a - client_secret/medium/(address|session_id) combo + combination of validation metadata Args: medium (str|None): The medium of the 3PID diff --git a/sytest-blacklist b/sytest-blacklist index 04698cb068..11785fd43f 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -29,12 +29,3 @@ Enabling an unknown default rule fails with 404 # Blacklisted due to https://github.com/matrix-org/synapse/issues/1663 New federated private chats get full presence information (SYN-115) - -# Blacklisted temporarily due to https://github.com/matrix-org/matrix-doc/pull/2290 -# These sytests need to be updated with new endpoints, which will come in a later PR -# That PR will also remove this blacklist -Can bind 3PID via home server -Can bind and unbind 3PID via homeserver -3PIDs are unbound after account deactivation -Can bind and unbind 3PID via /unbind by specifying the identity server -Can bind and unbind 3PID via /unbind without specifying the identity server -- cgit 1.5.1 From 2858d10671f889796ca79712e29b8f35b8c9cd98 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 23 Sep 2019 17:22:01 +0100 Subject: Fix the return value in the users_set_deactivated_flag background job --- synapse/storage/registration.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'synapse/storage/registration.py') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 805411a6b2..92dd42fb87 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -860,18 +860,18 @@ class RegistrationStore( ) if batch_size > len(rows): - return True + return (True, rows_processed_nb) else: - return False + return (True, rows_processed_nb) - end = yield self.runInteraction( + end, nb_processed = yield self.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) if end: yield self._end_background_update("users_set_deactivated_flag") - return batch_size + return nb_processed @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): -- cgit 1.5.1 From 323d685bf743a660339be574aba68c0a63cb6483 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 23 Sep 2019 17:23:49 +0100 Subject: Typo --- synapse/storage/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage/registration.py') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 92dd42fb87..4c84d804fe 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -862,7 +862,7 @@ class RegistrationStore( if batch_size > len(rows): return (True, rows_processed_nb) else: - return (True, rows_processed_nb) + return (False, rows_processed_nb) end, nb_processed = yield self.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn -- cgit 1.5.1 From 12fe2a29bc4b15f667022c4ceb5428405ca3da9d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 24 Sep 2019 14:43:38 +0100 Subject: Incorporate review --- synapse/storage/registration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/storage/registration.py') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 4c84d804fe..c3acc8ecaf 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -844,7 +844,7 @@ class RegistrationStore( rows = self.cursor_to_dict(txn) if not rows: - return True + return True, 0 rows_processed_nb = 0 @@ -860,9 +860,9 @@ class RegistrationStore( ) if batch_size > len(rows): - return (True, rows_processed_nb) + return True, len(rows) else: - return (False, rows_processed_nb) + return False, len(rows) end, nb_processed = yield self.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn -- cgit 1.5.1 From 566ac40939404649d58c053e97dba75810f95339 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 24 Sep 2019 17:01:09 +0100 Subject: remove unused parameter to get_user_id_by_threepid (#6099) Added in #5377, apparently in error --- changelog.d/6099.misc | 1 + synapse/storage/registration.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6099.misc (limited to 'synapse/storage/registration.py') diff --git a/changelog.d/6099.misc b/changelog.d/6099.misc new file mode 100644 index 0000000000..8415c6759b --- /dev/null +++ b/changelog.d/6099.misc @@ -0,0 +1 @@ +Remove unused parameter to get_user_id_by_threepid. diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 805411a6b2..5cf2c893aa 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -495,7 +495,7 @@ class RegistrationWorkerStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_user_id_by_threepid(self, medium, address, require_verified=False): + def get_user_id_by_threepid(self, medium, address): """Returns user id from threepid Args: -- cgit 1.5.1