summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/groups_local.py5
-rw-r--r--synapse/http/servlet.py3
-rw-r--r--synapse/rest/client/v1/events.py3
-rw-r--r--synapse/rest/client/v1/login.py3
-rw-r--r--synapse/rest/client/v1/logout.py6
-rw-r--r--synapse/rest/client/v1/presence.py3
-rw-r--r--synapse/rest/client/v1/profile.py6
-rw-r--r--synapse/rest/client/v1/push_rule.py3
-rw-r--r--synapse/rest/client/v1/pusher.py9
-rw-r--r--synapse/rest/client/v1/room.py17
-rw-r--r--synapse/rest/client/v1/voip.py3
-rw-r--r--synapse/rest/client/v2_alpha/account.py3
-rw-r--r--synapse/rest/client/v2_alpha/auth.py3
-rw-r--r--synapse/rest/client/v2_alpha/register.py3
-rw-r--r--synapse/storage/databases/main/__init__.py1
-rw-r--r--synapse/storage/databases/main/registration.py156
16 files changed, 88 insertions, 139 deletions
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 9684e60fc8..b2def93bb1 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -17,7 +17,7 @@
 import logging
 
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import get_domain_from_id
+from synapse.types import GroupID, get_domain_from_id
 
 logger = logging.getLogger(__name__)
 
@@ -28,6 +28,9 @@ def _create_rerouter(func_name):
     """
 
     async def f(self, group_id, *args, **kwargs):
+        if not GroupID.is_valid(group_id):
+            raise SynapseError(400, "%s was not legal group ID" % (group_id,))
+
         if self.is_mine_id(group_id):
             return await getattr(self.groups_server_handler, func_name)(
                 group_id, *args, **kwargs
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fd90ba7828..b361b7cbaf 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -272,7 +272,6 @@ class RestServlet:
       on_PUT
       on_POST
       on_DELETE
-      on_OPTIONS
 
     Automatically handles turning CodeMessageExceptions thrown by these methods
     into the appropriate HTTP response.
@@ -283,7 +282,7 @@ class RestServlet:
         if hasattr(self, "PATTERNS"):
             patterns = self.PATTERNS
 
-            for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
+            for method in ("GET", "PUT", "POST", "DELETE"):
                 if hasattr(self, "on_%s" % (method,)):
                     servlet_classname = self.__class__.__name__
                     method_handler = getattr(self, "on_%s" % (method,))
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 1ecb77aa26..6de4078290 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -67,9 +67,6 @@ class EventStreamRestServlet(RestServlet):
 
         return 200, chunk
 
-    def on_OPTIONS(self, request):
-        return 200, {}
-
 
 class EventRestServlet(RestServlet):
     PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index b82a4e978a..94452fcbf5 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -114,9 +114,6 @@ class LoginRestServlet(RestServlet):
 
         return 200, {"flows": flows}
 
-    def on_OPTIONS(self, request: SynapseRequest):
-        return 200, {}
-
     async def on_POST(self, request: SynapseRequest):
         self._address_ratelimiter.ratelimit(request.getClientIP())
 
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index f792b50cdc..ad8cea49c6 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -30,9 +30,6 @@ class LogoutRestServlet(RestServlet):
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
 
-    def on_OPTIONS(self, request):
-        return 200, {}
-
     async def on_POST(self, request):
         requester = await self.auth.get_user_by_req(request, allow_expired=True)
 
@@ -58,9 +55,6 @@ class LogoutAllRestServlet(RestServlet):
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
 
-    def on_OPTIONS(self, request):
-        return 200, {}
-
     async def on_POST(self, request):
         requester = await self.auth.get_user_by_req(request, allow_expired=True)
         user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 79d8e3057f..23a529f8e3 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -86,9 +86,6 @@ class PresenceStatusRestServlet(RestServlet):
 
         return 200, {}
 
-    def on_OPTIONS(self, request):
-        return 200, {}
-
 
 def register_servlets(hs, http_server):
     PresenceStatusRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fcd2b1ff..85a66458c5 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -67,9 +67,6 @@ class ProfileDisplaynameRestServlet(RestServlet):
 
         return 200, {}
 
-    def on_OPTIONS(self, request, user_id):
-        return 200, {}
-
 
 class ProfileAvatarURLRestServlet(RestServlet):
     PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -118,9 +115,6 @@ class ProfileAvatarURLRestServlet(RestServlet):
 
         return 200, {}
 
-    def on_OPTIONS(self, request, user_id):
-        return 200, {}
-
 
 class ProfileRestServlet(RestServlet):
     PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index f9eecb7cf5..241e535917 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -155,9 +155,6 @@ class PushRuleRestServlet(RestServlet):
         else:
             raise UnrecognizedRequestError()
 
-    def on_OPTIONS(self, request, path):
-        return 200, {}
-
     def notify_user(self, user_id):
         stream_id = self.store.get_max_push_rules_stream_id()
         self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 28dabf1c7a..8fe83f321a 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -60,9 +60,6 @@ class PushersRestServlet(RestServlet):
 
         return 200, {"pushers": filtered_pushers}
 
-    def on_OPTIONS(self, _):
-        return 200, {}
-
 
 class PushersSetRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers/set$", v1=True)
@@ -140,9 +137,6 @@ class PushersSetRestServlet(RestServlet):
 
         return 200, {}
 
-    def on_OPTIONS(self, _):
-        return 200, {}
-
 
 class PushersRemoveRestServlet(RestServlet):
     """
