summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9498.bugfix1
-rw-r--r--changelog.d/9518.misc1
-rw-r--r--changelog.d/9521.misc1
-rw-r--r--changelog.d/9529.misc1
-rw-r--r--changelog.d/9536.bugfix1
-rw-r--r--changelog.d/9537.bugfix1
-rw-r--r--changelog.d/9539.feature1
-rwxr-xr-xsetup.py2
-rw-r--r--synapse/app/generic_worker.py15
-rw-r--r--synapse/http/federation/matrix_federation_agent.py18
-rw-r--r--synapse/http/matrixfederationclient.py6
-rw-r--r--synapse/http/server.py29
-rw-r--r--synapse/http/site.py35
-rw-r--r--synapse/logging/_remote.py6
-rw-r--r--synapse/metrics/__init__.py11
-rw-r--r--synapse/module_api/__init__.py4
-rw-r--r--synapse/push/httppusher.py5
-rw-r--r--synapse/replication/tcp/client.py4
-rw-r--r--synapse/replication/tcp/streams/_base.py2
-rw-r--r--synapse/rest/admin/users.py85
-rw-r--r--synapse/server.py3
-rw-r--r--synapse/storage/databases/main/__init__.py10
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py4
-rw-r--r--synapse/storage/databases/main/media_repository.py2
-rw-r--r--synapse/storage/databases/main/purge_events.py42
-rw-r--r--synapse/storage/databases/main/pusher.py52
-rw-r--r--synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql3
-rw-r--r--synapse/storage/purge_events.py5
-rw-r--r--synapse/storage/roommember.py2
-rw-r--r--tests/rest/client/v1/test_login.py35
30 files changed, 274 insertions, 113 deletions
diff --git a/changelog.d/9498.bugfix b/changelog.d/9498.bugfix
new file mode 100644

