summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9498.bugfix1
-rw-r--r--changelog.d/9521.misc1
-rw-r--r--changelog.d/9529.misc1
-rw-r--r--changelog.d/9537.bugfix1
-rwxr-xr-xsetup.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py2
-rw-r--r--synapse/rest/admin/users.py85
-rw-r--r--synapse/storage/databases/main/__init__.py10
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py4
-rw-r--r--synapse/storage/databases/main/media_repository.py2
-rw-r--r--synapse/storage/databases/main/purge_events.py42
-rw-r--r--synapse/storage/purge_events.py5
-rw-r--r--synapse/storage/roommember.py2
13 files changed, 109 insertions, 49 deletions
diff --git a/changelog.d/9498.bugfix b/changelog.d/9498.bugfix
new file mode 100644
index 0000000000..dce0ad0920
--- /dev/null
+++ b/changelog.d/9498.bugfix
@@ -0,0 +1 @@
+Properly purge the event chain cover index when purging history.
diff --git a/changelog.d/9521.misc b/changelog.d/9521.misc
new file mode 100644
index 0000000000..1424d9c188
--- /dev/null
+++ b/changelog.d/9521.misc
@@ -0,0 +1 @@
+Add type hints to user admin API.
\ No newline at end of file
diff --git a/changelog.d/9529.misc b/changelog.d/9529.misc
new file mode 100644
index 0000000000..b9021a26b4
--- /dev/null
+++ b/changelog.d/9529.misc
@@ -0,0 +1 @@
+Bump the versions of mypy and mypy-zope used for static type checking.
diff --git a/changelog.d/9537.bugfix b/changelog.d/9537.bugfix
new file mode 100644
index 0000000000..033ab1c939
--- /dev/null
+++ b/changelog.d/9537.bugfix
@@ -0,0 +1 @@
+Fix rare edge case that caused a background update to fail if the server had rejected an event that had duplicate auth events.
diff --git a/setup.py b/setup.py
index 08ba4eb764..bbd9e7862a 100755
--- a/setup.py
+++ b/setup.py
@@ -102,7 +102,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
     "flake8",
 ]
 
-CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
+CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.11"]
 
 # Dependencies which are exclusively required by unit test code. This is
 # NOT a list of all modules that are necessary to run the unit tests.
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 38809b5b7c..f45e7a8c89 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -502,7 +502,7 @@ class AccountDataStream(Stream):
     """Global or per room account data was changed"""
 
     AccountDataStreamRow = namedtuple(
-        "AccountDataStream",
+        "AccountDataStreamRow",
         ("user_id", "room_id", "data_type"),  # str  # Optional[str]  # str
     )
 
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 9c701c7348..267a993430 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -16,7 +16,7 @@ import hashlib
 import hmac
 import logging
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -47,13 +47,15 @@ logger = logging.getLogger(__name__)
 class UsersRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
-    async def on_GET(self, request, user_id):
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, List[JsonDict]]:
         target_user = UserID.from_string(user_id)
         await assert_requester_is_admin(self.auth, request)
 
@@ -153,7 +155,7 @@ class UserRestServletV2(RestServlet):
         otherwise an error.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
@@ -165,7 +167,9 @@ class UserRestServletV2(RestServlet):
         self.registration_handler = hs.get_registration_handler()
         self.pusher_pool = hs.get_pusherpool()
 
-    async def on_GET(self, request, user_id):
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
@@ -179,7 +183,9 @@ class UserRestServletV2(RestServlet):
 
         return 200, ret
 
-    async def on_PUT(self, request, user_id):
+    async def on_PUT(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -273,6 +279,8 @@ class UserRestServletV2(RestServlet):
                     )
 
             user = await self.admin_handler.get_user(target_user)
+            assert user is not None
+
             return 200, user
 
         else:  # create user
@@ -330,9 +338,10 @@ class UserRestServletV2(RestServlet):
                     target_user, requester, body["avatar_url"], True
                 )
 
-            ret = await self.admin_handler.get_user(target_user)
+            user = await self.admin_handler.get_user(target_user)
+            assert user is not None
 
-            return 201, ret
+            return 201, user
 
 
 class UserRegisterServlet(RestServlet):
@@ -346,10 +355,10 @@ class UserRegisterServlet(RestServlet):
     PATTERNS = admin_patterns("/register")
     NONCE_TIMEOUT = 60
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.auth_handler = hs.get_auth_handler()
         self.reactor = hs.get_reactor()
