summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8614.misc1
-rw-r--r--changelog.d/8615.misc1
-rw-r--r--changelog.d/8628.bugfix1
-rw-r--r--mypy.ini1
-rw-r--r--synapse/handlers/groups_local.py5
-rw-r--r--synapse/storage/databases/main/__init__.py1
-rw-r--r--synapse/storage/databases/main/registration.py156
-rw-r--r--tests/storage/test_cleanup_extrems.py6
-rw-r--r--tests/storage/test_event_metrics.py4
-rw-r--r--tests/storage/test_roommember.py4
-rw-r--r--tests/test_federation.py4
-rw-r--r--tests/unittest.py4
12 files changed, 102 insertions, 86 deletions
diff --git a/changelog.d/8614.misc b/changelog.d/8614.misc
new file mode 100644

index 0000000000..1bf9ea08f0 --- /dev/null +++ b/changelog.d/8614.misc
@@ -0,0 +1 @@ +Don't instansiate Requester directly. diff --git a/changelog.d/8615.misc b/changelog.d/8615.misc new file mode 100644
index 0000000000..79fa7b7ff8 --- /dev/null +++ b/changelog.d/8615.misc
@@ -0,0 +1 @@ +Type hints for `RegistrationStore`. diff --git a/changelog.d/8628.bugfix b/changelog.d/8628.bugfix new file mode 100644
index 0000000000..1316136ca2 --- /dev/null +++ b/changelog.d/8628.bugfix
@@ -0,0 +1 @@ +Fix handling of invalid group IDs to return a 400 rather than log an exception and return a 500. diff --git a/mypy.ini b/mypy.ini
index 5e9f7b1259..59d9074c3b 100644 --- a/mypy.ini +++ b/mypy.ini
@@ -57,6 +57,7 @@ files = synapse/spam_checker_api, synapse/state, synapse/storage/databases/main/events.py, + synapse/storage/databases/main/registration.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 9684e60fc8..b2def93bb1 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py
@@ -17,7 +17,7 @@ import logging from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.types import get_domain_from_id +from synapse.types import GroupID, get_domain_from_id logger = logging.getLogger(__name__) @@ -28,6 +28,9 @@ def _create_rerouter(func_name): """ async def f(self, group_id, *args, **kwargs): + if not GroupID.is_valid(group_id): + raise SynapseError(400, "%s was not legal group ID" % (group_id,)) + if self.is_mine_id(group_id): return await getattr(self.groups_server_handler, func_name)( group_id, *args, **kwargs diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9b16f45f3e..43660ec4fb 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -146,7 +146,6 @@ class DataStore( db_conn, "e2e_cross_signing_keys", "stream_id" ) - self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d7774a8bef..6867961c3c 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -16,29 +16,33 @@ # limitations under the License. import logging import re -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool -from synapse.storage.types import Cursor +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.types import Connection, Cursor +from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 logger = logging.getLogger(__name__) -class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): +class RegistrationWorkerStore(CacheInvalidationWorkerStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config - self.clock = hs.get_clock() # Note: we don't check this sequence for consistency as we'd have to # call `find_max_generated_user_id_localpart` each time, which is @@ -61,7 +65,7 @@ class RegistrationWorkerStore(SQLBaseStore): # Create a background job for culling expired 3PID validity tokens if hs.config.run_background_tasks: - self.clock.looping_call( + self._clock.looping_call( self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS ) @@ -98,7 +102,7 @@ class RegistrationWorkerStore(SQLBaseStore): if not info: return False - now = self.clock.time_msec() + now = self._clock.time_msec() trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial @@ -312,7 +316,7 @@ class RegistrationWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, - self.clock.time_msec(), + self._clock.time_msec(), self.config.account_validity_renew_at, ) @@ -431,13 +435,17 @@ class RegistrationWorkerStore(SQLBaseStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," - " access_tokens.device_id, access_tokens.valid_until_ms" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) + sql = """ + SELECT users.name, + users.is_guest, + users.shadow_banned, + access_tokens.id as token_id, + access_tokens.device_id, + access_tokens.valid_until_ms + FROM users + INNER JOIN access_tokens on users.name = access_tokens.user_id + WHERE token = ? + """ txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) @@ -906,7 +914,7 @@ class RegistrationWorkerStore(SQLBaseStore): await self.db_pool.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, - self.clock.time_msec(), + self._clock.time_msec(), ) @wrap_as_background_process("account_validity_set_expiration_dates") @@ -993,10 +1001,10 @@ class RegistrationWorkerStore(SQLBaseStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self.clock = hs.get_clock() + self._clock = hs.get_clock() self.config = hs.config self.db_pool.updates.register_background_index_update( @@ -1119,13 +1127,56 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): return 1 + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: + """Set the `deactivated` property for the provided user to the provided value. + + Args: + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. + """ + + await self.db_pool.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) + + def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool): + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) + txn.call_after(self.is_guest.invalidate, (user_id,)) + + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="is_guest", + allow_none=True, + desc="is_guest", + ) + + return res if res else False + -class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): +class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + async def add_access_token_to_user( self, user_id: str, @@ -1241,19 +1292,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def _register_user( self, txn, - user_id, - password_hash, - was_guest, - make_guest, - appservice_id, - create_profile_with_displayname, - admin, - user_type, - shadow_banned, + user_id: str, + password_hash: Optional[str], + was_guest: bool, + make_guest: bool, + appservice_id: Optional[str], + create_profile_with_displayname: Optional[str], + admin: bool, + user_type: Optional[str], + shadow_banned: bool, ): user_id_obj = UserID.from_string(user_id) - now = int(self.clock.time()) + now = int(self._clock.time()) try: if was_guest: @@ -1477,18 +1528,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): await self.db_pool.runInteraction("delete_access_token", f) - @cached() - async def is_guest(self, user_id: str) -> bool: - res = await self.db_pool.simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="is_guest", - allow_none=True, - desc="is_guest", - ) - - return res if res else False - async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're @@ -1582,7 +1621,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self.clock.time_msec()}, + updatevalues={"validated_at": self._clock.time_msec()}, ) return next_link @@ -1650,35 +1689,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) - async def set_user_deactivated_status( - self, user_id: str, deactivated: bool - ) -> None: - """Set the `deactivated` property for the provided user to the provided value. - - Args: - user_id: The ID of the user to set the status for. - deactivated: The value to set for `deactivated`. - """ - - await self.db_pool.runInteraction( - "set_user_deactivated_status", - self.set_user_deactivated_status_txn, - user_id, - deactivated, - ) - - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"deactivated": 1 if deactivated else 0}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,) - ) - txn.call_after(self.is_guest.invalidate, (user_id,)) - def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 080761d1d2..5a1e5c4e66 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -22,7 +22,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client.v1 import login, room from synapse.storage import prepare_database -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") - self.requester = Requester(self.user, None, False, False, None, None) + self.requester = create_requester(self.user) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] @@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") - self.requester = Requester(self.user, None, False, False, None, None) + self.requester = create_requester(self.user) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 3957471f3f..7691f2d790 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py
@@ -14,7 +14,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, generate_latest -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): room_creator = self.hs.get_room_creation_handler() user = UserID("alice", "test") - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) # Real events, forward extremities events = [(3, 2), (6, 2), (4, 6)] diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 12ccc1f53e..ff972daeaa 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -19,7 +19,7 @@ from unittest.mock import Mock from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests import unittest from tests.test_utils import event_injection @@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Now let's create a room, which will insert a membership user = UserID("alice", "test") - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) self.get_success(self.room_creator.create_room(requester, {})) # Register the background update to run again. diff --git a/tests/test_federation.py b/tests/test_federation.py
index d39e792580..1ce4ea3a01 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -20,7 +20,7 @@ from twisted.internet.defer import succeed from synapse.api.errors import FederationError from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination @@ -43,7 +43,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) user_id = UserID("us", "test") - our_user = Requester(user_id, None, False, False, None, None) + our_user = create_requester(user_id) room_creator = self.homeserver.get_room_creation_handler() self.room_id = self.get_success( room_creator.create_room( diff --git a/tests/unittest.py b/tests/unittest.py
index 040b126a27..257f465897 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -44,7 +44,7 @@ from synapse.logging.context import ( set_current_context, ) from synapse.server import HomeServer -from synapse.types import Requester, UserID, create_requester +from synapse.types import UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import ( @@ -627,7 +627,7 @@ class HomeserverTestCase(TestCase): """ event_creator = self.hs.get_event_creation_handler() secrets = self.hs.get_secrets() - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) event, context = self.get_success( event_creator.create_event(