index 0000000000..dce0ad0920 --- /dev/null +++ b/changelog.d/9498.bugfix
@@ -0,0 +1 @@ +Properly purge the event chain cover index when purging history. diff --git a/changelog.d/9518.misc b/changelog.d/9518.misc new file mode 100644
index 0000000000..14c7b78dd9 --- /dev/null +++ b/changelog.d/9518.misc
@@ -0,0 +1 @@ +Fix incorrect type hints. diff --git a/changelog.d/9521.misc b/changelog.d/9521.misc new file mode 100644
index 0000000000..1424d9c188 --- /dev/null +++ b/changelog.d/9521.misc
@@ -0,0 +1 @@ +Add type hints to user admin API. \ No newline at end of file diff --git a/changelog.d/9529.misc b/changelog.d/9529.misc new file mode 100644
index 0000000000..b9021a26b4 --- /dev/null +++ b/changelog.d/9529.misc
@@ -0,0 +1 @@ +Bump the versions of mypy and mypy-zope used for static type checking. diff --git a/changelog.d/9536.bugfix b/changelog.d/9536.bugfix new file mode 100644
index 0000000000..2ab4f315c1 --- /dev/null +++ b/changelog.d/9536.bugfix
@@ -0,0 +1 @@ +Fix deleting pushers when using sharded pushers. diff --git a/changelog.d/9537.bugfix b/changelog.d/9537.bugfix new file mode 100644
index 0000000000..033ab1c939 --- /dev/null +++ b/changelog.d/9537.bugfix
@@ -0,0 +1 @@ +Fix rare edge case that caused a background update to fail if the server had rejected an event that had duplicate auth events. diff --git a/changelog.d/9539.feature b/changelog.d/9539.feature new file mode 100644
index 0000000000..06cfd5d199 --- /dev/null +++ b/changelog.d/9539.feature
@@ -0,0 +1 @@ +Add support for `X-Forwarded-Proto` header when using a reverse proxy. diff --git a/setup.py b/setup.py
index 08ba4eb764..bbd9e7862a 100755 --- a/setup.py +++ b/setup.py
@@ -102,7 +102,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [ "flake8", ] -CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"] +CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.11"] # Dependencies which are exclusively required by unit test code. This is # NOT a list of all modules that are necessary to run the unit tests. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index d9423349e1..52314db95c 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -23,6 +23,7 @@ from typing_extensions import ContextManager from twisted.internet import address from twisted.web.resource import IResource +from twisted.web.server import Request import synapse import synapse.events @@ -190,7 +191,7 @@ class KeyUploadServlet(RestServlet): self.http_client = hs.get_simple_http_client() self.main_uri = hs.config.worker_main_http_uri - async def on_POST(self, request, device_id): + async def on_POST(self, request: Request, device_id: Optional[str]): requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -223,10 +224,12 @@ class KeyUploadServlet(RestServlet): header: request.requestHeaders.getRawHeaders(header, []) for header in (b"Authorization", b"User-Agent") } - # Add the previous hop the the X-Forwarded-For header. + # Add the previous hop to the X-Forwarded-For header. x_forwarded_for = request.requestHeaders.getRawHeaders( b"X-Forwarded-For", [] ) + # we use request.client here, since we want the previous hop, not the + # original client (as returned by request.getClientAddress()). if isinstance(request.client, (address.IPv4Address, address.IPv6Address)): previous_host = request.client.host.encode("ascii") # If the header exists, add to the comma-separated list of the first @@ -239,6 +242,14 @@ class KeyUploadServlet(RestServlet): x_forwarded_for = [previous_host] headers[b"X-Forwarded-For"] = x_forwarded_for + # Replicate the original X-Forwarded-Proto header. Note that + # XForwardedForRequest overrides isSecure() to give us the original protocol + # used by the client, as opposed to the protocol used by our upstream proxy + # - which is what we want here. + headers[b"X-Forwarded-Proto"] = [ + b"https" if request.isSecure() else b"http" + ] + try: result = await self.http_client.post_json_get_json( self.main_uri + request.uri.decode("ascii"), body, headers=headers diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 2e83fa6773..b07aa59c08 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import List, Optional +from typing import Any, Generator, List, Optional from netaddr import AddrFormatError, IPAddress, IPSet from zope.interface import implementer @@ -116,7 +116,7 @@ class MatrixFederationAgent: uri: bytes, headers: Optional[Headers] = None, bodyProducer: Optional[IBodyProducer] = None, - ) -> defer.Deferred: + ) -> Generator[defer.Deferred, Any, defer.Deferred]: """ Args: method: HTTP method: GET/POST/etc @@ -177,17 +177,17 @@ class MatrixFederationAgent: # We need to make sure the host header is set to the netloc of the # server and that a user-agent is provided. if headers is None: - headers = Headers() + request_headers = Headers() else: - headers = headers.copy() + request_headers = headers.copy() - if not headers.hasHeader(b"host"): - headers.addRawHeader(b"host", parsed_uri.netloc) - if not headers.hasHeader(b"user-agent"): - headers.addRawHeader(b"user-agent", self.user_agent) + if not request_headers.hasHeader(b"host"): + request_headers.addRawHeader(b"host", parsed_uri.netloc) + if not request_headers.hasHeader(b"user-agent"): + request_headers.addRawHeader(b"user-agent", self.user_agent) res = yield make_deferred_yieldable( - self._agent.request(method, uri, headers, bodyProducer) + self._agent.request(method, uri, request_headers, bodyProducer) ) return res diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index cde42e9f5e..0f107714ea 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None: RequestSendFailed: if the Content-Type header is missing or isn't JSON """ - c_type = headers.getRawHeaders(b"Content-Type") - if c_type is None: + content_type_headers = headers.getRawHeaders(b"Content-Type") + if content_type_headers is None: raise RequestSendFailed( RuntimeError("No Content-Type header received from remote server"), can_retry=False, ) - c_type = c_type[0].decode("ascii") # only the first header + c_type = content_type_headers[0].decode("ascii") # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": raise RequestSendFailed( diff --git a/synapse/http/server.py b/synapse/http/server.py
index 845db9b78d..fa89260850 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -21,6 +21,7 @@ import logging import types import urllib from http import HTTPStatus +from inspect import isawaitable from io import BytesIO from typing import ( Any, @@ -30,6 +31,7 @@ from typing import ( Iterable, Iterator, List, + Optional, Pattern, Tuple, Union, @@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: """Sends a JSON error response to clients.""" if f.check(SynapseError): - error_code = f.value.code - error_dict = f.value.error_dict() + # mypy doesn't understand that f.check asserts the type. + exc = f.value # type: SynapseError # type: ignore + error_code = exc.code + error_dict = exc.error_dict() - logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg) + logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) else: error_code = 500 error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN} @@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: "Failed handle request via %r: %r", request.request_metrics.name, request, - exc_info=(f.type, f.value, f.getTracebackObject()), + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore ) # Only respond with an error response if we haven't already started writing, @@ -128,7 +132,8 @@ def return_html_error( `{msg}` placeholders), or a jinja2 template """ if f.check(CodeMessageException): - cme = f.value + # mypy doesn't understand that f.check asserts the type. + cme = f.value # type: CodeMessageException # type: ignore code = cme.code msg = cme.msg @@ -142,7 +147,7 @@ def return_html_error( logger.error( "Failed handle request %r", request, - exc_info=(f.type, f.value, f.getTracebackObject()), + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore ) else: code = HTTPStatus.INTERNAL_SERVER_ERROR @@ -151,7 +156,7 @@ def return_html_error( logger.error( "Failed handle request %r", request, - exc_info=(f.type, f.value, f.getTracebackObject()), + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore ) if isinstance(error_template, str): @@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): raw_callback_return = method_handler(request) # Is it synchronous? We'll allow this for now. - if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): + if isawaitable(raw_callback_return): callback_return = await raw_callback_return else: callback_return = raw_callback_return # type: ignore @@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource): A tuple of the callback to use, the name of the servlet, and the key word arguments to pass to the callback """ + # At this point the path must be bytes. + request_path_bytes = request.path # type: bytes # type: ignore + request_path = request_path_bytes.decode("ascii") # Treat HEAD requests as GET requests. - request_path = request.path.decode("ascii") request_method = request.method if request_method == b"HEAD": request_method = b"GET" @@ -551,7 +558,7 @@ class _ByteProducer: request: Request, iterator: Iterator[bytes], ): - self._request = request + self._request = request # type: Optional[Request] self._iterator = iterator self._paused = False @@ -563,7 +570,7 @@ class _ByteProducer: """ Send a list of bytes as a chunk of a response. """ - if not data: + if not data or not self._request: return self._request.write(b"".join(data)) diff --git a/synapse/http/site.py b/synapse/http/site.py
index 30153237e3..47754aff43 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Optional, Union +from typing import Optional, Type, Union import attr from zope.interface import implementer @@ -57,7 +57,7 @@ class SynapseRequest(Request): def __init__(self, channel, *args, **kw): Request.__init__(self, channel, *args, **kw) - self.site = channel.site + self.site = channel.site # type: SynapseSite self._channel = channel # this is used by the tests self.start_time = 0.0 @@ -96,25 +96,34 @@ class SynapseRequest(Request): def get_request_id(self): return "%s-%i" % (self.get_method(), self.request_seq) - def get_redacted_uri(self): - uri = self.uri + def get_redacted_uri(self) -> str: + """Gets the redacted URI associated with the request (or placeholder if the URI + has not yet been received). + + Note: This is necessary as the placeholder value in twisted is str + rather than bytes, so we need to sanitise `self.uri`. + + Returns: + The redacted URI as a string. + """ + uri = self.uri # type: Union[bytes, str] if isinstance(uri, bytes): - uri = self.uri.decode("ascii", errors="replace") + uri = uri.decode("ascii", errors="replace") return redact_uri(uri) - def get_method(self): - """Gets the method associated with the request (or placeholder if not - method has yet been received). + def get_method(self) -> str: + """Gets the method associated with the request (or placeholder if method + has not yet been received). Note: This is necessary as the placeholder value in twisted is str rather than bytes, so we need to sanitise `self.method`. Returns: - str + The request method as a string. """ - method = self.method + method = self.method # type: Union[bytes, str] if isinstance(method, bytes): - method = self.method.decode("ascii") + return self.method.decode("ascii") return method def render(self, resrc): @@ -432,7 +441,9 @@ class SynapseSite(Site): assert config.http_options is not None proxied = config.http_options.x_forwarded - self.requestFactory = XForwardedForRequest if proxied else SynapseRequest + self.requestFactory = ( + XForwardedForRequest if proxied else SynapseRequest + ) # type: Type[Request] self.access_logger = logging.getLogger(logger_name) self.server_version_string = server_version_string.encode("ascii") diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index f8e9112b56..174ca7be5a 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py
@@ -32,7 +32,7 @@ from twisted.internet.endpoints import ( TCP4ClientEndpoint, TCP6ClientEndpoint, ) -from twisted.internet.interfaces import IPushProducer, ITransport +from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport from twisted.internet.protocol import Factory, Protocol from twisted.python.failure import Failure @@ -121,7 +121,9 @@ class RemoteHandler(logging.Handler): try: ip = ip_address(self.host) if isinstance(ip, IPv4Address): - endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port) + endpoint = TCP4ClientEndpoint( + _reactor, self.host, self.port + ) # type: IStreamClientEndpoint elif isinstance(ip, IPv6Address): endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port) else: diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index a8cb49d5b4..3b499efc07 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py
@@ -527,7 +527,7 @@ class ReactorLastSeenMetric: REGISTRY.register(ReactorLastSeenMetric()) -def runUntilCurrentTimer(func): +def runUntilCurrentTimer(reactor, func): @functools.wraps(func) def f(*args, **kwargs): now = reactor.seconds() @@ -590,13 +590,14 @@ def runUntilCurrentTimer(func): try: # Ensure the reactor has all the attributes we expect - reactor.runUntilCurrent - reactor._newTimedCalls - reactor.threadCallQueue + reactor.seconds # type: ignore + reactor.runUntilCurrent # type: ignore + reactor._newTimedCalls # type: ignore + reactor.threadCallQueue # type: ignore # runUntilCurrent is called when we have pending calls. It is called once # per iteratation after fd polling. - reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent) + reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore # We manually run the GC each reactor tick so that we can get some metrics # about time spent doing GC, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2e3b311c4a..db2d400b7e 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Iterable, Optional, Tuple +from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple from twisted.internet import defer @@ -307,7 +307,7 @@ class ModuleApi: @defer.inlineCallbacks def get_state_events_in_room( self, room_id: str, types: Iterable[Tuple[str, Optional[str]]] - ) -> defer.Deferred: + ) -> Generator[defer.Deferred, Any, defer.Deferred]: """Gets current state events for the given room. (This is exposed for compatibility with the old SpamCheckerApi. We should diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index f4d7e199e9..eb6de8ba72 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py
@@ -15,11 +15,12 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Any, Dict, Iterable, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union from prometheus_client import Counter from twisted.internet.error import AlreadyCalled, AlreadyCancelled +from twisted.internet.interfaces import IDelayedCall from synapse.api.constants import EventTypes from synapse.events import EventBase @@ -71,7 +72,7 @@ class HttpPusher(Pusher): self.data = pusher_config.data self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.failing_since = pusher_config.failing_since - self.timed_call = None + self.timed_call = None # type: Optional[IDelayedCall] self._is_processing = False self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room self._pusherpool = hs.get_pusherpool() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2618eb1e53..3455839d67 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -108,9 +108,7 @@ class ReplicationDataHandler: # Map from stream to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. - self._streams_to_waiters = ( - {} - ) # type: Dict[str, List[Tuple[int, Deferred[None]]]] + self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]] async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 38809b5b7c..f45e7a8c89 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -502,7 +502,7 @@ class AccountDataStream(Stream): """Global or per room account data was changed""" AccountDataStreamRow = namedtuple( - "AccountDataStream", + "AccountDataStreamRow", ("user_id", "room_id", "data_type"), # str # Optional[str] # str ) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 9c701c7348..267a993430 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -16,7 +16,7 @@ import hashlib import hmac import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -47,13 +47,15 @@ logger = logging.getLogger(__name__) class UsersRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, List[JsonDict]]: target_user = UserID.from_string(user_id) await assert_requester_is_admin(self.auth, request) @@ -153,7 +155,7 @@ class UserRestServletV2(RestServlet): otherwise an error. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -165,7 +167,9 @@ class UserRestServletV2(RestServlet): self.registration_handler = hs.get_registration_handler() self.pusher_pool = hs.get_pusherpool() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) @@ -179,7 +183,9 @@ class UserRestServletV2(RestServlet): return 200, ret - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -273,6 +279,8 @@ class UserRestServletV2(RestServlet): ) user = await self.admin_handler.get_user(target_user) + assert user is not None + return 200, user else: # create user @@ -330,9 +338,10 @@ class UserRestServletV2(RestServlet): target_user, requester, body["avatar_url"], True ) - ret = await self.admin_handler.get_user(target_user) + user = await self.admin_handler.get_user(target_user) + assert user is not None - return 201, ret + return 201, user class UserRegisterServlet(RestServlet): @@ -346,10 +355,10 @@ class UserRegisterServlet(RestServlet): PATTERNS = admin_patterns("/register") NONCE_TIMEOUT = 60 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.auth_handler = hs.get_auth_handler() self.reactor = hs.get_reactor() - self.nonces = {} + self.nonces = {} # type: Dict[str, int] self.hs = hs def _clear_old_nonces(self): @@ -362,7 +371,7 @@ class UserRegisterServlet(RestServlet): if now - v > self.NONCE_TIMEOUT: del self.nonces[k] - def on_GET(self, request): + def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """ Generate a new nonce. """ @@ -372,7 +381,7 @@ class UserRegisterServlet(RestServlet): self.nonces[nonce] = int(self.reactor.seconds()) return 200, {"nonce": nonce} - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() if not self.hs.config.registration_shared_secret: @@ -478,12 +487,14 @@ class WhoisRestServlet(RestServlet): client_patterns("/admin" + path_regex, v1=True) ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) auth_user = requester.user @@ -508,7 +519,9 @@ class DeactivateAccountRestServlet(RestServlet): self.is_mine = hs.is_mine self.store = hs.get_datastore() - async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, target_user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -550,7 +563,7 @@ class AccountValidityRenewServlet(RestServlet): self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request) @@ -584,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet): PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() - async def on_POST(self, request, target_user_id): + async def on_POST( + self, request: SynapseRequest, target_user_id: str + ) -> Tuple[int, JsonDict]: """Post request to allow an administrator reset password for a user. This needs user to have administrator access in Synapse. """ @@ -626,12 +641,14 @@ class SearchUsersRestServlet(RestServlet): PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request, target_user_id): + async def on_GET( + self, request: SynapseRequest, target_user_id: str + ) -> Tuple[int, Optional[List[JsonDict]]]: """Get request to search user table for specific users according to search term. This needs user to have a administrator access in Synapse. @@ -682,12 +699,14 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) @@ -699,7 +718,9 @@ class UserAdminServlet(RestServlet): return 200, {"admin": is_admin} - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user @@ -730,12 +751,14 @@ class UserMembershipRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) room_ids = await self.store.get_rooms_for_user(user_id) @@ -758,7 +781,7 @@ class PushersRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.store = hs.get_datastore() self.auth = hs.get_auth() @@ -799,7 +822,7 @@ class UserMediaRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -891,7 +914,9 @@ class UserTokenRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - async def on_POST(self, request, user_id): + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user @@ -943,7 +968,9 @@ class ShadowBanRestServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, user_id): + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): diff --git a/synapse/server.py b/synapse/server.py
index 1d4370e0ba..afd7cd72e7 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -38,6 +38,7 @@ from typing import ( import twisted.internet.base import twisted.internet.tcp +from twisted.internet import defer from twisted.mail.smtp import sendmail from twisted.web.iweb import IPolicyForHTTPS @@ -403,7 +404,7 @@ class HomeServer(metaclass=abc.ABCMeta): return RoomShutdownHandler(self) @cache_in_self - def get_sendmail(self) -> sendmail: + def get_sendmail(self) -> Callable[..., defer.Deferred]: return sendmail @cache_in_self diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 70b49854cf..1d44c3aa2c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -16,7 +16,7 @@ # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig @@ -27,7 +27,7 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import get_domain_from_id +from synapse.types import JsonDict, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore @@ -264,7 +264,7 @@ class DataStore( return [UserPresenceState(**row) for row in rows] - async def get_users(self) -> List[Dict[str, Any]]: + async def get_users(self) -> List[JsonDict]: """Function to retrieve a list of users in users table. Returns: @@ -292,7 +292,7 @@ class DataStore( name: Optional[str] = None, guests: bool = True, deactivated: bool = False, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. @@ -353,7 +353,7 @@ class DataStore( "get_users_paginate_txn", get_users_paginate_txn ) - async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]: + async def search_users(self, term: str) -> Optional[List[JsonDict]]: """Function to search users list for one or more users with the matched term. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c1626ccf28..cb6b1f8a0c 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -696,7 +696,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) if not has_event_auth: - for auth_id in event.auth_event_ids(): + # Old, dodgy, events may have duplicate auth events, which we + # need to deduplicate as we have a unique constraint. + for auth_id in set(event.auth_event_ids()): auth_events.append( { "room_id": event.room_id, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 274f8de595..4f3d192562 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): start: int, limit: int, user_id: str, - order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value, + order_by: str = MediaSortOrder.CREATED_TS.value, direction: str = "f", ) -> Tuple[List[Dict[str, Any]], int]: """Get a paginated list of metadata for a local piece of media diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ecfc9f20b1..0836e4af49 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -28,7 +28,10 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): async def purge_history( self, room_id: str, token: str, delete_local_events: bool ) -> Set[int]: - """Deletes room history before a certain point + """Deletes room history before a certain point. + + Note that only a single purge can occur at once, this is guaranteed via + a higher level (in the PaginationHandler). Args: room_id: @@ -52,7 +55,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): delete_local_events, ) - def _purge_history_txn(self, txn, room_id, token, delete_local_events): + def _purge_history_txn( + self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool + ) -> Set[int]: # Tables that should be pruned: # event_auth # event_backward_extremities @@ -103,7 +108,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): if max_depth < token.topological: # We need to ensure we don't delete all the events from the database # otherwise we wouldn't be able to send any events (due to not - # having any backwards extremeties) + # having any backwards extremities) raise SynapseError( 400, "topological_ordering is greater than forward extremeties" ) @@ -154,7 +159,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): logger.info("[purge] Finding new backward extremities") - # We calculate the new entries for the backward extremeties by finding + # We calculate the new entries for the backward extremities by finding # events to be purged that are pointed to by events we're not going to # purge. txn.execute( @@ -296,7 +301,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "purge_room", self._purge_room_txn, room_id ) - def _purge_room_txn(self, txn, room_id): + def _purge_room_txn(self, txn, room_id: str) -> List[int]: # First we fetch all the state groups that should be deleted, before # we delete that information. txn.execute( @@ -310,6 +315,31 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): state_groups = [row[0] for row in txn] + # Get all the auth chains that are referenced by events that are to be + # deleted. + txn.execute( + """ + SELECT chain_id, sequence_number FROM events + LEFT JOIN event_auth_chains USING (event_id) + WHERE room_id = ? + """, + (room_id,), + ) + referenced_chain_id_tuples = list(txn) + + logger.info("[purge] removing events from event_auth_chain_links") + txn.executemany( + """ + DELETE FROM event_auth_chain_links WHERE + (origin_chain_id = ? AND origin_sequence_number = ?) OR + (target_chain_id = ? AND target_sequence_number = ?) + """, + ( + (chain_id, seq_num, chain_id, seq_num) + for (chain_id, seq_num) in referenced_chain_id_tuples + ), + ) + # Now we delete tables which lack an index on room_id but have one on event_id for table in ( "event_auth", @@ -319,6 +349,8 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "event_reference_hashes", "event_relations", "event_to_state_groups", + "event_auth_chains", + "event_auth_chain_to_calculate", "redactions", "rejections", "state_events", diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 6b608ebc9b..85f1ebac98 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -44,6 +44,11 @@ class PusherWorkerStore(SQLBaseStore): self._remove_deactivated_pushers, ) + self.db_pool.updates.register_background_update_handler( + "remove_stale_pushers", + self._remove_stale_pushers, + ) + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: """JSON-decode the data in the rows returned from the `pushers` table @@ -337,6 +342,53 @@ class PusherWorkerStore(SQLBaseStore): return number_deleted + async def _remove_stale_pushers(self, progress: dict, batch_size: int) -> int: + """A background update that deletes all pushers for logged out devices. + + Note that we don't proacively tell the pusherpool that we've deleted + these (just because its a bit off a faff to do from here), but they will + get cleaned up at the next restart + """ + + last_pusher = progress.get("last_pusher", 0) + + def _delete_pushers(txn) -> int: + + sql = """ + SELECT p.id, access_token FROM pushers AS p + LEFT JOIN access_tokens AS a ON (p.access_token = a.id) + WHERE p.id > ? + ORDER BY p.id ASC + LIMIT ? + """ + + txn.execute(sql, (last_pusher, batch_size)) + pushers = [(row[0], row[1]) for row in txn] + + self.db_pool.simple_delete_many_txn( + txn, + table="pushers", + column="id", + iterable=(pusher_id for pusher_id, token in pushers if token is None), + keyvalues={}, + ) + + if pushers: + self.db_pool.updates._background_update_progress_txn( + txn, "remove_stale_pushers", {"last_pusher": pushers[-1][0]} + ) + + return len(pushers) + + number_deleted = await self.db_pool.runInteraction( + "_remove_stale_pushers", _delete_pushers + ) + + if number_deleted < batch_size: + await self.db_pool.updates._end_background_update("remove_stale_pushers") + + return number_deleted + class PusherStore(PusherWorkerStore): def get_pushers_stream_token(self) -> int: diff --git a/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql
index 2442eea6bc..85196db288 100644 --- a/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql +++ b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql
@@ -16,4 +16,5 @@ -- Delete all pushers associated with deleted devices. This is to clear up after -- a bug where they weren't correctly deleted when using workers. -DELETE FROM pushers WHERE access_token NOT IN (SELECT id FROM access_tokens); +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5908, 'remove_stale_pushers', '{}'); diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index 3c4908865f..4dcd848c59 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py
@@ -73,9 +73,6 @@ class PurgeEventsStorage: Returns: The set of state groups that can be deleted. """ - # Graph of state group -> previous group - graph = {} - # Set of events that we have found to be referenced by events referenced_groups = set() @@ -111,8 +108,6 @@ class PurgeEventsStorage: next_to_search |= prevs state_groups_seen |= prevs - graph.update(edges) - to_delete = state_groups_seen - referenced_groups return to_delete diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index f152f63321..d2ff4da6b9 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple( ) GetRoomsForUserWithStreamOrdering = namedtuple( - "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos") + "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos") ) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 744d8d0941..20af3285bd 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py
@@ -522,7 +522,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): shorthand=False, ) self.assertEqual(channel.code, 302, channel.result) - cas_uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + cas_uri = location_headers[0] cas_uri_path, cas_uri_query = cas_uri.split("?", 1) # it should redirect us to the login page of the cas server @@ -545,7 +547,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + "&idp=saml", ) self.assertEqual(channel.code, 302, channel.result) - saml_uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + saml_uri = location_headers[0] saml_uri_path, saml_uri_query = saml_uri.split("?", 1) # it should redirect us to the login page of the SAML server @@ -567,17 +571,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + "&idp=oidc", ) self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + oidc_uri = location_headers[0] oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) # ... and should have set a cookie including the redirect url - cookies = dict( - h.split(";")[0].split("=", maxsplit=1) - for h in channel.headers.getRawHeaders("Set-Cookie") - ) + cookie_headers = channel.headers.getRawHeaders("Set-Cookie") + assert cookie_headers + cookies = {} # type: Dict[str, str] + for h in cookie_headers: + key, value = h.split(";")[0].split("=", maxsplit=1) + cookies[key] = value oidc_session_cookie = cookies["oidc_session"] macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) @@ -590,9 +598,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # that should serve a confirmation page self.assertEqual(channel.code, 200, channel.result) - self.assertTrue( - channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html") - ) + content_type_headers = channel.headers.getRawHeaders("Content-Type") + assert content_type_headers + self.assertTrue(content_type_headers[-1].startswith("text/html")) p = TestHtmlParser() p.feed(channel.text_body) p.close() @@ -806,6 +814,7 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 302) location_headers = channel.headers.getRawHeaders("Location") + assert location_headers self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) @@ -1248,7 +1257,9 @@ class UsernamePickerTestCase(HomeserverTestCase): # that should redirect to the username picker self.assertEqual(channel.code, 302, channel.result) - picker_url = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + picker_url = location_headers[0] self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie @@ -1291,6 +1302,7 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") + assert location_headers # send a request to the completion page, which should 302 to the client redirectUrl chan = self.make_request( @@ -1300,6 +1312,7 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") + assert location_headers # ensure that the returned location matches the requested redirect URL path, query = location_headers[0].split("?", 1)