@@ -182,9 +176,6 @@ class PushersRemoveRestServlet(RestServlet):
         )
         return None
 
-    def on_OPTIONS(self, _):
-        return 200, {}
-
 
 def register_servlets(hs, http_server):
     PushersRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 00b4397082..25d3cc6148 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -72,20 +72,6 @@ class RoomCreateRestServlet(TransactionRestServlet):
     def register(self, http_server):
         PATTERNS = "/createRoom"
         register_txn_path(self, PATTERNS, http_server)
-        # define CORS for all of /rooms in RoomCreateRestServlet for simplicity
-        http_server.register_paths(
-            "OPTIONS",
-            client_patterns("/rooms(?:/.*)?$", v1=True),
-            self.on_OPTIONS,
-            self.__class__.__name__,
-        )
-        # define CORS for /createRoom[/txnid]
-        http_server.register_paths(
-            "OPTIONS",
-            client_patterns("/createRoom(?:/.*)?$", v1=True),
-            self.on_OPTIONS,
-            self.__class__.__name__,
-        )
 
     def on_PUT(self, request, txn_id):
         set_tag("txn_id", txn_id)
@@ -104,9 +90,6 @@ class RoomCreateRestServlet(TransactionRestServlet):
         user_supplied_config = parse_json_object_from_request(request)
         return user_supplied_config
 
-    def on_OPTIONS(self, request):
-        return 200, {}
-
 
 # TODO: Needs unit testing for generic events
 class RoomStateEventRestServlet(TransactionRestServlet):
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index b8d491ca5c..d07ca2c47c 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -69,9 +69,6 @@ class VoipRestServlet(RestServlet):
             },
         )
 
-    def on_OPTIONS(self, request):
-        return 200, {}
-
 
 def register_servlets(hs, http_server):
     VoipRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index e857cff176..51effc4d8e 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -268,9 +268,6 @@ class PasswordRestServlet(RestServlet):
 
         return 200, {}
 
-    def on_OPTIONS(self, _):
-        return 200, {}
-
 
 class DeactivateAccountRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/deactivate$")
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 5fbfae5991..fab077747f 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -176,9 +176,6 @@ class AuthRestServlet(RestServlet):
         respond_with_html(request, 200, html)
         return None
 
-    def on_OPTIONS(self, _):
-        return 200, {}
-
 
 def register_servlets(hs, http_server):
     AuthRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 395b6a82a9..8f2c8cd991 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -642,9 +642,6 @@ class RegisterRestServlet(RestServlet):
 
         return 200, return_dict
 