-        self.nonces = {}
+        self.nonces = {}  # type: Dict[str, int]
         self.hs = hs
 
     def _clear_old_nonces(self):
@@ -362,7 +371,7 @@ class UserRegisterServlet(RestServlet):
             if now - v > self.NONCE_TIMEOUT:
                 del self.nonces[k]
 
-    def on_GET(self, request):
+    def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         """
         Generate a new nonce.
         """
@@ -372,7 +381,7 @@ class UserRegisterServlet(RestServlet):
         self.nonces[nonce] = int(self.reactor.seconds())
         return 200, {"nonce": nonce}
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         self._clear_old_nonces()
 
         if not self.hs.config.registration_shared_secret:
@@ -478,12 +487,14 @@ class WhoisRestServlet(RestServlet):
         client_patterns("/admin" + path_regex, v1=True)
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
-    async def on_GET(self, request, user_id):
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
         target_user = UserID.from_string(user_id)
         requester = await self.auth.get_user_by_req(request)
         auth_user = requester.user
@@ -508,7 +519,9 @@ class DeactivateAccountRestServlet(RestServlet):
         self.is_mine = hs.is_mine
         self.store = hs.get_datastore()
 
-    async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
+    async def on_POST(
+        self, request: SynapseRequest, target_user_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -550,7 +563,7 @@ class AccountValidityRenewServlet(RestServlet):
         self.account_activity_handler = hs.get_account_validity_handler()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         body = parse_json_object_from_request(request)
@@ -584,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self._set_password_handler = hs.get_set_password_handler()
 
-    async def on_POST(self, request, target_user_id):
+    async def on_POST(
+        self, request: SynapseRequest, target_user_id: str
+    ) -> Tuple[int, JsonDict]:
         """Post request to allow an administrator reset password for a user.
         This needs user to have administrator access in Synapse.
         """
@@ -626,12 +641,14 @@ class SearchUsersRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, target_user_id):
+    async def on_GET(
+        self, request: SynapseRequest, target_user_id: str
+    ) -> Tuple[int, Optional[List[JsonDict]]]:
         """Get request to search user table for specific users according to
         search term.
         This needs user to have a administrator access in Synapse.
@@ -682,12 +699,14 @@ class UserAdminServlet(RestServlet):
 
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, user_id):
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
@@ -699,7 +718,9 @@ class UserAdminServlet(RestServlet):
 
         return 200, {"admin": is_admin}
 
-    async def on_PUT(self, request, user_id):
+    async def on_PUT(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
         auth_user = requester.user
@@ -730,12 +751,14 @@ class UserMembershipRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.is_mine = hs.is_mine
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, user_id):
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         room_ids = await self.store.get_rooms_for_user(user_id)
@@ -758,7 +781,7 @@ class PushersRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.is_mine = hs.is_mine
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
@@ -799,7 +822,7 @@ class UserMediaRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.is_mine = hs.is_mine
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
@@ -891,7 +914,9 @@ class UserTokenRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
 
-    async def on_POST(self, request, user_id):
+    async def on_POST(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
         auth_user = requester.user
@@ -943,7 +968,9 @@ class ShadowBanRestServlet(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, user_id):
+    async def on_POST(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine_id(user_id):
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 70b49854cf..1d44c3aa2c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import List, Optional, Tuple
 
 from synapse.api.constants import PresenceState
 from synapse.config.homeserver import HomeServerConfig
@@ -27,7 +27,7 @@ from synapse.storage.util.id_generators import (
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
-from synapse.types import get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from .account_data import AccountDataStore
@@ -264,7 +264,7 @@ class DataStore(
 
         return [UserPresenceState(**row) for row in rows]
 
-    async def get_users(self) -> List[Dict[str, Any]]:
+    async def get_users(self) -> List[JsonDict]:
         """Function to retrieve a list of users in users table.
 
         Returns:
@@ -292,7 +292,7 @@ class DataStore(
         name: Optional[str] = None,
         guests: bool = True,
         deactivated: bool = False,
-    ) -> Tuple[List[Dict[str, Any]], int]:
+    ) -> Tuple[List[JsonDict], int]:
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users and the
         total number of users matching the filter criteria.
@@ -353,7 +353,7 @@ class DataStore(
             "get_users_paginate_txn", get_users_paginate_txn
         )
 
