From e97d1cf0014668b9d4883d4175b783088444b24b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 9 Jan 2020 17:21:30 +0000 Subject: Modify check_database to take a connection rather than a cursor We might not need the cursor at all. --- synapse/storage/data_stores/__init__.py | 2 +- synapse/storage/engines/postgres.py | 17 +++++++++-------- synapse/storage/engines/sqlite.py | 2 +- 3 files changed, 11 insertions(+), 10 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 092e803799..e1d03429ca 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -47,7 +47,7 @@ class DataStores(object): with make_conn(database_config, engine) as db_conn: logger.info("Preparing database %r...", db_name) - engine.check_database(db_conn.cursor()) + engine.check_database(db_conn) prepare_database( db_conn, engine, hs.config, data_stores=database_config.data_stores, ) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index b7c4eda338..ba19785fd7 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -32,14 +32,15 @@ class PostgresEngine(object): self.synchronous_commit = database_config.get("synchronous_commit", True) self._version = None # unknown as yet - def check_database(self, txn): - txn.execute("SHOW SERVER_ENCODING") - rows = txn.fetchall() - if rows and rows[0][0] != "UTF8": - raise IncorrectDatabaseSetup( - "Database has incorrect encoding: '%s' instead of 'UTF8'\n" - "See docs/postgres.rst for more information." % (rows[0][0],) - ) + def check_database(self, db_conn): + with db_conn.cursor() as txn: + txn.execute("SHOW SERVER_ENCODING") + rows = txn.fetchall() + if rows and rows[0][0] != "UTF8": + raise IncorrectDatabaseSetup( + "Database has incorrect encoding: '%s' instead of 'UTF8'\n" + "See docs/postgres.rst for more information." % (rows[0][0],) + ) def convert_param_style(self, sql): return sql.replace("?", "%s") diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index df039a072d..3b3c13360b 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -53,7 +53,7 @@ class Sqlite3Engine(object): """ return False - def check_database(self, txn): + def check_database(self, db_conn): pass def convert_param_style(self, sql): -- cgit 1.5.1 From e48ba84e0bfe081814941b74e610ddcd168a3ce8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 9 Jan 2020 17:33:41 +0000 Subject: Check postgres version in check_database this saves doing it on each connection, and will allow us to pass extra options in. --- synapse/storage/engines/postgres.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index ba19785fd7..2a285e018c 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -33,6 +33,16 @@ class PostgresEngine(object): self._version = None # unknown as yet def check_database(self, db_conn): + # Get the version of PostgreSQL that we're using. As per the psycopg2 + # docs: The number is formed by converting the major, minor, and + # revision numbers into two-decimal-digit numbers and appending them + # together. For example, version 8.1.5 will be returned as 80105 + self._version = db_conn.server_version + + # Are we on a supported PostgreSQL version? + if self._version < 90500: + raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.") + with db_conn.cursor() as txn: txn.execute("SHOW SERVER_ENCODING") rows = txn.fetchall() @@ -46,17 +56,6 @@ class PostgresEngine(object): return sql.replace("?", "%s") def on_new_connection(self, db_conn): - - # Get the version of PostgreSQL that we're using. As per the psycopg2 - # docs: The number is formed by converting the major, minor, and - # revision numbers into two-decimal-digit numbers and appending them - # together. For example, version 8.1.5 will be returned as 80105 - self._version = db_conn.server_version - - # Are we on a supported PostgreSQL version? - if self._version < 90500: - raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.") - db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) @@ -120,8 +119,8 @@ class PostgresEngine(object): Returns: string """ - # note that this is a bit of a hack because it relies on on_new_connection - # having been called at least once. Still, that should be a safe bet here. + # note that this is a bit of a hack because it relies on check_database + # having been called. Still, that should be a safe bet here. numver = self._version assert numver is not None -- cgit 1.5.1 From bf468211805900e767b6b07a2bfa6046f70efb7a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 9 Jan 2020 17:46:52 +0000 Subject: Refuse to start if sqlite is older than 3.11.0 --- scripts/synapse_port_db | 16 ++++++++++++---- synapse/storage/engines/postgres.py | 4 ++-- synapse/storage/engines/sqlite.py | 7 +++++-- 3 files changed, 19 insertions(+), 8 deletions(-) (limited to 'synapse/storage') diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index a3dafaffc9..f135c8bc54 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -447,11 +447,15 @@ class Porter(object): else: return - def build_db_store(self, db_config: DatabaseConnectionConfig): + def build_db_store( + self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False, + ): """Builds and returns a database store using the provided configuration. Args: - config: The database configuration + db_config: The database configuration + allow_outdated_version: True to suppress errors about the database server + version being too old to run a complete synapse Returns: The built Store object. @@ -463,7 +467,9 @@ class Porter(object): hs = MockHomeserver(self.hs_config) with make_conn(db_config, engine) as db_conn: - engine.check_database(db_conn) + engine.check_database( + db_conn, allow_outdated_version=allow_outdated_version + ) prepare_database(db_conn, engine, config=None) store = Store(Database(hs, db_config, engine), db_conn, hs) db_conn.commit() @@ -491,8 +497,10 @@ class Porter(object): @defer.inlineCallbacks def run(self): try: + # we allow people to port away from outdated versions of sqlite. self.sqlite_store = self.build_db_store( - DatabaseConnectionConfig("master-sqlite", self.sqlite_config) + DatabaseConnectionConfig("master-sqlite", self.sqlite_config), + allow_outdated_version=True, ) # Check if all background updates are done, abort if not. diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 2a285e018c..c84cb452b0 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -32,7 +32,7 @@ class PostgresEngine(object): self.synchronous_commit = database_config.get("synchronous_commit", True) self._version = None # unknown as yet - def check_database(self, db_conn): + def check_database(self, db_conn, allow_outdated_version: bool = False): # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and # revision numbers into two-decimal-digit numbers and appending them @@ -40,7 +40,7 @@ class PostgresEngine(object): self._version = db_conn.server_version # Are we on a supported PostgreSQL version? - if self._version < 90500: + if not allow_outdated_version and self._version < 90500: raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.") with db_conn.cursor() as txn: diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 3b3c13360b..cbf52f5191 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -53,8 +53,11 @@ class Sqlite3Engine(object): """ return False - def check_database(self, db_conn): - pass + def check_database(self, db_conn, allow_outdated_version: bool = False): + if not allow_outdated_version: + version = self.module.sqlite_version_info + if version < (3, 11, 0): + raise RuntimeError("Synapse requires sqlite 3.11 or above.") def convert_param_style(self, sql): return sql -- cgit 1.5.1 From 1177d3f3a33bd3ae1eef46fba360d319598359ad Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 13 Jan 2020 18:10:43 +0000 Subject: Quarantine media by ID or user ID (#6681) --- changelog.d/6681.feature | 1 + docs/admin_api/media_admin_api.md | 76 ++++++- docs/workers.md | 4 +- synapse/rest/admin/media.py | 68 +++++- synapse/storage/data_stores/main/room.py | 116 ++++++++++- tests/rest/admin/test_admin.py | 341 +++++++++++++++++++++++++++++++ tests/rest/client/v1/utils.py | 37 ++++ 7 files changed, 632 insertions(+), 11 deletions(-) create mode 100644 changelog.d/6681.feature (limited to 'synapse/storage') diff --git a/changelog.d/6681.feature b/changelog.d/6681.feature new file mode 100644 index 0000000000..5cf19a4e0e --- /dev/null +++ b/changelog.d/6681.feature @@ -0,0 +1 @@ +Add new quarantine media admin APIs to quarantine by media ID or by user who uploaded the media. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 8b3666d5f5..46ba7a1a71 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -22,19 +22,81 @@ It returns a JSON body like the following: } ``` -# Quarantine media in a room +# Quarantine media -This API 'quarantines' all the media in a room. +Quarantining media means that it is marked as inaccessible by users. It applies +to any local media, and any locally-cached copies of remote media. -The API is: +The media file itself (and any thumbnails) is not deleted from the server. + +## Quarantining media by ID + +This API quarantines a single piece of local or remote media. + +Request: ``` -POST /_synapse/admin/v1/quarantine_media/ +POST /_synapse/admin/v1/media/quarantine// {} ``` -Quarantining media means that it is marked as inaccessible by users. It applies -to any local media, and any locally-cached copies of remote media. +Where `server_name` is in the form of `example.org`, and `media_id` is in the +form of `abcdefg12345...`. + +Response: + +``` +{} +``` + +## Quarantining media in a room + +This API quarantines all local and remote media in a room. + +Request: + +``` +POST /_synapse/admin/v1/room//media/quarantine + +{} +``` + +Where `room_id` is in the form of `!roomid12345:example.org`. + +Response: + +``` +{ + "num_quarantined": 10 # The number of media items successfully quarantined +} +``` + +Note that there is a legacy endpoint, `POST +/_synapse/admin/v1/quarantine_media/`, that operates the same. +However, it is deprecated and may be removed in a future release. + +## Quarantining all media of a user + +This API quarantines all *local* media that a *local* user has uploaded. That is to say, if +you would like to quarantine media uploaded by a user on a remote homeserver, you should +instead use one of the other APIs. + +Request: + +``` +POST /_synapse/admin/v1/user//media/quarantine + +{} +``` + +Where `user_id` is in the form of `@bob:example.org`. + +Response: + +``` +{ + "num_quarantined": 10 # The number of media items successfully quarantined +} +``` -The media file itself (and any thumbnails) is not deleted from the server. diff --git a/docs/workers.md b/docs/workers.md index f4283aeb05..0ab269fd96 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -202,7 +202,9 @@ Handles the media repository. It can handle all endpoints starting with: ... and the following regular expressions matching media-specific administration APIs: ^/_synapse/admin/v1/purge_media_cache$ - ^/_synapse/admin/v1/room/.*/media$ + ^/_synapse/admin/v1/room/.*/media.*$ + ^/_synapse/admin/v1/user/.*/media.*$ + ^/_synapse/admin/v1/media/.*$ ^/_synapse/admin/v1/quarantine_media/.*$ You should also set `enable_media_repo: False` in the shared configuration diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index fa833e54cf..3a445d6eed 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -32,16 +32,24 @@ class QuarantineMediaInRoom(RestServlet): this server. """ - PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P[^/]+)") + PATTERNS = ( + historical_admin_path_patterns("/room/(?P[^/]+)/media/quarantine") + + + # This path kept around for legacy reasons + historical_admin_path_patterns("/quarantine_media/(?P![^/]+)") + ) def __init__(self, hs): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, room_id): + async def on_POST(self, request, room_id: str): requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) + logging.info("Quarantining room: %s", room_id) + + # Quarantine all media in this room num_quarantined = await self.store.quarantine_media_ids_in_room( room_id, requester.user.to_string() ) @@ -49,6 +57,60 @@ class QuarantineMediaInRoom(RestServlet): return 200, {"num_quarantined": num_quarantined} +class QuarantineMediaByUser(RestServlet): + """Quarantines all local media by a given user so that no one can download it via + this server. + """ + + PATTERNS = historical_admin_path_patterns( + "/user/(?P[^/]+)/media/quarantine" + ) + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, user_id: str): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Quarantining local media by user: %s", user_id) + + # Quarantine all media this user has uploaded + num_quarantined = await self.store.quarantine_media_ids_by_user( + user_id, requester.user.to_string() + ) + + return 200, {"num_quarantined": num_quarantined} + + +class QuarantineMediaByID(RestServlet): + """Quarantines local or remote media by a given ID so that no one can download + it via this server. + """ + + PATTERNS = historical_admin_path_patterns( + "/media/quarantine/(?P[^/]+)/(?P[^/]+)" + ) + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, server_name: str, media_id: str): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Quarantining local media by ID: %s/%s", server_name, media_id) + + # Quarantine this media id + await self.store.quarantine_media_by_id( + server_name, media_id, requester.user.to_string() + ) + + return 200, {} + + class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ @@ -94,4 +156,6 @@ def register_servlets_for_media_repo(hs, http_server): """ PurgeMediaCacheRestServlet(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server) + QuarantineMediaByID(hs).register(http_server) + QuarantineMediaByUser(hs).register(http_server) ListMediaInRoom(hs).register(http_server) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 8636d75030..49bab62be3 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -18,7 +18,7 @@ import collections import logging import re from abc import abstractmethod -from typing import Optional, Tuple +from typing import List, Optional, Tuple from six import integer_types @@ -399,6 +399,8 @@ class RoomWorkerStore(SQLBaseStore): the associated media """ + logger.info("Quarantining media in room: %s", room_id) + def _quarantine_media_in_room_txn(txn): local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) total_media_quarantined = 0 @@ -494,6 +496,118 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs + def quarantine_media_by_id( + self, server_name: str, media_id: str, quarantined_by: str, + ): + """quarantines a single local or remote media id + + Args: + server_name: The name of the server that holds this media + media_id: The ID of the media to be quarantined + quarantined_by: The user ID that initiated the quarantine request + """ + logger.info("Quarantining media: %s/%s", server_name, media_id) + is_local = server_name == self.config.server_name + + def _quarantine_media_by_id_txn(txn): + local_mxcs = [media_id] if is_local else [] + remote_mxcs = [(server_name, media_id)] if not is_local else [] + + return self._quarantine_media_txn( + txn, local_mxcs, remote_mxcs, quarantined_by + ) + + return self.db.runInteraction( + "quarantine_media_by_user", _quarantine_media_by_id_txn + ) + + def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + """quarantines all local media associated with a single user + + Args: + user_id: The ID of the user to quarantine media of + quarantined_by: The ID of the user who made the quarantine request + """ + + def _quarantine_media_by_user_txn(txn): + local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) + return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) + + return self.db.runInteraction( + "quarantine_media_by_user", _quarantine_media_by_user_txn + ) + + def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True): + """Retrieves local media IDs by a given user + + Args: + txn (cursor) + user_id: The ID of the user to retrieve media IDs of + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + # Local media + sql = """ + SELECT media_id + FROM local_media_repository + WHERE user_id = ? + """ + if filter_quarantined: + sql += "AND quarantined_by IS NULL" + txn.execute(sql, (user_id,)) + + local_media_ids = [row[0] for row in txn] + + # TODO: Figure out all remote media a user has referenced in a message + + return local_media_ids + + def _quarantine_media_txn( + self, + txn, + local_mxcs: List[str], + remote_mxcs: List[Tuple[str, str]], + quarantined_by: str, + ) -> int: + """Quarantine local and remote media items + + Args: + txn (cursor) + local_mxcs: A list of local mxc URLs + remote_mxcs: A list of (remote server, media id) tuples representing + remote mxc URLs + quarantined_by: The ID of the user who initiated the quarantine request + Returns: + The total number of media items quarantined + """ + total_media_quarantined = 0 + + # Update all the tables to set the quarantined_by flag + txn.executemany( + """ + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? + """, + ((quarantined_by, media_id) for media_id in local_mxcs), + ) + + txn.executemany( + """ + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE media_origin = ? AND media_id = ? + """, + ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs), + ) + + total_media_quarantined += len(local_mxcs) + total_media_quarantined += len(remote_mxcs) + + return total_media_quarantined + class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 6ceb483aa8..7a7e898843 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -14,11 +14,17 @@ # limitations under the License. import json +import os +import urllib.parse +from binascii import unhexlify from mock import Mock +from twisted.internet.defer import Deferred + import synapse.rest.admin from synapse.http.server import JsonResource +from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet from synapse.rest.client.v1 import events, login, room from synapse.rest.client.v2_alpha import groups @@ -346,3 +352,338 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) test_purge_room.skip = "Disabled because it's currently broken" + + +class QuarantineMediaTestCase(unittest.HomeserverTestCase): + """Test /quarantine_media admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.hs = hs + + # Allow for uploading and downloading to/from the media repo + self.media_repo = hs.get_media_repository_resource() + self.download_resource = self.media_repo.children[b"download"] + self.upload_resource = self.media_repo.children[b"upload"] + self.image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + def make_homeserver(self, reactor, clock): + + self.fetches = [] + + def get_file(destination, path, output_stream, args=None, max_size=None): + """ + Returns tuple[int,dict,str,int] of file length, response headers, + absolute URI, and response code. + """ + + def write_to(r): + data, response = r + output_stream.write(data) + return response + + d = Deferred() + d.addCallback(write_to) + self.fetches.append((d, destination, path, args)) + return make_deferred_yieldable(d) + + client = Mock() + client.get_file = get_file + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + + config = self.default_config() + config["media_store_path"] = self.media_store_path + config["thumbnail_requirements"] = {} + config["max_image_pixels"] = 2000000 + + provider_config = { + "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + config["media_storage_providers"] = [provider_config] + + hs = self.setup_test_homeserver(config=config, http_client=client) + + return hs + + def test_quarantine_media_requires_admin(self): + self.register_user("nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("nonadmin", "pass") + + # Attempt quarantine media APIs as non-admin + url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345" + request, channel = self.make_request( + "POST", url.encode("ascii"), access_token=non_admin_user_tok, + ) + self.render(request) + + # Expect a forbidden error + self.assertEqual( + 403, + int(channel.result["code"]), + msg="Expected forbidden on quarantining media as a non-admin", + ) + + # And the roomID/userID endpoint + url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine" + request, channel = self.make_request( + "POST", url.encode("ascii"), access_token=non_admin_user_tok, + ) + self.render(request) + + # Expect a forbidden error + self.assertEqual( + 403, + int(channel.result["code"]), + msg="Expected forbidden on quarantining media as a non-admin", + ) + + def test_quarantine_media_by_id(self): + self.register_user("id_admin", "pass", admin=True) + admin_user_tok = self.login("id_admin", "pass") + + self.register_user("id_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("id_nonadmin", "pass") + + # Upload some media into the room + response = self.helper.upload_media( + self.upload_resource, self.image_data, tok=admin_user_tok + ) + + # Extract media ID from the response + server_name_and_media_id = response["content_uri"][ + 6: + ] # Cut off the 'mxc://' bit + server_name, media_id = server_name_and_media_id.split("/") + + # Attempt to access the media + request, channel = self.make_request( + "GET", + server_name_and_media_id, + shorthand=False, + access_token=non_admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be successful + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + + # Quarantine the media + url = "/_synapse/admin/v1/media/quarantine/%s/%s" % ( + urllib.parse.quote(server_name), + urllib.parse.quote(media_id), + ) + request, channel = self.make_request("POST", url, access_token=admin_user_tok,) + self.render(request) + self.pump(1.0) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + + # Attempt to access the media + request, channel = self.make_request( + "GET", + server_name_and_media_id, + shorthand=False, + access_token=admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_name_and_media_id + ), + ) + + def test_quarantine_all_media_in_room(self): + self.register_user("room_admin", "pass", admin=True) + admin_user_tok = self.login("room_admin", "pass") + + non_admin_user = self.register_user("room_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("room_nonadmin", "pass") + + room_id = self.helper.create_room_as(non_admin_user, tok=admin_user_tok) + self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok) + + # Upload some media + response_1 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + response_2 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + + # Extract mxcs + mxc_1 = response_1["content_uri"] + mxc_2 = response_2["content_uri"] + + # Send it into the room + self.helper.send_event( + room_id, + "m.room.message", + content={"body": "image-1", "msgtype": "m.image", "url": mxc_1}, + txn_id="111", + tok=non_admin_user_tok, + ) + self.helper.send_event( + room_id, + "m.room.message", + content={"body": "image-2", "msgtype": "m.image", "url": mxc_2}, + txn_id="222", + tok=non_admin_user_tok, + ) + + # Quarantine all media in the room + url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote( + room_id + ) + request, channel = self.make_request("POST", url, access_token=admin_user_tok,) + self.render(request) + self.pump(1.0) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + self.assertEqual( + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 2}, + "Expected 2 quarantined items", + ) + + # Convert mxc URLs to server/media_id strings + server_and_media_id_1 = mxc_1[6:] + server_and_media_id_2 = mxc_2[6:] + + # Test that we cannot download any of the media anymore + request, channel = self.make_request( + "GET", + server_and_media_id_1, + shorthand=False, + access_token=non_admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_1 + ), + ) + + request, channel = self.make_request( + "GET", + server_and_media_id_2, + shorthand=False, + access_token=non_admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_2 + ), + ) + + def test_quarantine_all_media_by_user(self): + self.register_user("user_admin", "pass", admin=True) + admin_user_tok = self.login("user_admin", "pass") + + non_admin_user = self.register_user("user_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("user_nonadmin", "pass") + + # Upload some media + response_1 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + response_2 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + + # Extract media IDs + server_and_media_id_1 = response_1["content_uri"][6:] + server_and_media_id_2 = response_2["content_uri"][6:] + + # Quarantine all media by this user + url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( + non_admin_user + ) + request, channel = self.make_request( + "POST", url.encode("ascii"), access_token=admin_user_tok, + ) + self.render(request) + self.pump(1.0) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 2}, + "Expected 2 quarantined items", + ) + + # Attempt to access each piece of media + request, channel = self.make_request( + "GET", + server_and_media_id_1, + shorthand=False, + access_token=non_admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_1, + ), + ) + + # Attempt to access each piece of media + request, channel = self.make_request( + "GET", + server_and_media_id_2, + shorthand=False, + access_token=non_admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_2 + ), + ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index e7417b3d14..873d5ef99c 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -21,6 +21,8 @@ import time import attr +from twisted.web.resource import Resource + from synapse.api.constants import Membership from tests.server import make_request, render @@ -160,3 +162,38 @@ class RestHelper(object): ) return channel.json_body + + def upload_media( + self, + resource: Resource, + image_data: bytes, + tok: str, + filename: str = "test.png", + expect_code: int = 200, + ) -> dict: + """Upload a piece of test media to the media repo + Args: + resource: The resource that will handle the upload request + image_data: The image data to upload + tok: The user token to use during the upload + filename: The filename of the media to be uploaded + expect_code: The return code to expect from attempting to upload the media + """ + image_length = len(image_data) + path = "/_matrix/media/r0/upload?filename=%s" % (filename,) + request, channel = make_request( + self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok + ) + request.requestHeaders.addRawHeader( + b"Content-Length", str(image_length).encode("UTF-8") + ) + request.render(resource) + self.hs.get_reactor().pump([100]) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body -- cgit 1.5.1 From 28c98e51ffa166bd717646b0b34228e59f253485 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 Jan 2020 14:59:33 +0000 Subject: Add `local_current_membership` table (#6655) Currently we rely on `current_state_events` to figure out what rooms a user was in and their last membership event in there. However, if the server leaves the room then the table may be cleaned up and that information is lost. So lets add a table that separately holds that information. --- changelog.d/6655.misc | 1 + scripts/synapse_port_db | 2 +- synapse/handlers/admin.py | 2 +- synapse/handlers/deactivate_account.py | 2 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/push/push_tools.py | 2 +- synapse/replication/slave/storage/events.py | 2 +- synapse/server_notices/server_notices_manager.py | 2 +- synapse/storage/data_stores/main/events.py | 30 ++++ synapse/storage/data_stores/main/roommember.py | 189 ++++++++++++--------- .../schema/delta/57/local_current_membership.py | 97 +++++++++++ synapse/storage/prepare_database.py | 2 +- tests/handlers/test_sync.py | 4 +- tests/replication/slave/storage/test_events.py | 4 +- tests/rest/client/v2_alpha/test_account.py | 12 +- tests/rest/client/v2_alpha/test_sync.py | 9 - tests/storage/test_roommember.py | 2 +- 20 files changed, 263 insertions(+), 107 deletions(-) create mode 100644 changelog.d/6655.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py (limited to 'synapse/storage') diff --git a/changelog.d/6655.misc b/changelog.d/6655.misc new file mode 100644 index 0000000000..01e78bc84e --- /dev/null +++ b/changelog.d/6655.misc @@ -0,0 +1 @@ +Add `local_current_membership` table for tracking local user membership state in rooms. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index f135c8bc54..5e69104b97 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -470,7 +470,7 @@ class Porter(object): engine.check_database( db_conn, allow_outdated_version=allow_outdated_version ) - prepare_database(db_conn, engine, config=None) + prepare_database(db_conn, engine, config=self.hs_config) store = Store(Database(hs, db_config, engine), db_conn, hs) db_conn.commit() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 76d18a8ba8..a9407553b4 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -134,7 +134,7 @@ class AdminHandler(BaseHandler): The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in - rooms = await self.store.get_rooms_for_user_where_membership_is( + rooms = await self.store.get_rooms_for_local_user_where_membership_is( user_id, membership_list=( Membership.JOIN, diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 4426967f88..2afb390a92 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -140,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler): user_id (str): The user ID to reject pending invites for. """ user = UserID.from_string(user_id) - pending_invites = await self.store.get_invited_rooms_for_user(user_id) + pending_invites = await self.store.get_invited_rooms_for_local_user(user_id) for room in pending_invites: try: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 44ec3e66ae..2e6755f19c 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -101,7 +101,7 @@ class InitialSyncHandler(BaseHandler): if include_archived: memberships.append(Membership.LEAVE) - room_list = await self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=memberships ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 03bb52ccfb..15e8aa5249 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -690,7 +690,7 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _get_inviter(self, user_id, room_id): - invite = yield self.store.get_invite_for_user_in_room( + invite = yield self.store.get_invite_for_local_user_in_room( user_id=user_id, room_id=room_id ) if invite: diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index ef750d1497..110097eab9 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -179,7 +179,7 @@ class SearchHandler(BaseHandler): search_filter = Filter(filter_dict) # TODO: Search through left rooms too - rooms = yield self.store.get_rooms_for_user_where_membership_is( + rooms = yield self.store.get_rooms_for_local_user_where_membership_is( user.to_string(), membership_list=[Membership.JOIN], # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 2d3b8ba73c..cd95f85e3f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1662,7 +1662,7 @@ class SyncHandler(object): Membership.BAN, ) - room_list = await self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=membership_list ) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index de5c101a58..5dae4648c0 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,7 +21,7 @@ from synapse.storage import Storage @defer.inlineCallbacks def get_badge_count(store, user_id): - invites = yield store.get_invited_rooms_for_user(user_id) + invites = yield store.get_invited_rooms_for_local_user(user_id) joins = yield store.get_rooms_for_user(user_id) my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 29f35b9915..3aa6cb8b96 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -152,7 +152,7 @@ class SlavedEventStore( if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) - self.get_invited_rooms_for_user.invalidate((state_key,)) + self.get_invited_rooms_for_local_user.invalidate((state_key,)) if relates_to: self.get_relations_for_event.invalidate_many((relates_to,)) diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 2dac90578c..f7432c8d2f 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -105,7 +105,7 @@ class ServerNoticesManager(object): assert self._is_mine_id(user_id), "Cannot send server notices to remote users" - rooms = yield self._store.get_rooms_for_user_where_membership_is( + rooms = yield self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) system_mxid = self._config.server_notices_mxid diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 58f35d7f56..e9fe63037b 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -128,6 +128,7 @@ class EventsStore( hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + self.is_mine_id = hs.is_mine_id @defer.inlineCallbacks def _read_forward_extremities(self): @@ -547,6 +548,34 @@ class EventsStore( ], ) + # Note: Do we really want to delete rows here (that we do not + # subsequently reinsert below)? While technically correct it means + # we have no record of the fact the user *was* a member of the + # room but got, say, state reset out of it. + if to_delete or to_insert: + txn.executemany( + "DELETE FROM local_current_membership" + " WHERE room_id = ? AND user_id = ?", + ( + (room_id, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + if etype == EventTypes.Member and self.is_mine_id(state_key) + ), + ) + + if to_insert: + txn.executemany( + """INSERT INTO local_current_membership + (room_id, user_id, event_id, membership) + VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + """, + [ + (room_id, key[1], ev_id, ev_id) + for key, ev_id in to_insert.items() + if key[0] == EventTypes.Member and self.is_mine_id(key[1]) + ], + ) + txn.call_after( self._curr_state_delta_stream_cache.entity_has_changed, room_id, @@ -1724,6 +1753,7 @@ class EventsStore( "local_invites", "room_account_data", "room_tags", + "local_current_membership", ): logger.info("[purge] removing %s from %s", room_id, table) txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 70ff5751b6..9acef7c950 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -297,19 +297,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {row[0]: row[1] for row in txn} @cached() - def get_invited_rooms_for_user(self, user_id): - """ Get all the rooms the user is invited to + def get_invited_rooms_for_local_user(self, user_id): + """ Get all the rooms the *local* user is invited to + Args: user_id (str): The user ID. Returns: A deferred list of RoomsForUser. """ - return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE]) + return self.get_rooms_for_local_user_where_membership_is( + user_id, [Membership.INVITE] + ) @defer.inlineCallbacks - def get_invite_for_user_in_room(self, user_id, room_id): - """Gets the invite for the given user and room + def get_invite_for_local_user_in_room(self, user_id, room_id): + """Gets the invite for the given *local* user and room Args: user_id (str) @@ -319,15 +322,15 @@ class RoomMemberWorkerStore(EventsWorkerStore): Deferred: Resolves to either a RoomsForUser or None if no invite was found. """ - invites = yield self.get_invited_rooms_for_user(user_id) + invites = yield self.get_invited_rooms_for_local_user(user_id) for invite in invites: if invite.room_id == room_id: return invite return None @defer.inlineCallbacks - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): - """ Get all the rooms for this user where the membership for this user + def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list): + """ Get all the rooms for this *local* user where the membership for this user matches one in the membership list. Filters out forgotten rooms. @@ -344,8 +347,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): return defer.succeed(None) rooms = yield self.db.runInteraction( - "get_rooms_for_user_where_membership_is", - self._get_rooms_for_user_where_membership_is_txn, + "get_rooms_for_local_user_where_membership_is", + self._get_rooms_for_local_user_where_membership_is_txn, user_id, membership_list, ) @@ -354,76 +357,42 @@ class RoomMemberWorkerStore(EventsWorkerStore): forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) return [room for room in rooms if room.room_id not in forgotten_rooms] - def _get_rooms_for_user_where_membership_is_txn( + def _get_rooms_for_local_user_where_membership_is_txn( self, txn, user_id, membership_list ): + # Paranoia check. + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r" + % (user_id,), + ) - do_invite = Membership.INVITE in membership_list - membership_list = [m for m in membership_list if m != Membership.INVITE] - - results = [] - if membership_list: - if self._current_state_events_membership_up_to_date: - clause, args = make_in_list_sql_clause( - self.database_engine, "c.membership", membership_list - ) - sql = """ - SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND %s - """ % ( - clause, - ) - else: - clause, args = make_in_list_sql_clause( - self.database_engine, "m.membership", membership_list - ) - sql = """ - SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (room_id, event_id) - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND %s - """ % ( - clause, - ) - - txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] + clause, args = make_in_list_sql_clause( + self.database_engine, "c.membership", membership_list + ) - if do_invite: - sql = ( - "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" - " FROM local_invites as i" - " INNER JOIN events as e USING (event_id)" - " WHERE invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) + sql = """ + SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering + FROM local_current_membership AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + user_id = ? + AND %s + """ % ( + clause, + ) - txn.execute(sql, (user_id,)) - results.extend( - RoomsForUser( - room_id=r["room_id"], - sender=r["inviter"], - event_id=r["event_id"], - stream_ordering=r["stream_ordering"], - membership=Membership.INVITE, - ) - for r in self.db.cursor_to_dict(txn) - ) + txn.execute(sql, (user_id, *args)) + results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] return results - @cachedInlineCallbacks(max_entries=500000, iterable=True) + @cached(max_entries=500000, iterable=True) def get_rooms_for_user_with_stream_ordering(self, user_id): - """Returns a set of room_ids the user is currently joined to + """Returns a set of room_ids the user is currently joined to. + + If a remote user only returns rooms this server is currently + participating in. Args: user_id (str) @@ -433,17 +402,49 @@ class RoomMemberWorkerStore(EventsWorkerStore): the rooms the user is in currently, along with the stream ordering of the most recent join for that user and room. """ - rooms = yield self.get_rooms_for_user_where_membership_is( - user_id, membership_list=[Membership.JOIN] - ) - return frozenset( - GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) - for r in rooms + return self.db.runInteraction( + "get_rooms_for_user_with_stream_ordering", + self._get_rooms_for_user_with_stream_ordering_txn, + user_id, ) + def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id): + # We use `current_state_events` here and not `local_current_membership` + # as a) this gets called with remote users and b) this only gets called + # for rooms the server is participating in. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT room_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND c.membership = ? + """ + else: + sql = """ + SELECT room_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND m.membership = ? + """ + + txn.execute(sql, (user_id, Membership.JOIN)) + results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) + + return results + @defer.inlineCallbacks def get_rooms_for_user(self, user_id, on_invalidate=None): - """Returns a set of room_ids the user is currently joined to + """Returns a set of room_ids the user is currently joined to. + + If a remote user only returns rooms this server is currently + participating in. """ rooms = yield self.get_rooms_for_user_with_stream_ordering( user_id, on_invalidate=on_invalidate @@ -1022,7 +1023,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): event.internal_metadata.stream_ordering, ) txn.call_after( - self.get_invited_rooms_for_user.invalidate, (event.state_key,) + self.get_invited_rooms_for_local_user.invalidate, (event.state_key,) ) # We update the local_invites table only if the event is "current", @@ -1064,6 +1065,27 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): ), ) + # We also update the `local_current_membership` table with + # latest invite info. This will usually get updated by the + # `current_state_events` handling, unless its an outlier. + if event.internal_metadata.is_outlier(): + # This should only happen for out of band memberships, so + # we add a paranoia check. + assert event.internal_metadata.is_out_of_band_membership() + + self.db.simple_upsert_txn( + txn, + table="local_current_membership", + keyvalues={ + "room_id": event.room_id, + "user_id": event.state_key, + }, + values={ + "event_id": event.event_id, + "membership": event.membership, + }, + ) + @defer.inlineCallbacks def locally_reject_invite(self, user_id, room_id): sql = ( @@ -1075,6 +1097,15 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def f(txn, stream_ordering): txn.execute(sql, (stream_ordering, True, room_id, user_id)) + # We also clear this entry from `local_current_membership`. + # Ideally we'd point to a leave event, but we don't have one, so + # nevermind. + self.db.simple_delete_txn( + txn, + table="local_current_membership", + keyvalues={"room_id": room_id, "user_id": user_id}, + ) + with self._stream_id_gen.get_next() as stream_ordering: yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py new file mode 100644 index 0000000000..601c236c4a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# We create a new table called `local_current_membership` that stores the latest +# membership state of local users in rooms, which helps track leaves/bans/etc +# even if the server has left the room (and so has deleted the room from +# `current_state_events`). This will also include outstanding invites for local +# users for rooms the server isn't in. +# +# If the server isn't and hasn't been in the room then it will only include +# outsstanding invites, and not e.g. pre-emptive bans of local users. +# +# If the server later rejoins a room `local_current_membership` can simply be +# replaced with the new current state of the room (which results in the +# equivalent behaviour as if the server had remained in the room). + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + # We need to do the insert in `run_upgrade` section as we don't have access + # to `config` in `run_create`. + + # This upgrade may take a bit of time for large servers (e.g. one minute for + # matrix.org) but means we avoid a lots of book keeping required to do it as + # a background update. + + # We check if the `current_state_events.membership` is up to date by + # checking if the relevant background update has finished. If it has + # finished we can avoid doing a join against `room_memberships`, which + # speesd things up. + cur.execute( + """SELECT 1 FROM background_updates + WHERE update_name = 'current_state_events_membership' + """ + ) + current_state_membership_up_to_date = not bool(cur.fetchone()) + + # Cheekily drop and recreate indices, as that is faster. + cur.execute("DROP INDEX local_current_membership_idx") + cur.execute("DROP INDEX local_current_membership_room_idx") + + if current_state_membership_up_to_date: + sql = """ + INSERT INTO local_current_membership (room_id, user_id, event_id, membership) + SELECT c.room_id, state_key AS user_id, event_id, c.membership + FROM current_state_events AS c + WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key like '%' || ? + """ + else: + # We can't rely on the membership column, so we need to join against + # `room_memberships`. + sql = """ + INSERT INTO local_current_membership (room_id, user_id, event_id, membership) + SELECT c.room_id, state_key AS user_id, event_id, r.membership + FROM current_state_events AS c + INNER JOIN room_memberships AS r USING (event_id) + WHERE type = 'm.room.member' and state_key like '%' || ? + """ + cur.execute(sql, (config.server_name,)) + + cur.execute( + "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" + ) + cur.execute( + "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" + ) + + +def run_create(cur, database_engine, *args, **kwargs): + cur.execute( + """ + CREATE TABLE local_current_membership ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + membership TEXT NOT NULL + )""" + ) + + cur.execute( + "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" + ) + cur.execute( + "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" + ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e70026b80a..e86984cd50 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 56 +SCHEMA_VERSION = 57 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 758ee071a5..4cbe9784ed 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -32,8 +32,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def test_wait_for_sync_for_user_auth_blocking(self): - user_id1 = "@user1:server" - user_id2 = "@user2:server" + user_id1 = "@user1:test" + user_id2 = "@user2:test" sync_config = self._generate_sync_config(user_id1) self.reactor.advance(100) # So we get not 0 time diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index b68e9fe082..b1b037006d 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -115,13 +115,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def test_invites(self): self.persist(type="m.room.create", key="", creator=USER_ID) - self.check("get_invited_rooms_for_user", [USER_ID_2], []) + self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") self.replicate() self.check( - "get_invited_rooms_for_user", + "get_invited_rooms_for_local_user", [USER_ID_2], [ RoomsForUser( diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 0f51895b81..c3facc00eb 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -285,7 +285,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) # Make sure the invite is here. - pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id)) + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) self.assertEqual(len(pending_invites), 1, pending_invites) self.assertEqual(pending_invites[0].room_id, room_id, pending_invites) @@ -293,12 +295,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.deactivate(invitee_id, invitee_tok) # Check that the invite isn't there anymore. - pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id)) + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) self.assertEqual(len(pending_invites), 0, pending_invites) # Check that the membership of @invitee:test in the room is now "leave". memberships = self.get_success( - store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE]) + store.get_rooms_for_local_user_where_membership_is( + invitee_id, [Membership.LEAVE] + ) ) self.assertEqual(len(memberships), 1, memberships) self.assertEqual(memberships[0].room_id, room_id, memberships) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 661c1f88b9..9c13a13786 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -15,8 +15,6 @@ # limitations under the License. import json -from mock import Mock - import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room @@ -36,13 +34,6 @@ class FilterTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def make_homeserver(self, reactor, clock): - - hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock() - ) - return hs - def test_sync_argless(self): request, channel = self.make_request("GET", "/sync") self.render(request) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 7840f63fe3..00df0ea68e 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -57,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) rooms_for_user = self.get_success( - self.store.get_rooms_for_user_where_membership_is( + self.store.get_rooms_for_local_user_where_membership_is( self.u_alice, [Membership.JOIN] ) ) -- cgit 1.5.1 From 19a1aac48cc83fe41287a97bb0a96280a0e8c565 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 Jan 2020 18:13:47 +0000 Subject: Fix purge_room admin API (#6711) --- changelog.d/6711.bugfix | 1 + synapse/storage/purge_events.py | 2 +- tests/rest/admin/test_admin.py | 4 +--- 3 files changed, 3 insertions(+), 4 deletions(-) create mode 100644 changelog.d/6711.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/6711.bugfix b/changelog.d/6711.bugfix new file mode 100644 index 0000000000..c70506bd88 --- /dev/null +++ b/changelog.d/6711.bugfix @@ -0,0 +1 @@ +Fix `purge_room` admin API. diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index d6a7bd7834..fdc0abf5cf 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -34,7 +34,7 @@ class PurgeEventsStorage(object): """ state_groups_to_delete = yield self.stores.main.purge_room(room_id) - yield self.stores.main.purge_room_state(room_id, state_groups_to_delete) + yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) @defer.inlineCallbacks def purge_history(self, room_id, token, delete_local_events): diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 7a7e898843..f3b4a31e21 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -337,7 +337,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "local_invites", "room_account_data", "room_tags", - "state_groups", + # "state_groups", # Current impl leaves orphaned state groups around. "state_groups_state", ): count = self.get_success( @@ -351,8 +351,6 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) - test_purge_room.skip = "Disabled because it's currently broken" - class QuarantineMediaTestCase(unittest.HomeserverTestCase): """Test /quarantine_media admin API. -- cgit 1.5.1 From 855af069a494f826ef941d722c811287b3fc4a8c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 15 Jan 2020 18:56:18 +0000 Subject: Fix instantiation of message retention purge jobs When figuring out which topological token to start a purge job at, we need to do the following: 1. Figure out a timestamp before which events will be purged 2. Select the first stream ordering after that timestamp 3. Select info about the first event after that stream ordering 4. Build a topological token from that info In some situations (e.g. quiet rooms with a short max_lifetime), there might not be an event after the stream ordering at step 3, therefore we abort the purge with the error `No event found`. To mitigate that, this patch fetches the first event _before_ the stream ordering, instead of after. --- synapse/handlers/pagination.py | 2 +- synapse/storage/data_stores/main/stream.py | 59 ++++++++++++++++++++++++------ 2 files changed, 48 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 00a6afc963..3ee6a091c5 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -156,7 +156,7 @@ class PaginationHandler(object): stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts) - r = yield self.store.get_room_event_after_stream_ordering( + r = yield self.store.get_room_event_before_stream_ordering( room_id, stream_ordering, ) if not r: diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 140da8dad6..223ce7fedb 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -536,20 +536,55 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Deferred[(int, int, str)]: (stream ordering, topological ordering, event_id) """ + return self.db.runInteraction( + "get_room_event_after_stream_ordering", + self.get_room_event_around_stream_ordering_txn, + room_id, stream_ordering, "f", + ) - def _f(txn): - sql = ( - "SELECT stream_ordering, topological_ordering, event_id" - " FROM events" - " WHERE room_id = ? AND stream_ordering >= ?" - " AND NOT outlier" - " ORDER BY stream_ordering" - " LIMIT 1" - ) - txn.execute(sql, (room_id, stream_ordering)) - return txn.fetchone() + def get_room_event_before_stream_ordering(self, room_id, stream_ordering): + """Gets details of the first event in a room at or before a stream ordering + + Args: + room_id (str): + stream_ordering (int): + + Returns: + Deferred[(int, int, str)]: + (stream ordering, topological ordering, event_id) + """ + return self.db.runInteraction( + "get_room_event_before_stream_ordering", + self.get_room_event_around_stream_ordering_txn, + room_id, stream_ordering, "f", + ) + + def get_room_event_around_stream_ordering_txn( + self, txn, room_id, stream_ordering, dir="f" + ): + """Gets details of the first event in a room at or either after or before a + stream ordering, depending on the provided direction. + + Args: + room_id (str): + stream_ordering (int): + dir (str): Direction in which we're looking towards in the room's history, + either "f" (forward) or "b" (backward). - return self.db.runInteraction("get_room_event_after_stream_ordering", _f) + Returns: + Deferred[(int, int, str)]: + (stream ordering, topological ordering, event_id) + """ + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering %s ?" + " AND NOT outlier" + " ORDER BY stream_ordering" + " LIMIT 1" + ) % ("<=" if dir == "b" else ">=",) + txn.execute(sql, (room_id, stream_ordering)) + return txn.fetchone() @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): -- cgit 1.5.1 From 83635882379ecddb1509ea3d071eefdedefb647e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 15 Jan 2020 19:13:22 +0000 Subject: Fix typo --- synapse/storage/data_stores/main/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 223ce7fedb..9fa5e1f203 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -556,7 +556,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return self.db.runInteraction( "get_room_event_before_stream_ordering", self.get_room_event_around_stream_ordering_txn, - room_id, stream_ordering, "f", + room_id, stream_ordering, "b", ) def get_room_event_around_stream_ordering_txn( -- cgit 1.5.1 From 066b9f52b80c172eec6074ca01fb24670200fd80 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 15 Jan 2020 19:32:47 +0000 Subject: Correctly order when selecting before stream ordering --- synapse/storage/data_stores/main/stream.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 9fa5e1f203..451f38296b 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -580,9 +580,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): " FROM events" " WHERE room_id = ? AND stream_ordering %s ?" " AND NOT outlier" - " ORDER BY stream_ordering" + " ORDER BY stream_ordering %s" " LIMIT 1" - ) % ("<=" if dir == "b" else ">=",) + ) % ( + "<=" if dir == "b" else ">=", + "DESC" if dir == "b" else "ASC", + ) txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() -- cgit 1.5.1 From e601f35d3b562495b2f8b071bd4c812fd783d6a7 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 16 Jan 2020 09:55:11 +0000 Subject: Lint --- synapse/storage/data_stores/main/stream.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 451f38296b..652cecd59b 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -539,7 +539,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return self.db.runInteraction( "get_room_event_after_stream_ordering", self.get_room_event_around_stream_ordering_txn, - room_id, stream_ordering, "f", + room_id, + stream_ordering, + "f", ) def get_room_event_before_stream_ordering(self, room_id, stream_ordering): @@ -556,7 +558,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return self.db.runInteraction( "get_room_event_before_stream_ordering", self.get_room_event_around_stream_ordering_txn, - room_id, stream_ordering, "b", + room_id, + stream_ordering, + "b", ) def get_room_event_around_stream_ordering_txn( @@ -575,6 +579,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Deferred[(int, int, str)]: (stream ordering, topological ordering, event_id) """ + # Figure out which comparison operation to perform and how to order the results, + # using the provided direction. + op = "<=" if dir == "b" else ">=" + order = "DESC" if dir == "b" else "ASC" + sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" @@ -582,10 +591,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): " AND NOT outlier" " ORDER BY stream_ordering %s" " LIMIT 1" - ) % ( - "<=" if dir == "b" else ">=", - "DESC" if dir == "b" else "ASC", - ) + ) % (op, order) txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() -- cgit 1.5.1 From d386f2f339c839ff6ec8d656492dd635dc26f811 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 Jan 2020 13:31:22 +0000 Subject: Add StateMap type alias (#6715) --- changelog.d/6715.misc | 1 + synapse/api/auth.py | 8 +--- synapse/events/snapshot.py | 11 +++-- synapse/federation/sender/per_destination_queue.py | 3 +- synapse/handlers/admin.py | 25 ++++------- synapse/handlers/federation.py | 10 ++--- synapse/handlers/room.py | 24 +++++++--- synapse/state/__init__.py | 5 ++- synapse/state/v1.py | 5 ++- synapse/state/v2.py | 9 ++-- synapse/storage/data_stores/main/state.py | 11 ++--- synapse/storage/data_stores/state/store.py | 52 ++++++++++++---------- synapse/storage/state.py | 35 +++++++++------ synapse/types.py | 9 +++- 14 files changed, 115 insertions(+), 93 deletions(-) create mode 100644 changelog.d/6715.misc (limited to 'synapse/storage') diff --git a/changelog.d/6715.misc b/changelog.d/6715.misc new file mode 100644 index 0000000000..8876b0446d --- /dev/null +++ b/changelog.d/6715.misc @@ -0,0 +1 @@ +Add StateMap type alias to simplify types. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index abbc7079a3..2cbfab2569 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -14,7 +14,6 @@ # limitations under the License. import logging -from typing import Dict, Tuple from six import itervalues @@ -35,7 +34,7 @@ from synapse.api.errors import ( ResourceLimitError, ) from synapse.config.server import is_threepid_reserved -from synapse.types import UserID +from synapse.types import StateMap, UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -509,10 +508,7 @@ class Auth(object): return self.store.is_server_admin(user) def compute_auth_events( - self, - event, - current_state_ids: Dict[Tuple[str, str], str], - for_verification: bool = False, + self, event, current_state_ids: StateMap[str], for_verification: bool = False, ): """Given an event and current state return the list of event IDs used to auth an event. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index a44baea365..9ea85e93e6 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from six import iteritems @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.appservice import ApplicationService from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import StateMap @attr.s(slots=True) @@ -106,13 +107,11 @@ class EventContext: _state_group = attr.ib(default=None, type=Optional[int]) state_group_before_event = attr.ib(default=None, type=Optional[int]) prev_group = attr.ib(default=None, type=Optional[int]) - delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) + delta_ids = attr.ib(default=None, type=Optional[StateMap[str]]) app_service = attr.ib(default=None, type=Optional[ApplicationService]) - _current_state_ids = attr.ib( - default=None, type=Optional[Dict[Tuple[str, str], str]] - ) - _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) + _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]]) + _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]]) @staticmethod def with_state( diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index a5b36b1827..5012aaea35 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -31,6 +31,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState +from synapse.types import StateMap from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter # This is defined in the Matrix spec and enforced by the receiver. @@ -77,7 +78,7 @@ class PerDestinationQueue(object): # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # based on their key (e.g. typing events by room_id) # Map of (edu_type, key) -> Edu - self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu] + self._pending_edus_keyed = {} # type: StateMap[Edu] # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index a9407553b4..60a7c938bc 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -14,9 +14,11 @@ # limitations under the License. import logging +from typing import List from synapse.api.constants import Membership -from synapse.types import RoomStreamToken +from synapse.events import FrozenEvent +from synapse.types import RoomStreamToken, StateMap from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -259,35 +261,26 @@ class ExfiltrationWriter(object): """Interface used to specify how to write exported data. """ - def write_events(self, room_id, events): + def write_events(self, room_id: str, events: List[FrozenEvent]): """Write a batch of events for a room. - - Args: - room_id (str) - events (list[FrozenEvent]) """ pass - def write_state(self, room_id, event_id, state): + def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]): """Write the state at the given event in the room. This only gets called for backward extremities rather than for each event. - - Args: - room_id (str) - event_id (str) - state (dict[tuple[str, str], FrozenEvent]) """ pass - def write_invite(self, room_id, event, state): + def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]): """Write an invite for the room, with associated invite state. Args: - room_id (str) - event (FrozenEvent) - state (dict[tuple[str, str], dict]): A subset of the state at the + room_id + event + state: A subset of the state at the invite, with a subset of the event keys (type, state_key content and sender) """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 61b6713c88..d4f9a792fc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -64,7 +64,7 @@ from synapse.replication.http.federation import ( from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour -from synapse.types import UserID, get_domain_from_id +from synapse.types import StateMap, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -89,7 +89,7 @@ class _NewEventInfo: event = attr.ib(type=EventBase) state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None) + auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) def shortstr(iterable, maxitems=5): @@ -352,9 +352,7 @@ class FederationHandler(BaseHandler): ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id - state_maps = list( - ours.values() - ) # type: list[dict[tuple[str, str], str]] + state_maps = list(ours.values()) # type: list[StateMap[str]] # we don't need this any more, let's delete it. del ours @@ -1912,7 +1910,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - auth_events: Optional[Dict[Tuple[str, str], EventBase]], + auth_events: Optional[StateMap[EventBase]], backfilled: bool, ): """ diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9cab2adbfb..9f50196ea7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -32,7 +32,15 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, Syna from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter -from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID +from synapse.types import ( + Requester, + RoomAlias, + RoomID, + RoomStreamToken, + StateMap, + StreamToken, + UserID, +) from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache @@ -207,15 +215,19 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def _update_upgraded_room_pls( - self, requester, old_room_id, new_room_id, old_room_state, + self, + requester: Requester, + old_room_id: str, + new_room_id: str, + old_room_state: StateMap[str], ): """Send updated power levels in both rooms after an upgrade Args: - requester (synapse.types.Requester): the user requesting the upgrade - old_room_id (str): the id of the room to be replaced - new_room_id (str): the id of the replacement room - old_room_state (dict[tuple[str, str], str]): the state map for the old room + requester: the user requesting the upgrade + old_room_id: the id of the room to be replaced + new_room_id: the id of the replacement room + old_room_state: the state map for the old room Returns: Deferred diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 5accc071ab..cacd0c0c2b 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional from six import iteritems, itervalues @@ -33,6 +33,7 @@ from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.types import StateMap from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -594,7 +595,7 @@ def _make_state_cache_entry(new_state, state_groups_ids): def resolve_events_with_store( room_id: str, room_version: str, - state_sets: List[Dict[Tuple[str, str], str]], + state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", ): diff --git a/synapse/state/v1.py b/synapse/state/v1.py index b2f9865f39..d6c34ce3b7 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,7 +15,7 @@ import hashlib import logging -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional from six import iteritems, iterkeys, itervalues @@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase +from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -36,7 +37,7 @@ POWER_KEY = (EventTypes.PowerLevels, "") @defer.inlineCallbacks def resolve_events_with_store( room_id: str, - state_sets: List[Dict[Tuple[str, str], str]], + state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_map_factory: Callable, ): diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 72fb8a6317..6216fdd204 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,7 +16,7 @@ import heapq import itertools import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional from six import iteritems, itervalues @@ -27,6 +27,7 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.events import EventBase +from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ logger = logging.getLogger(__name__) def resolve_events_with_store( room_id: str, room_version: str, - state_sets: List[Dict[Tuple[str, str], str]], + state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "synapse.state.StateResolutionStore", ): @@ -393,12 +394,12 @@ def _iterative_auth_checks( room_id (str) room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to - base_state (dict[tuple[str, str], str]): The set of state to start with + base_state (StateMap[str]): The set of state to start with event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) Returns: - Deferred[dict[tuple[str, str], str]]: Returns the final updated state + Deferred[StateMap[str]]: Returns the final updated state """ resolved_state = base_state.copy() diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index d07440e3ed..33bebd1c48 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -165,19 +165,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # FIXME: how should this be cached? - def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): + def get_filtered_current_state_ids( + self, room_id: str, state_filter: StateFilter = StateFilter.all() + ): """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state Args: - room_id (str) - state_filter (StateFilter): The state filter used to fetch state + room_id + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[tuple[str, str], str]]: Map from type/state_key to - event ID. + defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. """ where_clause, where_args = state_filter.make_sql_filter_clause() diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index d53695f238..c4ee9b7ccb 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Dict, Iterable, List, Set, Tuple from six import iteritems from six.moves import range @@ -26,6 +27,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.database import Database from synapse.storage.state import StateFilter +from synapse.types import StateMap from synapse.util.caches import get_cache_factor_for from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -133,17 +135,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, state_filter): - """Returns the state groups for a given set of groups, filtering on - types of state events. + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ): + """Returns the state groups for a given set of groups from the + database, filtering on types of state events. Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state + groups: list of state group IDs to query + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ results = {} @@ -199,18 +202,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types @defer.inlineCallbacks - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + def _get_state_for_groups( + self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + ): """Gets the state at each of a list of state groups, optionally filtering by type/state_key Args: - groups (iterable[int]): list of state groups for which we want + groups: list of state groups for which we want to get the state. - state_filter (StateFilter): The state filter used to fetch state + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ member_filter, non_member_filter = state_filter.get_member_split() @@ -268,24 +272,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state - def _get_state_for_groups_using_cache(self, groups, cache, state_filter): + def _get_state_for_groups_using_cache( + self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter + ) -> Tuple[Dict[int, StateMap[str]], Set[int]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key, querying from a specific cache. Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - cache (DictionaryCache): the cache of group ids to state dicts which - we will pass through - either the normal state cache or the specific - members state cache. - state_filter (StateFilter): The state filter used to fetch state - from the database. + groups: list of state groups for which we want to get the state. + cache: the cache of group ids to state dicts which + we will pass through - either the normal state cache or the + specific members state cache. + state_filter: The state filter used to fetch state from the + database. Returns: - tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of - dict of state_group_id -> (dict of (type, state_key) -> event id) - of entries in the cache, and the state group ids either missing - from the cache or incomplete. + Tuple of dict of state_group_id to state map of entries in the + cache, and the state group ids either missing from the cache or + incomplete. """ results = {} incomplete_groups = set() diff --git a/synapse/storage/state.py b/synapse/storage/state.py index cbeb586014..c522c80922 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Iterable, List, TypeVar from six import iteritems, itervalues @@ -22,9 +23,13 @@ import attr from twisted.internet import defer from synapse.api.constants import EventTypes +from synapse.types import StateMap logger = logging.getLogger(__name__) +# Used for generic functions below +T = TypeVar("T") + @attr.s(slots=True) class StateFilter(object): @@ -233,14 +238,14 @@ class StateFilter(object): return len(self.concrete_types()) - def filter_state(self, state_dict): + def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]: """Returns the state filtered with by this StateFilter Args: - state (dict[tuple[str, str], Any]): The state map to filter + state: The state map to filter Returns: - dict[tuple[str, str], Any]: The filtered state map + The filtered state map """ if self.is_full(): return dict(state_dict) @@ -333,12 +338,12 @@ class StateGroupStorage(object): def __init__(self, hs, stores): self.stores = stores - def get_state_group_delta(self, state_group): + def get_state_group_delta(self, state_group: int): """Given a state group try to return a previous group and a delta between the old and the new. Returns: - Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]): + Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: (prev_group, delta_ids) """ @@ -353,7 +358,7 @@ class StateGroupStorage(object): event_ids (iterable[str]): ids of the events Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: + Deferred[dict[int, StateMap[str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: @@ -410,17 +415,18 @@ class StateGroupStorage(object): for group, event_id_map in iteritems(group_to_ids) } - def _get_state_groups_from_groups(self, groups, state_filter): + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ): """Returns the state groups for a given set of groups, filtering on types of state events. Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state + groups: list of state group IDs to query + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) @@ -519,7 +525,9 @@ class StateGroupStorage(object): state_map = yield self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + def _get_state_for_groups( + self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + ): """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -529,8 +537,7 @@ class StateGroupStorage(object): state_filter (StateFilter): The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. """ return self.stores.state._get_state_for_groups(groups, state_filter) diff --git a/synapse/types.py b/synapse/types.py index cd996c0b5a..65e4d8c181 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -17,6 +17,7 @@ import re import string import sys from collections import namedtuple +from typing import Dict, Tuple, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -28,7 +29,7 @@ from synapse.api.errors import SynapseError if sys.version_info[:3] >= (3, 6, 0): from typing import Collection else: - from typing import Sized, Iterable, Container, TypeVar + from typing import Sized, Iterable, Container T_co = TypeVar("T_co", covariant=True) @@ -36,6 +37,12 @@ else: __slots__ = () +# Define a state map type from type/state_key to T (usually an event ID or +# event) +T = TypeVar("T") +StateMap = Dict[Tuple[str, str], T] + + class Requester( namedtuple( "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"] -- cgit 1.5.1 From 842c2cfbf1e9f3e0d9251fa0c572eba9d6af6dbe Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 16 Jan 2020 20:24:17 +0000 Subject: Remove get_room_event_after_stream_ordering entirely --- synapse/rest/admin/__init__.py | 2 +- synapse/storage/data_stores/main/stream.py | 69 ++++++------------------------ 2 files changed, 13 insertions(+), 58 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index a10b4a9b72..2932fe2123 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -107,7 +107,7 @@ class PurgeHistoryRestServlet(RestServlet): stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) - r = await self.store.get_room_event_after_stream_ordering( + r = await self.store.get_room_event_before_stream_ordering( room_id, stream_ordering ) if not r: diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 652cecd59b..a20c3d1012 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -525,25 +525,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, token - def get_room_event_after_stream_ordering(self, room_id, stream_ordering): - """Gets details of the first event in a room at or after a stream ordering - - Args: - room_id (str): - stream_ordering (int): - - Returns: - Deferred[(int, int, str)]: - (stream ordering, topological ordering, event_id) - """ - return self.db.runInteraction( - "get_room_event_after_stream_ordering", - self.get_room_event_around_stream_ordering_txn, - room_id, - stream_ordering, - "f", - ) - def get_room_event_before_stream_ordering(self, room_id, stream_ordering): """Gets details of the first event in a room at or before a stream ordering @@ -555,45 +536,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Deferred[(int, int, str)]: (stream ordering, topological ordering, event_id) """ - return self.db.runInteraction( - "get_room_event_before_stream_ordering", - self.get_room_event_around_stream_ordering_txn, - room_id, - stream_ordering, - "b", - ) - - def get_room_event_around_stream_ordering_txn( - self, txn, room_id, stream_ordering, dir="f" - ): - """Gets details of the first event in a room at or either after or before a - stream ordering, depending on the provided direction. - - Args: - room_id (str): - stream_ordering (int): - dir (str): Direction in which we're looking towards in the room's history, - either "f" (forward) or "b" (backward). - - Returns: - Deferred[(int, int, str)]: - (stream ordering, topological ordering, event_id) - """ - # Figure out which comparison operation to perform and how to order the results, - # using the provided direction. - op = "<=" if dir == "b" else ">=" - order = "DESC" if dir == "b" else "ASC" + def _f(txn): + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering <= ?" + " AND NOT outlier" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + txn.execute(sql, (room_id, stream_ordering)) + return txn.fetchone() - sql = ( - "SELECT stream_ordering, topological_ordering, event_id" - " FROM events" - " WHERE room_id = ? AND stream_ordering %s ?" - " AND NOT outlier" - " ORDER BY stream_ordering %s" - " LIMIT 1" - ) % (op, order) - txn.execute(sql, (room_id, stream_ordering)) - return txn.fetchone() + return self.db.runInteraction("get_room_event_before_stream_ordering", _f) @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): -- cgit 1.5.1 From dac148341ba2638cc9486cf0b00005932dab939d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 16 Jan 2020 20:25:09 +0000 Subject: Fixup diff --- synapse/storage/data_stores/main/stream.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index a20c3d1012..056b25b13a 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -536,14 +536,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Deferred[(int, int, str)]: (stream ordering, topological ordering, event_id) """ + def _f(txn): sql = ( - "SELECT stream_ordering, topological_ordering, event_id" - " FROM events" - " WHERE room_id = ? AND stream_ordering <= ?" - " AND NOT outlier" - " ORDER BY stream_ordering DESC" - " LIMIT 1" + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering <= ?" + " AND NOT outlier" + " ORDER BY stream_ordering DESC" + " LIMIT 1" ) txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() -- cgit 1.5.1 From 14d8f342d5cae86d93d9ba2b411d486690ff54f5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 14 Jan 2020 11:58:02 +0000 Subject: move batch_iter to a separate module --- synapse/storage/data_stores/main/cache.py | 2 +- synapse/storage/data_stores/main/devices.py | 2 +- synapse/storage/data_stores/main/events.py | 2 +- synapse/storage/data_stores/main/events_worker.py | 2 +- synapse/storage/data_stores/main/keys.py | 2 +- synapse/storage/data_stores/main/presence.py | 2 +- synapse/util/__init__.py | 17 ----------- synapse/util/iterutils.py | 35 +++++++++++++++++++++++ 8 files changed, 41 insertions(+), 23 deletions(-) create mode 100644 synapse/util/iterutils.py (limited to 'synapse/storage') diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index 54ed8574c4..bf91512daf 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine -from synapse.util import batch_iter +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 9a828231c4..f0a7962dd0 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -33,13 +33,13 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database from synapse.types import get_verify_key_from_cross_signing_key -from synapse.util import batch_iter from synapse.util.caches.descriptors import ( Cache, cached, cachedInlineCallbacks, cachedList, ) +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index e9fe63037b..bb69c20448 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -43,9 +43,9 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.database import Database from synapse.types import RoomStreamToken, get_domain_from_id -from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 0cce5232f5..3b93e0597a 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -37,8 +37,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database from synapse.types import get_domain_from_id -from synapse.util import batch_iter from synapse.util.caches.descriptors import Cache +from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index 6b12f5a75f..ba89c68c9f 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -23,8 +23,8 @@ from signedjson.key import decode_verify_key_bytes from synapse.storage._base import SQLBaseStore from synapse.storage.keys import FetchKeyResult -from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index a2c83e0867..604c8b7ddd 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -17,8 +17,8 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.presence import UserPresenceState -from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 7856353002..60f0de70f7 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -15,7 +15,6 @@ import logging import re -from itertools import islice import attr @@ -107,22 +106,6 @@ class Clock(object): raise -def batch_iter(iterable, size): - """batch an iterable up into tuples with a maximum size - - Args: - iterable (iterable): the iterable to slice - size (int): the maximum batch size - - Returns: - an iterator over the chunks - """ - # make sure we can deal with iterables like lists too - sourceiter = iter(iterable) - # call islice until it returns an empty tuple - return iter(lambda: tuple(islice(sourceiter, size)), ()) - - def log_failure(failure, msg, consumeErrors=True): """Creates a function suitable for passing to `Deferred.addErrback` that logs any failures that occur. diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py new file mode 100644 index 0000000000..c10016fbc5 --- /dev/null +++ b/synapse/util/iterutils.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from itertools import islice +from typing import Iterable, Iterator, Sequence, Tuple, TypeVar + +T = TypeVar("T") + + +def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: + """batch an iterable up into tuples with a maximum size + + Args: + iterable (iterable): the iterable to slice + size (int): the maximum batch size + + Returns: + an iterator over the chunks + """ + # make sure we can deal with iterables like lists too + sourceiter = iter(iterable) + # call islice until it returns an empty tuple + return iter(lambda: tuple(islice(sourceiter, size)), ()) -- cgit 1.5.1 From 722b4f302d705f497355f206ecb160de1bef2074 Mon Sep 17 00:00:00 2001 From: Satsuki Yanagi <17376330+u1-liquid@users.noreply.github.com> Date: Fri, 17 Jan 2020 23:30:35 +0900 Subject: Fix syntax error in run_upgrade for schema 57 (#6728) Fix #6727 Related #6655 Co-authored-by: Erik Johnston --- changelog.d/6728.bugfix | 1 + .../data_stores/main/schema/delta/57/local_current_membership.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) create mode 100644 changelog.d/6728.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/6728.bugfix b/changelog.d/6728.bugfix new file mode 100644 index 0000000000..5a136e17be --- /dev/null +++ b/changelog.d/6728.bugfix @@ -0,0 +1 @@ +Fix a bug causing `ValueError: unsupported format character ''' (0x27) at index 312` error when running the schema 57 upgrade script. diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py index 601c236c4a..63b5acdcf7 100644 --- a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py +++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py @@ -56,7 +56,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): INSERT INTO local_current_membership (room_id, user_id, event_id, membership) SELECT c.room_id, state_key AS user_id, event_id, c.membership FROM current_state_events AS c - WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key like '%' || ? + WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key LIKE ? """ else: # We can't rely on the membership column, so we need to join against @@ -66,9 +66,10 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): SELECT c.room_id, state_key AS user_id, event_id, r.membership FROM current_state_events AS c INNER JOIN room_memberships AS r USING (event_id) - WHERE type = 'm.room.member' and state_key like '%' || ? + WHERE type = 'm.room.member' AND state_key LIKE ? """ - cur.execute(sql, (config.server_name,)) + sql = database_engine.convert_param_style(sql) + cur.execute(sql, ("%:" + config.server_name,)) cur.execute( "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" -- cgit 1.5.1 From 0e68760078c0aac57bfaeb681d534231e191315a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 20 Jan 2020 18:07:20 +0000 Subject: Add a DeltaState to track changes to be made to current state (#6716) --- changelog.d/6716.misc | 1 + synapse/storage/data_stores/main/events.py | 87 ++++++++++---------- synapse/storage/persist_events.py | 123 ++++++++++++++++------------- 3 files changed, 112 insertions(+), 99 deletions(-) create mode 100644 changelog.d/6716.misc (limited to 'synapse/storage') diff --git a/changelog.d/6716.misc b/changelog.d/6716.misc new file mode 100644 index 0000000000..319aaa4acb --- /dev/null +++ b/changelog.d/6716.misc @@ -0,0 +1 @@ +Add a `DeltaState` to track changes to be made to current state during event persistence. diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index bb69c20448..596daf8909 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -19,6 +19,7 @@ import itertools import logging from collections import Counter as c_counter, OrderedDict, namedtuple from functools import wraps +from typing import Dict, List, Tuple from six import iteritems, text_type from six.moves import range @@ -41,8 +42,9 @@ from synapse.storage._base import make_in_list_sql_clause from synapse.storage.data_stores.main.event_federation import EventFederationStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore -from synapse.storage.database import Database -from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.persist_events import DeltaState +from synapse.types import RoomStreamToken, StateMap, get_domain_from_id from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.iterutils import batch_iter @@ -148,30 +150,26 @@ class EventsStore( @defer.inlineCallbacks def _persist_events_and_state_updates( self, - events_and_contexts, - current_state_for_room, - state_delta_for_room, - new_forward_extremeties, - backfilled=False, - delete_existing=False, + events_and_contexts: List[Tuple[EventBase, EventContext]], + current_state_for_room: Dict[str, StateMap[str]], + state_delta_for_room: Dict[str, DeltaState], + new_forward_extremeties: Dict[str, List[str]], + backfilled: bool = False, + delete_existing: bool = False, ): """Persist a set of events alongside updates to the current state and forward extremities tables. Args: - events_and_contexts (list[(EventBase, EventContext)]): - current_state_for_room (dict[str, dict]): Map from room_id to the - current state of the room based on forward extremities - state_delta_for_room (dict[str, tuple]): Map from room_id to tuple - of `(to_delete, to_insert)` where to_delete is a list - of type/state keys to remove from current state, and to_insert - is a map (type,key)->event_id giving the state delta in each - room. - new_forward_extremities (dict[str, list[str]]): Map from room_id - to list of event IDs that are the new forward extremities of - the room. - backfilled (bool) - delete_existing (bool): + events_and_contexts: + current_state_for_room: Map from room_id to the current state of + the room based on forward extremities + state_delta_for_room: Map from room_id to the delta to apply to + room state + new_forward_extremities: Map from room_id to list of event IDs + that are the new forward extremities of the room. + backfilled + delete_existing Returns: Deferred: resolves when the events have been persisted @@ -352,12 +350,12 @@ class EventsStore( @log_function def _persist_events_txn( self, - txn, - events_and_contexts, - backfilled, - delete_existing=False, - state_delta_for_room={}, - new_forward_extremeties={}, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, + delete_existing: bool = False, + state_delta_for_room: Dict[str, DeltaState] = {}, + new_forward_extremeties: Dict[str, List[str]] = {}, ): """Insert some number of room events into the necessary database tables. @@ -366,21 +364,16 @@ class EventsStore( whether the event was rejected. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): - events to persist - backfilled (bool): True if the events were backfilled - delete_existing (bool): True to purge existing table rows for the - events from the database. This is useful when retrying due to + txn + events_and_contexts: events to persist + backfilled: True if the events were backfilled + delete_existing True to purge existing table rows for the events + from the database. This is useful when retrying due to IntegrityError. - state_delta_for_room (dict[str, (list, dict)]): - The current-state delta for each room. For each room, a tuple - (to_delete, to_insert), being a list of type/state keys to be - removed from the current state, and a state set to be added to - the current state. - new_forward_extremeties (dict[str, list[str]]): - The new forward extremities for each room. For each room, a - list of the event ids which are the forward extremities. + state_delta_for_room: The current-state delta for each room. + new_forward_extremetie: The new forward extremities for each room. + For each room, a list of the event ids which are the forward + extremities. """ all_events_and_contexts = events_and_contexts @@ -465,9 +458,15 @@ class EventsStore( # room_memberships, where applicable. self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) - def _update_current_state_txn(self, txn, state_delta_by_room, stream_id): - for room_id, current_state_tuple in iteritems(state_delta_by_room): - to_delete, to_insert = current_state_tuple + def _update_current_state_txn( + self, + txn: LoggingTransaction, + state_delta_by_room: Dict[str, DeltaState], + stream_id: int, + ): + for room_id, delta_state in iteritems(state_delta_by_room): + to_delete = delta_state.to_delete + to_insert = delta_state.to_insert # First we add entries to the current_state_delta_stream. We # do this before updating the current_state_events table so diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 1ed44925fc..368c457321 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -17,19 +17,24 @@ import logging from collections import deque, namedtuple +from typing import Iterable, List, Optional, Tuple from six import iteritems from six.moves import range +import attr from prometheus_client import Counter, Histogram from twisted.internet import defer from synapse.api.constants import EventTypes +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateResolutionStore from synapse.storage.data_stores import DataStores +from synapse.types import StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -67,6 +72,19 @@ stale_forward_extremities_counter = Histogram( ) +@attr.s(slots=True, frozen=True) +class DeltaState: + """Deltas to use to update the `current_state_events` table. + + Attributes: + to_delete: List of type/state_keys to delete from current state + to_insert: Map of state to upsert into current state + """ + + to_delete = attr.ib(type=List[Tuple[str, str]]) + to_insert = attr.ib(type=StateMap[str]) + + class _EventPeristenceQueue(object): """Queues up events so that they can be persisted in bulk with only one concurrent transaction per room. @@ -138,13 +156,12 @@ class _EventPeristenceQueue(object): self._currently_persisting_rooms.add(room_id) - @defer.inlineCallbacks - def handle_queue_loop(): + async def handle_queue_loop(): try: queue = self._get_drainining_queue(room_id) for item in queue: try: - ret = yield per_item_callback(item) + ret = await per_item_callback(item) except Exception: with PreserveLoggingContext(): item.deferred.errback() @@ -191,12 +208,16 @@ class EventsPersistenceStorage(object): self._state_resolution_handler = hs.get_state_resolution_handler() @defer.inlineCallbacks - def persist_events(self, events_and_contexts, backfilled=False): + def persist_events( + self, + events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + backfilled: bool = False, + ): """ Write events to the database Args: events_and_contexts: list of tuples of (event, context) - backfilled (bool): Whether the results are retrieved from federation + backfilled: Whether the results are retrieved from federation via backfill or not. Used to determine if they're "new" events which might update the current state etc. @@ -226,16 +247,12 @@ class EventsPersistenceStorage(object): return max_persisted_id @defer.inlineCallbacks - def persist_event(self, event, context, backfilled=False): + def persist_event( + self, event: FrozenEvent, context: EventContext, backfilled: bool = False + ): """ - - Args: - event (EventBase): - context (EventContext): - backfilled (bool): - Returns: - Deferred: resolves to (int, int): the stream ordering of ``event``, + Deferred[Tuple[int, int]]: the stream ordering of ``event``, and the stream ordering of the latest persisted event """ deferred = self._event_persist_queue.add_to_queue( @@ -249,28 +266,22 @@ class EventsPersistenceStorage(object): max_persisted_id = yield self.main_store.get_current_events_token() return (event.internal_metadata.stream_ordering, max_persisted_id) - def _maybe_start_persisting(self, room_id): - @defer.inlineCallbacks - def persisting_queue(item): + def _maybe_start_persisting(self, room_id: str): + async def persisting_queue(item): with Measure(self._clock, "persist_events"): - yield self._persist_events( + await self._persist_events( item.events_and_contexts, backfilled=item.backfilled ) self._event_persist_queue.handle_queue(room_id, persisting_queue) - @defer.inlineCallbacks - def _persist_events(self, events_and_contexts, backfilled=False): + async def _persist_events( + self, + events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + backfilled: bool = False, + ): """Calculates the change to current state and forward extremities, and persists the given events and with those updates. - - Args: - events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): - delete_existing (bool): - - Returns: - Deferred: resolves when the events have been persisted """ if not events_and_contexts: return @@ -315,10 +326,10 @@ class EventsPersistenceStorage(object): ) for room_id, ev_ctx_rm in iteritems(events_by_room): - latest_event_ids = yield self.main_store.get_latest_event_ids_in_room( + latest_event_ids = await self.main_store.get_latest_event_ids_in_room( room_id ) - new_latest_event_ids = yield self._calculate_new_extremities( + new_latest_event_ids = await self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids ) @@ -374,7 +385,7 @@ class EventsPersistenceStorage(object): with Measure( self._clock, "persist_events.get_new_state_after_events" ): - res = yield self._get_new_state_after_events( + res = await self._get_new_state_after_events( room_id, ev_ctx_rm, latest_event_ids, @@ -389,12 +400,12 @@ class EventsPersistenceStorage(object): # If there is a delta we know that we've # only added or replaced state, never # removed keys entirely. - state_delta_for_room[room_id] = ([], delta_ids) + state_delta_for_room[room_id] = DeltaState([], delta_ids) elif current_state is not None: with Measure( self._clock, "persist_events.calculate_state_delta" ): - delta = yield self._calculate_state_delta( + delta = await self._calculate_state_delta( room_id, current_state ) state_delta_for_room[room_id] = delta @@ -404,7 +415,7 @@ class EventsPersistenceStorage(object): if current_state is not None: current_state_for_room[room_id] = current_state - yield self.main_store._persist_events_and_state_updates( + await self.main_store._persist_events_and_state_updates( chunk, current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, @@ -412,8 +423,12 @@ class EventsPersistenceStorage(object): backfilled=backfilled, ) - @defer.inlineCallbacks - def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): + async def _calculate_new_extremities( + self, + room_id: str, + event_contexts: List[Tuple[FrozenEvent, EventContext]], + latest_event_ids: List[str], + ): """Calculates the new forward extremities for a room given events to persist. @@ -444,13 +459,13 @@ class EventsPersistenceStorage(object): ) # Remove any events which are prev_events of any existing events. - existing_prevs = yield self.main_store._get_events_which_are_prevs(result) + existing_prevs = await self.main_store._get_events_which_are_prevs(result) result.difference_update(existing_prevs) # Finally handle the case where the new events have soft-failed prev # events. If they do we need to remove them and their prev events, # otherwise we end up with dangling extremities. - existing_prevs = yield self.main_store._get_prevs_before_rejected( + existing_prevs = await self.main_store._get_prevs_before_rejected( e_id for event in new_events for e_id in event.prev_event_ids() ) result.difference_update(existing_prevs) @@ -464,10 +479,13 @@ class EventsPersistenceStorage(object): return result - @defer.inlineCallbacks - def _get_new_state_after_events( - self, room_id, events_context, old_latest_event_ids, new_latest_event_ids - ): + async def _get_new_state_after_events( + self, + room_id: str, + events_context: List[Tuple[FrozenEvent, EventContext]], + old_latest_event_ids: Iterable[str], + new_latest_event_ids: Iterable[str], + ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: """Calculate the current state dict after adding some new events to a room @@ -485,7 +503,6 @@ class EventsPersistenceStorage(object): the new forward extremities for the room. Returns: - Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]: Returns a tuple of two state maps, the first being the full new current state and the second being the delta to the existing current state. If both are None then there has been no change. @@ -547,7 +564,7 @@ class EventsPersistenceStorage(object): if missing_event_ids: # Now pull out the state groups for any missing events from DB - event_to_groups = yield self.main_store._get_state_group_for_events( + event_to_groups = await self.main_store._get_state_group_for_events( missing_event_ids ) event_id_to_state_group.update(event_to_groups) @@ -588,7 +605,7 @@ class EventsPersistenceStorage(object): # their state IDs so we can resolve to a single state set. missing_state = new_state_groups - set(state_groups_map) if missing_state: - group_to_state = yield self.state_store._get_state_for_groups(missing_state) + group_to_state = await self.state_store._get_state_for_groups(missing_state) state_groups_map.update(group_to_state) if len(new_state_groups) == 1: @@ -612,10 +629,10 @@ class EventsPersistenceStorage(object): break if not room_version: - room_version = yield self.main_store.get_room_version(room_id) + room_version = await self.main_store.get_room_version(room_id) logger.debug("calling resolve_state_groups from preserve_events") - res = yield self._state_resolution_handler.resolve_state_groups( + res = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, state_groups, @@ -625,18 +642,14 @@ class EventsPersistenceStorage(object): return res.state, None - @defer.inlineCallbacks - def _calculate_state_delta(self, room_id, current_state): + async def _calculate_state_delta( + self, room_id: str, current_state: StateMap[str] + ) -> DeltaState: """Calculate the new state deltas for a room. Assumes that we are only persisting events for one room at a time. - - Returns: - tuple[list, dict] (to_delete, to_insert): where to_delete are the - type/state_keys to remove from current_state_events and `to_insert` - are the updates to current_state_events. """ - existing_state = yield self.main_store.get_current_state_ids(room_id) + existing_state = await self.main_store.get_current_state_ids(room_id) to_delete = [key for key in existing_state if key not in current_state] @@ -646,4 +659,4 @@ class EventsPersistenceStorage(object): if ev_id != existing_state.get(key) } - return to_delete, to_insert + return DeltaState(to_delete=to_delete, to_insert=to_insert) -- cgit 1.5.1 From 5d7a6ad2238981646b2ae7b4071d8715281d181a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 22 Jan 2020 10:37:00 +0000 Subject: Allow streaming cache invalidate all to workers. (#6749) --- changelog.d/6749.misc | 1 + docs/tcp_replication.md | 5 +++++ synapse/replication/slave/storage/_base.py | 7 ++++++- synapse/replication/tcp/streams/_base.py | 26 +++++++++++++++++++++----- synapse/storage/_base.py | 18 +++++++++++++----- synapse/storage/data_stores/main/cache.py | 27 +++++++++++++++++++++++---- 6 files changed, 69 insertions(+), 15 deletions(-) create mode 100644 changelog.d/6749.misc (limited to 'synapse/storage') diff --git a/changelog.d/6749.misc b/changelog.d/6749.misc new file mode 100644 index 0000000000..9fa13cb1d4 --- /dev/null +++ b/changelog.d/6749.misc @@ -0,0 +1 @@ +Allow streaming cache 'invalidate all' to workers. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index a0b1d563ff..e3a4634b14 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -254,6 +254,11 @@ and they key to invalidate. For example: > RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] +Alternatively, an entire cache can be invalidated by sending down a `null` +instead of the key. For example: + + > RDATA caches 550953772 ["get_user_by_id", null, 1550574873252] + However, there are times when a number of caches need to be invalidated at the same time with the same key. To reduce traffic we batch those invalidations into a single poke by defining a special cache name that diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 704282c800..f45cbd37a0 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -66,11 +66,16 @@ class BaseSlavedStore(SQLBaseStore): self._cache_id_gen.advance(token) for row in rows: if row.cache_func == CURRENT_STATE_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for current state cache" + ) + room_id = row.keys[0] members_changed = set(row.keys[1:]) self._invalidate_state_caches(room_id, members_changed) else: - self._attempt_to_invalidate_cache(row.cache_func, tuple(row.keys)) + self._attempt_to_invalidate_cache(row.cache_func, row.keys) def _invalidate_cache_and_stream(self, txn, cache_func, keys): txn.call_after(cache_func.invalidate, keys) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index e03e77199b..a8d568b14a 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -17,7 +17,9 @@ import itertools import logging from collections import namedtuple -from typing import Any +from typing import Any, List, Optional + +import attr logger = logging.getLogger(__name__) @@ -65,10 +67,24 @@ PushersStreamRow = namedtuple( "PushersStreamRow", ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool ) -CachesStreamRow = namedtuple( - "CachesStreamRow", - ("cache_func", "keys", "invalidation_ts"), # str # list(str) # int -) + + +@attr.s +class CachesStreamRow: + """Stream to inform workers they should invalidate their cache. + + Attributes: + cache_func: Name of the cached function. + keys: The entry in the cache to invalidate. If None then will + invalidate all. + invalidation_ts: Timestamp of when the invalidation took place. + """ + + cache_func = attr.ib(type=str) + keys = attr.ib(type=Optional[List[Any]]) + invalidation_ts = attr.ib(type=int) + + PublicRoomsStreamRow = namedtuple( "PublicRoomsStreamRow", ( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 3bb9381663..da3b99f93d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,6 +17,7 @@ import logging import random from abc import ABCMeta +from typing import Any, Optional from six import PY2 from six.moves import builtins @@ -26,7 +27,7 @@ from canonicaljson import json from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import Database -from synapse.types import get_domain_from_id +from synapse.types import Collection, get_domain_from_id logger = logging.getLogger(__name__) @@ -63,17 +64,24 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) - def _attempt_to_invalidate_cache(self, cache_name, key): + def _attempt_to_invalidate_cache( + self, cache_name: str, key: Optional[Collection[Any]] + ): """Attempts to invalidate the cache of the given name, ignoring if the cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. Args: - cache_name (str) - key (tuple) + cache_name + key: Entry to invalidate. If None then invalidates the entire + cache. """ + try: - getattr(self, cache_name).invalidate(key) + if key is None: + getattr(self, cache_name).invalidate_all() + else: + getattr(self, cache_name).invalidate(tuple(key)) except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index bf91512daf..afa2b41c98 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -16,6 +16,7 @@ import itertools import logging +from typing import Any, Iterable, Optional from twisted.internet import defer @@ -43,6 +44,14 @@ class CacheInvalidationStore(SQLBaseStore): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + def _invalidate_all_cache_and_stream(self, txn, cache_func): + """Invalidates the entire cache and adds it to the cache stream so slaves + will know to invalidate their caches. + """ + + txn.call_after(cache_func.invalidate_all) + self._send_invalidation_to_replication(txn, cache_func.__name__, None) + def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): """Special case invalidation of caches based on current state. @@ -73,17 +82,24 @@ class CacheInvalidationStore(SQLBaseStore): txn, CURRENT_STATE_CACHE_NAME, [room_id] ) - def _send_invalidation_to_replication(self, txn, cache_name, keys): + def _send_invalidation_to_replication( + self, txn, cache_name: str, keys: Optional[Iterable[Any]] + ): """Notifies replication that given cache has been invalidated. Note that this does *not* invalidate the cache locally. Args: txn - cache_name (str) - keys (iterable[str]) + cache_name + keys: Entry to invalidate. If None will invalidate all. """ + if cache_name == CURRENT_STATE_CACHE_NAME and keys is None: + raise Exception( + "Can't stream invalidate all with magic current state cache" + ) + if isinstance(self.database_engine, PostgresEngine): # get_next() returns a context manager which is designed to wrap # the transaction. However, we want to only get an ID when we want @@ -95,13 +111,16 @@ class CacheInvalidationStore(SQLBaseStore): txn.call_after(ctx.__exit__, None, None, None) txn.call_after(self.hs.get_notifier().on_new_replication_data) + if keys is not None: + keys = list(keys) + self.db.simple_insert_txn( txn, table="cache_invalidation_stream", values={ "stream_id": stream_id, "cache_func": cache_name, - "keys": list(keys), + "keys": keys, "invalidation_ts": self.clock.time_msec(), }, ) -- cgit 1.5.1 From 5e52d8563bdc0ab6667f0ec2571f35791720a40a Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 22 Jan 2020 11:05:14 +0000 Subject: Allow monthly active user limiting support for worker mode, fixes #4639. (#6742) --- changelog.d/6742.bugfix | 1 + synapse/app/client_reader.py | 4 + synapse/app/event_creator.py | 4 + synapse/app/federation_reader.py | 4 + synapse/app/synchrotron.py | 4 + .../data_stores/main/monthly_active_users.py | 165 +++++++++++---------- 6 files changed, 100 insertions(+), 82 deletions(-) create mode 100644 changelog.d/6742.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/6742.bugfix b/changelog.d/6742.bugfix new file mode 100644 index 0000000000..ca2687c8bb --- /dev/null +++ b/changelog.d/6742.bugfix @@ -0,0 +1 @@ +Fix monthly active user limiting support for worker mode, fixes [#4639](https://github.com/matrix-org/synapse/issues/4639). diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 3edfe19567..ca96da6a4a 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -62,6 +62,9 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.server import HomeServer +from synapse.storage.data_stores.main.monthly_active_users import ( + MonthlyActiveUsersWorkerStore, +) from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -85,6 +88,7 @@ class ClientReaderSlavedStore( SlavedTransactionStore, SlavedProfileStore, SlavedClientIpStore, + MonthlyActiveUsersWorkerStore, BaseSlavedStore, ): pass diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index d0ddbe38fc..58e5b354f6 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -56,6 +56,9 @@ from synapse.rest.client.v1.room import ( RoomStateEventRestServlet, ) from synapse.server import HomeServer +from synapse.storage.data_stores.main.monthly_active_users import ( + MonthlyActiveUsersWorkerStore, +) from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole @@ -81,6 +84,7 @@ class EventCreatorSlavedStore( SlavedEventStore, SlavedRegistrationStore, RoomStore, + MonthlyActiveUsersWorkerStore, BaseSlavedStore, ): pass diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 311523e0ed..1f1cea1416 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -46,6 +46,9 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer +from synapse.storage.data_stores.main.monthly_active_users import ( + MonthlyActiveUsersWorkerStore, +) from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -66,6 +69,7 @@ class FederationReaderSlavedStore( RoomStore, DirectoryStore, SlavedTransactionStore, + MonthlyActiveUsersWorkerStore, BaseSlavedStore, ): pass diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 3218da07bd..8982c0676e 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -54,6 +54,9 @@ from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.room import RoomInitialSyncRestServlet from synapse.rest.client.v2_alpha import sync from synapse.server import HomeServer +from synapse.storage.data_stores.main.monthly_active_users import ( + MonthlyActiveUsersWorkerStore, +) from synapse.storage.data_stores.main.presence import UserPresenceState from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole @@ -77,6 +80,7 @@ class SynchrotronSlavedStore( SlavedEventStore, SlavedClientIpStore, RoomStore, + MonthlyActiveUsersWorkerStore, BaseSlavedStore, ): pass diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 27158534cb..89a41542a3 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -27,12 +27,76 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 60 * 60 * 1000 -class MonthlyActiveUsersStore(SQLBaseStore): +class MonthlyActiveUsersWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): - super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) + super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs + + @cached(num_args=0) + def get_monthly_active_count(self): + """Generates current count of monthly active users + + Returns: + Defered[int]: Number of current monthly active users + """ + + def _count_users(txn): + sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + + txn.execute(sql) + (count,) = txn.fetchone() + return count + + return self.db.runInteraction("count_users", _count_users) + + @defer.inlineCallbacks + def get_registered_reserved_users(self): + """Of the reserved threepids defined in config, which are associated + with registered users? + + Returns: + Defered[list]: Real reserved users + """ + users = [] + + for tp in self.hs.config.mau_limits_reserved_threepids[ + : self.hs.config.max_mau_value + ]: + user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + tp["medium"], tp["address"] + ) + if user_id: + users.append(user_id) + + return users + + @cached(num_args=1) + def user_last_seen_monthly_active(self, user_id): + """ + Checks if a given user is part of the monthly active user group + Arguments: + user_id (str): user to add/update + Return: + Deferred[int] : timestamp since last seen, None if never seen + + """ + + return self.db.simple_select_one_onecol( + table="monthly_active_users", + keyvalues={"user_id": user_id}, + retcol="timestamp", + allow_none=True, + desc="user_last_seen_monthly_active", + ) + + +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): + def __init__(self, database: Database, db_conn, hs): + super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) + # Do not add more reserved users than the total allowable number + # cur = LoggingTransaction( self.db.new_transaction( db_conn, "initialise_mau_threepids", @@ -146,57 +210,22 @@ class MonthlyActiveUsersStore(SQLBaseStore): txn.execute(sql, query_args) + # It seems poor to invalidate the whole cache, Postgres supports + # 'Returning' which would allow me to invalidate only the + # specific users, but sqlite has no way to do this and instead + # I would need to SELECT and the DELETE which without locking + # is racy. + # Have resolved to invalidate the whole cache for now and do + # something about it if and when the perf becomes significant + self._invalidate_all_cache_and_stream( + txn, self.user_last_seen_monthly_active + ) + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + reserved_users = yield self.get_registered_reserved_users() yield self.db.runInteraction( "reap_monthly_active_users", _reap_users, reserved_users ) - # It seems poor to invalidate the whole cache, Postgres supports - # 'Returning' which would allow me to invalidate only the - # specific users, but sqlite has no way to do this and instead - # I would need to SELECT and the DELETE which without locking - # is racy. - # Have resolved to invalidate the whole cache for now and do - # something about it if and when the perf becomes significant - self.user_last_seen_monthly_active.invalidate_all() - self.get_monthly_active_count.invalidate_all() - - @cached(num_args=0) - def get_monthly_active_count(self): - """Generates current count of monthly active users - - Returns: - Defered[int]: Number of current monthly active users - """ - - def _count_users(txn): - sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" - - txn.execute(sql) - (count,) = txn.fetchone() - return count - - return self.db.runInteraction("count_users", _count_users) - - @defer.inlineCallbacks - def get_registered_reserved_users(self): - """Of the reserved threepids defined in config, which are associated - with registered users? - - Returns: - Defered[list]: Real reserved users - """ - users = [] - - for tp in self.hs.config.mau_limits_reserved_threepids[ - : self.hs.config.max_mau_value - ]: - user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - tp["medium"], tp["address"] - ) - if user_id: - users.append(user_id) - - return users @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): @@ -222,23 +251,9 @@ class MonthlyActiveUsersStore(SQLBaseStore): "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) - user_in_mau = self.user_last_seen_monthly_active.cache.get( - (user_id,), None, update_metrics=False - ) - if user_in_mau is None: - self.get_monthly_active_count.invalidate(()) - - self.user_last_seen_monthly_active.invalidate((user_id,)) - def upsert_monthly_active_user_txn(self, txn, user_id): """Updates or inserts monthly active user member - Note that, after calling this method, it will generally be necessary - to invalidate the caches on user_last_seen_monthly_active and - get_monthly_active_count. We can't do that here, because we are running - in a database thread rather than the main thread, and we can't call - txn.call_after because txn may not be a LoggingTransaction. - We consciously do not call is_support_txn from this method because it is not possible to cache the response. is_support_txn will be false in almost all cases, so it seems reasonable to call it only for @@ -269,27 +284,13 @@ class MonthlyActiveUsersStore(SQLBaseStore): values={"timestamp": int(self._clock.time_msec())}, ) - return is_insert - - @cached(num_args=1) - def user_last_seen_monthly_active(self, user_id): - """ - Checks if a given user is part of the monthly active user group - Arguments: - user_id (str): user to add/update - Return: - Deferred[int] : timestamp since last seen, None if never seen - - """ - - return self.db.simple_select_one_onecol( - table="monthly_active_users", - keyvalues={"user_id": user_id}, - retcol="timestamp", - allow_none=True, - desc="user_last_seen_monthly_active", + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + self._invalidate_cache_and_stream( + txn, self.user_last_seen_monthly_active, (user_id,) ) + return is_insert + @defer.inlineCallbacks def populate_monthly_active_users(self, user_id): """Checks on the state of monthly active user limits and optionally -- cgit 1.5.1 From 90a28fb475a29daa9e7a9ee7204f6f76cc8af441 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 22 Jan 2020 13:36:43 +0000 Subject: Admin API to list, filter and sort rooms (#6720) --- changelog.d/6720.feature | 1 + docs/admin_api/rooms.md | 173 ++++++++++++++ synapse/rest/admin/__init__.py | 3 +- synapse/rest/admin/_base.py | 15 ++ synapse/rest/admin/rooms.py | 82 +++++++ synapse/rest/client/v2_alpha/_base.py | 2 +- synapse/storage/data_stores/main/room.py | 125 +++++++++- tests/rest/admin/test_admin.py | 393 ++++++++++++++++++++++++++++++- 8 files changed, 787 insertions(+), 7 deletions(-) create mode 100644 changelog.d/6720.feature create mode 100644 docs/admin_api/rooms.md (limited to 'synapse/storage') diff --git a/changelog.d/6720.feature b/changelog.d/6720.feature new file mode 100644 index 0000000000..dfc1b74d62 --- /dev/null +++ b/changelog.d/6720.feature @@ -0,0 +1 @@ +Add a new admin API to list and filter rooms on the server. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md new file mode 100644 index 0000000000..082721ea95 --- /dev/null +++ b/docs/admin_api/rooms.md @@ -0,0 +1,173 @@ +# List Room API + +The List Room admin API allows server admins to get a list of rooms on their +server. There are various parameters available that allow for filtering and +sorting the returned list. This API supports pagination. + +## Parameters + +The following query parameters are available: + +* `from` - Offset in the returned list. Defaults to `0`. +* `limit` - Maximum amount of rooms to return. Defaults to `100`. +* `order_by` - The method in which to sort the returned list of rooms. Valid values are: + - `alphabetical` - Rooms are ordered alphabetically by room name. This is the default. + - `size` - Rooms are ordered by the number of members. Largest to smallest. +* `dir` - Direction of room order. Either `f` for forwards or `b` for backwards. Setting + this value to `b` will reverse the above sort order. Defaults to `f`. +* `search_term` - Filter rooms by their room name. Search term can be contained in any + part of the room name. Defaults to no filtering. + +The following fields are possible in the JSON response body: + +* `rooms` - An array of objects, each containing information about a room. + - Room objects contain the following fields: + - `room_id` - The ID of the room. + - `name` - The name of the room. + - `canonical_alias` - The canonical (main) alias address of the room. + - `joined_members` - How many users are currently in the room. +* `offset` - The current pagination offset in rooms. This parameter should be + used instead of `next_token` for room offset as `next_token` is + not intended to be parsed. +* `total_rooms` - The total number of rooms this query can return. Using this + and `offset`, you have enough information to know the current + progression through the list. +* `next_batch` - If this field is present, we know that there are potentially + more rooms on the server that did not all fit into this response. + We can use `next_batch` to get the "next page" of results. To do + so, simply repeat your request, setting the `from` parameter to + the value of `next_batch`. +* `prev_batch` - If this field is present, it is possible to paginate backwards. + Use `prev_batch` for the `from` value in the next request to + get the "previous page" of results. + +## Usage + +A standard request with no filtering: + +``` +GET /_synapse/admin/rooms + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!OGEhHVWSdvArJzumhm:matrix.org", + "name": "Matrix HQ", + "canonical_alias": "#matrix:matrix.org", + "joined_members": 8326 + }, + ... (8 hidden items) ... + { + "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", + "name": "This Week In Matrix (TWIM)", + "canonical_alias": "#twim:matrix.org", + "joined_members": 314 + } + ], + "offset": 0, + "total_rooms": 10 +} +``` + +Filtering by room name: + +``` +GET /_synapse/admin/rooms?search_term=TWIM + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", + "name": "This Week In Matrix (TWIM)", + "canonical_alias": "#twim:matrix.org", + "joined_members": 314 + } + ], + "offset": 0, + "total_rooms": 1 +} +``` + +Paginating through a list of rooms: + +``` +GET /_synapse/admin/rooms?order_by=size + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!OGEhHVWSdvArJzumhm:matrix.org", + "name": "Matrix HQ", + "canonical_alias": "#matrix:matrix.org", + "joined_members": 8326 + }, + ... (98 hidden items) ... + { + "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", + "name": "This Week In Matrix (TWIM)", + "canonical_alias": "#twim:matrix.org", + "joined_members": 314 + } + ], + "offset": 0, + "total_rooms": 150 + "next_token": 100 +} +``` + +The presence of the `next_token` parameter tells us that there are more rooms +than returned in this request, and we need to make another request to get them. +To get the next batch of room results, we repeat our request, setting the `from` +parameter to the value of `next_token`. + +``` +GET /_synapse/admin/rooms?order_by=size&from=100 + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", + "name": "Music Theory", + "canonical_alias": "#musictheory:matrix.org", + "joined_members": 127 + }, + ... (48 hidden items) ... + { + "room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk", + "name": "weechat-matrix", + "canonical_alias": "#weechat-matrix:termina.org.uk", + "joined_members": 137 + } + ], + "offset": 100, + "prev_batch": 0, + "total_rooms": 150 +} +``` + +Once the `next_token` parameter is no longer present, we know we've reached the +end of the list. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 2932fe2123..42cc2b062a 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -29,7 +29,7 @@ from synapse.rest.admin._base import ( from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet -from synapse.rest.admin.rooms import ShutdownRoomRestServlet +from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, @@ -188,6 +188,7 @@ def register_servlets(hs, http_server): Register all the admin servlets. """ register_servlets_for_client_rest_resource(hs, http_server) + ListRoomRestServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index afd0647205..459482eb6d 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -40,6 +40,21 @@ def historical_admin_path_patterns(path_regex): ) +def admin_patterns(path_regex: str): + """Returns the list of patterns for an admin endpoint + + Args: + path_regex: The regex string to match. This should NOT have a ^ + as this will be prefixed. + + Returns: + A list of regex patterns. + """ + admin_prefix = "^/_synapse/admin/v1" + patterns = [re.compile(admin_prefix + path_regex)] + return patterns + + async def assert_requester_is_admin(auth, request): """Verify that the requester is an admin user diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f7cc5e9be9..f9b8c0a4f0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -15,15 +15,20 @@ import logging from synapse.api.constants import Membership +from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_integer, parse_json_object_from_request, + parse_string, ) from synapse.rest.admin._base import ( + admin_patterns, assert_user_is_admin, historical_admin_path_patterns, ) +from synapse.storage.data_stores.main.room import RoomSortOrder from synapse.types import create_requester from synapse.util.async_helpers import maybe_awaitable @@ -155,3 +160,80 @@ class ShutdownRoomRestServlet(RestServlet): "new_room_id": new_room_id, }, ) + + +class ListRoomRestServlet(RestServlet): + """ + List all rooms that are known to the homeserver. Results are returned + in a dictionary containing room information. Supports pagination. + """ + + PATTERNS = admin_patterns("/rooms") + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.admin_handler = hs.get_handlers().admin_handler + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + # Extract query parameters + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + order_by = parse_string(request, "order_by", default="alphabetical") + if order_by not in ( + RoomSortOrder.ALPHABETICAL.value, + RoomSortOrder.SIZE.value, + ): + raise SynapseError( + 400, + "Unknown value for order_by: %s" % (order_by,), + errcode=Codes.INVALID_PARAM, + ) + + search_term = parse_string(request, "search_term") + if search_term == "": + raise SynapseError( + 400, + "search_term cannot be an empty string", + errcode=Codes.INVALID_PARAM, + ) + + direction = parse_string(request, "dir", default="f") + if direction not in ("f", "b"): + raise SynapseError( + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + ) + + reverse_order = True if direction == "b" else False + + # Return list of rooms according to parameters + rooms, total_rooms = await self.store.get_rooms_paginate( + start, limit, order_by, reverse_order, search_term + ) + response = { + # next_token should be opaque, so return a value the client can parse + "offset": start, + "rooms": rooms, + "total_rooms": total_rooms, + } + + # Are there more rooms to paginate through after this? + if (start + limit) < total_rooms: + # There are. Calculate where the query should start from next time + # to get the next part of the list + response["next_batch"] = start + limit + + # Is it possible to paginate backwards? Check if we currently have an + # offset + if start > 0: + if start > limit: + # Going back one iteration won't take us to the start. + # Calculate new offset + response["prev_batch"] = start - limit + else: + response["prev_batch"] = 0 + + return 200, response diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 2a3f4dd58f..bc11b4dda4 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -32,7 +32,7 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): Args: path_regex (str): The regex string to match. This should NOT have a ^ - as this will be prefixed. + as this will be prefixed. Returns: SRE_Pattern """ diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 49bab62be3..d968803ad2 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -18,7 +18,8 @@ import collections import logging import re from abc import abstractmethod -from typing import List, Optional, Tuple +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple from six import integer_types @@ -46,6 +47,18 @@ RatelimitOverride = collections.namedtuple( ) +class RoomSortOrder(Enum): + """ + Enum to define the sorting method used when returning rooms with get_rooms_paginate + + ALPHABETICAL = sort rooms alphabetically by name + SIZE = sort rooms by membership size, highest to lowest + """ + + ALPHABETICAL = "alphabetical" + SIZE = "size" + + class RoomWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(RoomWorkerStore, self).__init__(database, db_conn, hs) @@ -281,6 +294,116 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + async def get_rooms_paginate( + self, + start: int, + limit: int, + order_by: RoomSortOrder, + reverse_order: bool, + search_term: Optional[str], + ) -> Tuple[List[Dict[str, Any]], int]: + """Function to retrieve a paginated list of rooms as json. + + Args: + start: offset in the list + limit: maximum amount of rooms to retrieve + order_by: the sort order of the returned list + reverse_order: whether to reverse the room list + search_term: a string to filter room names by + Returns: + A list of room dicts and an integer representing the total number of + rooms that exist given this query + """ + # Filter room names by a string + where_statement = "" + if search_term: + where_statement = "WHERE state.name LIKE ?" + + # Our postgres db driver converts ? -> %s in SQL strings as that's the + # placeholder for postgres. + # HOWEVER, if you put a % into your SQL then everything goes wibbly. + # To get around this, we're going to surround search_term with %'s + # before giving it to the database in python instead + search_term = "%" + search_term + "%" + + # Set ordering + if RoomSortOrder(order_by) == RoomSortOrder.SIZE: + order_by_column = "curr.joined_members" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL: + # Sort alphabetically + order_by_column = "state.name" + order_by_asc = True + else: + raise StoreError( + 500, "Incorrect value for order_by provided: %s" % order_by + ) + + # Whether to return the list in reverse order + if reverse_order: + # Flip the boolean + order_by_asc = not order_by_asc + + # Create one query for getting the limited number of events that the user asked + # for, and another query for getting the total number of events that could be + # returned. Thus allowing us to see if there are more events to paginate through + info_sql = """ + SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members + FROM room_stats_state state + INNER JOIN room_stats_current curr USING (room_id) + %s + ORDER BY %s %s + LIMIT ? + OFFSET ? + """ % ( + where_statement, + order_by_column, + "ASC" if order_by_asc else "DESC", + ) + + # Use a nested SELECT statement as SQL can't count(*) with an OFFSET + count_sql = """ + SELECT count(*) FROM ( + SELECT room_id FROM room_stats_state state + %s + ) AS get_room_ids + """ % ( + where_statement, + ) + + def _get_rooms_paginate_txn(txn): + # Execute the data query + sql_values = (limit, start) + if search_term: + # Add the search term into the WHERE clause + sql_values = (search_term,) + sql_values + txn.execute(info_sql, sql_values) + + # Refactor room query data into a structured dictionary + rooms = [] + for room in txn: + rooms.append( + { + "room_id": room[0], + "name": room[1], + "canonical_alias": room[2], + "joined_members": room[3], + } + ) + + # Execute the count query + + # Add the search term into the WHERE clause if present + sql_values = (search_term,) if search_term else () + txn.execute(count_sql, sql_values) + + room_count = txn.fetchone() + return rooms, room_count[0] + + return await self.db.runInteraction( + "get_rooms_paginate", _get_rooms_paginate_txn, + ) + @cachedInlineCallbacks(max_entries=10000) def get_ratelimit_for_user(self, user_id): """Check if there are any overrides for ratelimiting for the given diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index af4d604e50..0342aed416 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -17,6 +17,7 @@ import json import os import urllib.parse from binascii import unhexlify +from typing import List, Optional from mock import Mock @@ -26,7 +27,7 @@ import synapse.rest.admin from synapse.http.server import JsonResource from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet -from synapse.rest.client.v1 import events, login, room +from synapse.rest.client.v1 import directory, events, login, room from synapse.rest.client.v2_alpha import groups from tests import unittest @@ -468,9 +469,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ) # Extract media ID from the response - server_name_and_media_id = response["content_uri"][ - 6: - ] # Cut off the 'mxc://' bit + server_name_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' server_name, media_id = server_name_and_media_id.split("/") # Attempt to access the media @@ -692,3 +691,389 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): % server_and_media_id_2 ), ) + + +class RoomTestCase(unittest.HomeserverTestCase): + """Test /room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_list_rooms(self): + """Test that we can list rooms""" + # Create 3 test rooms + total_rooms = 3 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + + # Check request completed successfully + self.assertEqual(200, int(channel.code), msg=channel.json_body) + + # Check that response json body contains a "rooms" key + self.assertTrue( + "rooms" in channel.json_body, + msg="Response body does not " "contain a 'rooms' key", + ) + + # Check that 3 rooms were returned + self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) + + # Check their room_ids match + returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] + self.assertEqual(room_ids, returned_room_ids) + + # Check that all fields are available + for r in channel.json_body["rooms"]: + self.assertIn("name", r) + self.assertIn("canonical_alias", r) + self.assertIn("joined_members", r) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # Should be 0 as we aren't paginating + self.assertEqual(channel.json_body["offset"], 0) + + # Check that the prev_batch parameter is not present + self.assertNotIn("prev_batch", channel.json_body) + + # We shouldn't receive a next token here as there's no further rooms to show + self.assertNotIn("next_batch", channel.json_body) + + def test_list_rooms_pagination(self): + """Test that we can get a full list of rooms through pagination""" + # Create 5 test rooms + total_rooms = 5 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Set the name of the rooms so we get a consistent returned ordering + for idx, room_id in enumerate(room_ids): + self.helper.send_state( + room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + returned_room_ids = [] + start = 0 + limit = 2 + + run_count = 0 + should_repeat = True + while should_repeat: + run_count += 1 + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( + start, + limit, + "alphabetical", + ) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + self.assertTrue("rooms" in channel.json_body) + for r in channel.json_body["rooms"]: + returned_room_ids.append(r["room_id"]) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # We're only getting 2 rooms each page, so should be 2 * last run_count + self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) + + if run_count > 1: + # Check the value of prev_batch is correct + self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) + + if "next_batch" not in channel.json_body: + # We have reached the end of the list + should_repeat = False + else: + # Make another query with an updated start value + start = channel.json_body["next_batch"] + + # We should've queried the endpoint 3 times + self.assertEqual( + run_count, + 3, + msg="Should've queried 3 times for 5 rooms with limit 2 per query", + ) + + # Check that we received all of the room ids + self.assertEqual(room_ids, returned_room_ids) + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def test_correct_room_attributes(self): + """Test the correct attributes for a room are returned""" + # Create a test room + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + test_alias = "#test:test" + test_room_name = "something" + + # Have another user join the room + user_2 = self.register_user("user4", "pass") + user_tok_2 = self.login("user4", "pass") + self.helper.join(room_id, user_2, tok=user_tok_2) + + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=self.admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=self.admin_user_tok, + ) + + # Set a name for the room + self.helper.send_state( + room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that only one room was returned + self.assertEqual(len(rooms), 1) + + # And that the value of the total_rooms key was correct + self.assertEqual(channel.json_body["total_rooms"], 1) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that all provided attributes are set + r = rooms[0] + self.assertEqual(room_id, r["room_id"]) + self.assertEqual(test_room_name, r["name"]) + self.assertEqual(test_alias, r["canonical_alias"]) + + def test_room_list_sort_order(self): + """Test room list sort ordering. alphabetical versus number of members, + reversing the order, etc. + """ + # Create 3 test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C + self.helper.send_state( + room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, + ) + + # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 + user_1 = self.register_user("bob1", "pass") + user_1_tok = self.login("bob1", "pass") + self.helper.join(room_id_2, user_1, tok=user_1_tok) + + user_2 = self.register_user("bob2", "pass") + user_2_tok = self.login("bob2", "pass") + self.helper.join(room_id_3, user_2, tok=user_2_tok) + + user_3 = self.register_user("bob3", "pass") + user_3_tok = self.login("bob3", "pass") + self.helper.join(room_id_3, user_3, tok=user_3_tok) + + def _order_test( + order_type: str, expected_room_list: List[str], reverse: bool = False, + ): + """Request the list of rooms in a certain order. Assert that order is what + we expect + + Args: + order_type: The type of ordering to give the server + expected_room_list: The list of room_ids in the order we expect to get + back from the server + """ + # Request the list of rooms in the given order + url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) + if reverse: + url += "&dir=b" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check for the correct total_rooms value + self.assertEqual(channel.json_body["total_rooms"], 3) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that rooms were returned in alphabetical order + returned_order = [r["room_id"] for r in rooms] + self.assertListEqual(expected_room_list, returned_order) # order is checked + + # Test different sort orders, with forward and reverse directions + _order_test("alphabetical", [room_id_1, room_id_2, room_id_3]) + _order_test("alphabetical", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("size", [room_id_3, room_id_2, room_id_1]) + _order_test("size", [room_id_1, room_id_2, room_id_3], reverse=True) + + def test_search_term(self): + """Test that searching for a room works correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + def _search_test( + expected_room_id: Optional[str], + search_term: str, + expected_http_code: int = 200, + ): + """Search for a room and check that the returned room's id is a match + + Args: + expected_room_id: The room_id expected to be returned by the API. Set + to None to expect zero results for the search + search_term: The term to search for room names with + expected_http_code: The expected http code for the request + """ + url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) + + if expected_http_code != 200: + return + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that the expected number of rooms were returned + expected_room_count = 1 if expected_room_id else 0 + self.assertEqual(len(rooms), expected_room_count) + self.assertEqual(channel.json_body["total_rooms"], expected_room_count) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + if expected_room_id: + # Check that the first returned room id is correct + r = rooms[0] + self.assertEqual(expected_room_id, r["room_id"]) + + # Perform search tests + _search_test(room_id_1, "something") + _search_test(room_id_1, "thing") + + _search_test(room_id_2, "else") + _search_test(room_id_2, "se") + + _search_test(None, "foo") + _search_test(None, "bar") + _search_test(None, "", expected_http_code=400) -- cgit 1.5.1 From ce84dd9e207d9ae88e4cf9ca8a9731fcac043969 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 22 Jan 2020 15:09:57 +0000 Subject: Remove unnecessary abstractions in admin handler (#6751) --- changelog.d/6751.misc | 1 + synapse/handlers/admin.py | 62 ------------------------ synapse/rest/admin/users.py | 19 ++++---- synapse/storage/data_stores/main/registration.py | 2 +- 4 files changed, 11 insertions(+), 73 deletions(-) create mode 100644 changelog.d/6751.misc (limited to 'synapse/storage') diff --git a/changelog.d/6751.misc b/changelog.d/6751.misc new file mode 100644 index 0000000000..7222520528 --- /dev/null +++ b/changelog.d/6751.misc @@ -0,0 +1 @@ +Remove some unnecessary admin handler abstraction methods. \ No newline at end of file diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 60a7c938bc..9205865231 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -62,68 +62,6 @@ class AdminHandler(BaseHandler): ret["avatar_url"] = profile.avatar_url return ret - async def get_users(self): - """Function to retrieve a list of users in users table. - - Args: - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - ret = await self.store.get_users() - - return ret - - async def get_users_paginate(self, start, limit, name, guests, deactivated): - """Function to retrieve a paginated list of users from - users list. This will return a json list of users. - - Args: - start (int): start number to begin the query from - limit (int): number of rows to retrieve - name (string): filter for user names - guests (bool): whether to in include guest users - deactivated (bool): whether to include deactivated users - Returns: - defer.Deferred: resolves to json list[dict[str, Any]] - """ - ret = await self.store.get_users_paginate( - start, limit, name, guests, deactivated - ) - - return ret - - async def search_users(self, term): - """Function to search users list for one or more users with - the matched term. - - Args: - term (str): search term - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - ret = await self.store.search_users(term) - - return ret - - def get_user_server_admin(self, user): - """ - Get the admin bit on a user. - - Args: - user_id (UserID): the (necessarily local) user to manipulate - """ - return self.store.is_server_admin(user) - - def set_user_server_admin(self, user, admin): - """ - Set the admin bit on a user. - - Args: - user_id (UserID): the (necessarily local) user to manipulate - admin (bool): whether or not the user should be an admin of this server - """ - return self.store.set_server_admin(user, admin) - async def export_user_data(self, user_id, writer): """Write all data we have on the user to the given writer. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 52d27fa3e3..927e9ca9ee 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -45,6 +45,7 @@ class UsersRestServlet(RestServlet): def __init__(self, hs): self.hs = hs + self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_handlers().admin_handler @@ -55,7 +56,7 @@ class UsersRestServlet(RestServlet): if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only users a local user") - ret = await self.admin_handler.get_users() + ret = await self.store.get_users() return 200, ret @@ -80,6 +81,7 @@ class UsersRestServletV2(RestServlet): def __init__(self, hs): self.hs = hs + self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_handlers().admin_handler @@ -92,7 +94,7 @@ class UsersRestServletV2(RestServlet): guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) - users = await self.admin_handler.get_users_paginate( + users = await self.store.get_users_paginate( start, limit, user_id, guests, deactivated ) ret = {"users": users} @@ -516,8 +518,8 @@ class SearchUsersRestServlet(RestServlet): PATTERNS = historical_admin_path_patterns("/search_users/(?P[^/]*)") def __init__(self, hs): - self.store = hs.get_datastore() self.hs = hs + self.store = hs.get_datastore() self.auth = hs.get_auth() self.handlers = hs.get_handlers() @@ -540,7 +542,7 @@ class SearchUsersRestServlet(RestServlet): term = parse_string(request, "term", required=True) logger.info("term: %s ", term) - ret = await self.handlers.admin_handler.search_users(term) + ret = await self.handlers.store.search_users(term) return 200, ret @@ -574,8 +576,8 @@ class UserAdminServlet(RestServlet): def __init__(self, hs): self.hs = hs + self.store = hs.get_datastore() self.auth = hs.get_auth() - self.handlers = hs.get_handlers() async def on_GET(self, request, user_id): await assert_requester_is_admin(self.auth, request) @@ -585,8 +587,7 @@ class UserAdminServlet(RestServlet): if not self.hs.is_mine(target_user): raise SynapseError(400, "Only local users can be admins of this homeserver") - is_admin = await self.handlers.admin_handler.get_user_server_admin(target_user) - is_admin = bool(is_admin) + is_admin = await self.store.is_server_admin(target_user) return 200, {"admin": is_admin} @@ -609,8 +610,6 @@ class UserAdminServlet(RestServlet): if target_user == auth_user and not set_admin_to: raise SynapseError(400, "You may not demote yourself.") - await self.handlers.admin_handler.set_user_server_admin( - target_user, set_admin_to - ) + await self.store.set_user_server_admin(target_user, set_admin_to) return 200, {} diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index cb4b2b39a0..49306642ed 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -291,7 +291,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="is_server_admin", ) - return res if res else False + return bool(res) if res else False def set_server_admin(self, user, admin): """Sets whether a user is an admin of this homeserver. -- cgit 1.5.1