-    def on_OPTIONS(self, _):
-        return 200, {}
-
     async def _do_appservice_registration(self, username, as_token, body):
         user_id = await self.registration_handler.appservice_register(
             username, as_token
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9b16f45f3e..43660ec4fb 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -146,7 +146,6 @@ class DataStore(
             db_conn, "e2e_cross_signing_keys", "stream_id"
         )
 
-        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 4c843b7679..b0329e17ec 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,29 +16,33 @@
 # limitations under the License.
 import logging
 import re
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
-from synapse.storage.types import Cursor
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.types import Connection, Cursor
+from synapse.storage.util.id_generators import IdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import UserID
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 
 logger = logging.getLogger(__name__)
 
 
-class RegistrationWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+class RegistrationWorkerStore(CacheInvalidationWorkerStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
-        self.clock = hs.get_clock()
 
         # Note: we don't check this sequence for consistency as we'd have to
         # call `find_max_generated_user_id_localpart` each time, which is
@@ -55,7 +59,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         # Create a background job for culling expired 3PID validity tokens
         if hs.config.run_background_tasks:
-            self.clock.looping_call(
+            self._clock.looping_call(
                 self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
             )
 
@@ -92,7 +96,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         if not info:
             return False
 
-        now = self.clock.time_msec()
+        now = self._clock.time_msec()
         trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
         is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
         return is_trial
@@ -257,7 +261,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         return await self.db_pool.runInteraction(
             "get_users_expiring_soon",
             select_users_txn,
-            self.clock.time_msec(),
+            self._clock.time_msec(),
             self.config.account_validity.renew_at,
         )
 
@@ -328,13 +332,17 @@ class RegistrationWorkerStore(SQLBaseStore):
         await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
     def _query_for_auth(self, txn, token):
-        sql = (
-            "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
-            " access_tokens.device_id, access_tokens.valid_until_ms"
-            " FROM users"
-            " INNER JOIN access_tokens on users.name = access_tokens.user_id"
-            " WHERE token = ?"
-        )
+        sql = """
+            SELECT users.name,
+                users.is_guest,
+                users.shadow_banned,
+                access_tokens.id as token_id,
+                access_tokens.device_id,
+                access_tokens.valid_until_ms
+            FROM users
+            INNER JOIN access_tokens on users.name = access_tokens.user_id
+            WHERE token = ?
+        """
 
         txn.execute(sql, (token,))
         rows = self.db_pool.cursor_to_dict(txn)
@@ -803,7 +811,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         await self.db_pool.runInteraction(
             "cull_expired_threepid_validation_tokens",
             cull_expired_threepid_validation_tokens_txn,
-            self.clock.time_msec(),
+            self._clock.time_msec(),
         )
 
     @wrap_as_background_process("account_validity_set_expiration_dates")
@@ -890,10 +898,10 @@ class RegistrationWorkerStore(SQLBaseStore):
 
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
-        self.clock = hs.get_clock()
+        self._clock = hs.get_clock()
         self.config = hs.config
 
         self.db_pool.updates.register_background_index_update(
@@ -1016,13 +1024,56 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
         return 1
 
+    async def set_user_deactivated_status(
+        self, user_id: str, deactivated: bool
+    ) -> None:
+        """Set the `deactivated` property for the provided user to the provided value.
+
+        Args:
+            user_id: The ID of the user to set the status for.
+            deactivated: The value to set for `deactivated`.
+        """
+
+        await self.db_pool.runInteraction(
+            "set_user_deactivated_status",
+            self.set_user_deactivated_status_txn,
+            user_id,
+            deactivated,
+        )
+
+    def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
+        self.db_pool.simple_update_one_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            updatevalues={"deactivated": 1 if deactivated else 0},
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_user_deactivated_status, (user_id,)
+        )
+        txn.call_after(self.is_guest.invalidate, (user_id,))
+
+    @cached()
+    async def is_guest(self, user_id: str) -> bool:
+        res = await self.db_pool.simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="is_guest",
+            allow_none=True,
+            desc="is_guest",
+        )
+
+        return res if res else False
+
 
-class RegistrationStore(RegistrationBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
+        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+
     async def add_access_token_to_user(
         self,
         user_id: str,
@@ -1138,19 +1189,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
     def _register_user(
         self,
         txn,
-        user_id,
-        password_hash,
-        was_guest,
-        make_guest,
-        appservice_id,
-        create_profile_with_displayname,
-        admin,
-        user_type,
-        shadow_banned,
+        user_id: str,
+        password_hash: Optional[str],
+        was_guest: bool,
+        make_guest: bool,
+        appservice_id: Optional[str],
+        create_profile_with_displayname: Optional[str],
+        admin: bool,
+        user_type: Optional[str],
+        shadow_banned: bool,
     ):
         user_id_obj = UserID.from_string(user_id)
 
-        now = int(self.clock.time())
+        now = int(self._clock.time())
 
         try:
             if was_guest:
@@ -1374,18 +1425,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         await self.db_pool.runInteraction("delete_access_token", f)
 
-    @cached()
-    async def is_guest(self, user_id: str) -> bool:
-        res = await self.db_pool.simple_select_one_onecol(
-            table="users",
-            keyvalues={"name": user_id},
-            retcol="is_guest",
-            allow_none=True,
-            desc="is_guest",
-        )
-
-        return res if res else False
-
     async def add_user_pending_deactivation(self, user_id: str) -> None:
         """
         Adds a user to the table of users who need to be parted from all the rooms they're
@@ -1479,7 +1518,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
-                updatevalues={"validated_at": self.clock.time_msec()},
+                updatevalues={"validated_at": self._clock.time_msec()},
             )
 
             return next_link
@@ -1547,35 +1586,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             start_or_continue_validation_session_txn,
         )
 
-    async def set_user_deactivated_status(
-        self, user_id: str, deactivated: bool
-    ) -> None:
-        """Set the `deactivated` property for the provided user to the provided value.
-
-        Args:
-            user_id: The ID of the user to set the status for.
-            deactivated: The value to set for `deactivated`.
-        """
-
-        await self.db_pool.runInteraction(
-            "set_user_deactivated_status",
-            self.set_user_deactivated_status_txn,
-            user_id,
-            deactivated,
-        )
-
-    def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
-        self.db_pool.simple_update_one_txn(
-            txn=txn,
-            table="users",
-            keyvalues={"name": user_id},
-            updatevalues={"deactivated": 1 if deactivated else 0},
-        )
-        self._invalidate_cache_and_stream(
-            txn, self.get_user_deactivated_status, (user_id,)
-        )
-        txn.call_after(self.is_guest.invalidate, (user_id,))
-
 
 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
     """