-    async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
+    async def search_users(self, term: str) -> Optional[List[JsonDict]]:
         """Function to search users list for one or more users with
         the matched term.
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c1626ccf28..cb6b1f8a0c 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -696,7 +696,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 )
 
             if not has_event_auth:
-                for auth_id in event.auth_event_ids():
+                # Old, dodgy, events may have duplicate auth events, which we
+                # need to deduplicate as we have a unique constraint.
+                for auth_id in set(event.auth_event_ids()):
                     auth_events.append(
                         {
                             "room_id": event.room_id,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 274f8de595..4f3d192562 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         start: int,
         limit: int,
         user_id: str,
-        order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value,
+        order_by: str = MediaSortOrder.CREATED_TS.value,
         direction: str = "f",
     ) -> Tuple[List[Dict[str, Any]], int]:
         """Get a paginated list of metadata for a local piece of media
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ecfc9f20b1..0836e4af49 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -28,7 +28,10 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
     async def purge_history(
         self, room_id: str, token: str, delete_local_events: bool
     ) -> Set[int]:
-        """Deletes room history before a certain point
+        """Deletes room history before a certain point.
+
+        Note that only a single purge can occur at once, this is guaranteed via
+        a higher level (in the PaginationHandler).
 
         Args:
             room_id:
@@ -52,7 +55,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
             delete_local_events,
         )
 
-    def _purge_history_txn(self, txn, room_id, token, delete_local_events):
+    def _purge_history_txn(
+        self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
+    ) -> Set[int]:
         # Tables that should be pruned:
         #     event_auth
         #     event_backward_extremities
@@ -103,7 +108,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         if max_depth < token.topological:
             # We need to ensure we don't delete all the events from the database
             # otherwise we wouldn't be able to send any events (due to not
-            # having any backwards extremeties)
+            # having any backwards extremities)
             raise SynapseError(
                 400, "topological_ordering is greater than forward extremeties"
             )
@@ -154,7 +159,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
 
         logger.info("[purge] Finding new backward extremities")
 
-        # We calculate the new entries for the backward extremeties by finding
+        # We calculate the new entries for the backward extremities by finding
         # events to be purged that are pointed to by events we're not going to
         # purge.
         txn.execute(
@@ -296,7 +301,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
             "purge_room", self._purge_room_txn, room_id
         )
 
-    def _purge_room_txn(self, txn, room_id):
+    def _purge_room_txn(self, txn, room_id: str) -> List[int]:
         # First we fetch all the state groups that should be deleted, before
         # we delete that information.
         txn.execute(
@@ -310,6 +315,31 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
 
         state_groups = [row[0] for row in txn]
 
+        # Get all the auth chains that are referenced by events that are to be
+        # deleted.
+        txn.execute(
+            """
+            SELECT chain_id, sequence_number FROM events
+            LEFT JOIN event_auth_chains USING (event_id)
+            WHERE room_id = ?
+            """,
+            (room_id,),
+        )
+        referenced_chain_id_tuples = list(txn)
+
+        logger.info("[purge] removing events from event_auth_chain_links")
+        txn.executemany(
+            """
+            DELETE FROM event_auth_chain_links WHERE
+            (origin_chain_id = ? AND origin_sequence_number = ?) OR
+            (target_chain_id = ? AND target_sequence_number = ?)
+            """,
+            (
+                (chain_id, seq_num, chain_id, seq_num)
+                for (chain_id, seq_num) in referenced_chain_id_tuples
+            ),
+        )
+
         # Now we delete tables which lack an index on room_id but have one on event_id
         for table in (
             "event_auth",
@@ -319,6 +349,8 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
             "event_reference_hashes",
             "event_relations",
             "event_to_state_groups",
+            "event_auth_chains",
+            "event_auth_chain_to_calculate",
             "redactions",
             "rejections",
             "state_events",
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index 3c4908865f..4dcd848c59 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -73,9 +73,6 @@ class PurgeEventsStorage:
         Returns:
             The set of state groups that can be deleted.
         """
-        # Graph of state group -> previous group
-        graph = {}
-
         # Set of events that we have found to be referenced by events
         referenced_groups = set()
 
@@ -111,8 +108,6 @@ class PurgeEventsStorage:
             next_to_search |= prevs
             state_groups_seen |= prevs
 
-            graph.update(edges)
-
         to_delete = state_groups_seen - referenced_groups
 
         return to_delete
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index f152f63321..d2ff4da6b9 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
 )
 
 GetRoomsForUserWithStreamOrdering = namedtuple(
-    "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
+    "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
 )