diff options
70 files changed, 1603 insertions, 666 deletions
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 239553ae13..75c2976a25 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true - + jobs: lint: runs-on: ubuntu-latest @@ -374,6 +374,11 @@ jobs: rc=0 results=$(jq -r 'to_entries[] | [.key,.value.result] | join(" ")' <<< $NEEDS_CONTEXT) while read job result ; do + # The newsfile lint may be skipped on non PR builds + if [ $result == "skipped" ] && [ $job == "lint-newsfile" ]; then + continue + fi + if [ "$result" != "success" ]; then echo "::set-failed ::Job $job returned $result" rc=1 diff --git a/changelog.d/10119.misc b/changelog.d/10119.misc new file mode 100644 index 0000000000..f70dc6496f --- /dev/null +++ b/changelog.d/10119.misc @@ -0,0 +1 @@ +Improve event caching mechanism to avoid having multiple copies of an event in memory at a time. diff --git a/changelog.d/10129.bugfix b/changelog.d/10129.bugfix new file mode 100644 index 0000000000..292676ec8d --- /dev/null +++ b/changelog.d/10129.bugfix @@ -0,0 +1 @@ +Add some clarification to the sample config file. Contributed by @Kentokamoto. diff --git a/changelog.d/10435.feature b/changelog.d/10435.feature new file mode 100644 index 0000000000..f93ef5b415 --- /dev/null +++ b/changelog.d/10435.feature @@ -0,0 +1 @@ +Experimental support for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288), sending `room_type` to the identity server for 3pid invites over the `/store-invite` API. diff --git a/changelog.d/10443.doc b/changelog.d/10443.doc new file mode 100644 index 0000000000..3588e5487f --- /dev/null +++ b/changelog.d/10443.doc @@ -0,0 +1 @@ +Add documentation for configuration a forward proxy. diff --git a/changelog.d/10498.feature b/changelog.d/10498.feature new file mode 100644 index 0000000000..5df896572d --- /dev/null +++ b/changelog.d/10498.feature @@ -0,0 +1 @@ +Add support for "marker" events which makes historical events discoverable for servers that already have all of the scrollback history (part of MSC2716). diff --git a/changelog.d/10504.misc b/changelog.d/10504.misc new file mode 100644 index 0000000000..1479a5022d --- /dev/null +++ b/changelog.d/10504.misc @@ -0,0 +1 @@ +Reduce errors in PostgreSQL logs due to concurrent serialization errors. diff --git a/changelog.d/10507.misc b/changelog.d/10507.misc new file mode 100644 index 0000000000..5dfd116e60 --- /dev/null +++ b/changelog.d/10507.misc @@ -0,0 +1 @@ +Include room ID in ignored EDU log messages. Contributed by @ilmari. diff --git a/changelog.d/10513.feature b/changelog.d/10513.feature new file mode 100644 index 0000000000..153b2df7b2 --- /dev/null +++ b/changelog.d/10513.feature @@ -0,0 +1 @@ +Add a configuration setting for the time a `/sync` response is cached for. diff --git a/changelog.d/10527.misc b/changelog.d/10527.misc new file mode 100644 index 0000000000..3cf22f9daf --- /dev/null +++ b/changelog.d/10527.misc @@ -0,0 +1 @@ +Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)). diff --git a/changelog.d/10529.misc b/changelog.d/10529.misc new file mode 100644 index 0000000000..4caf22523c --- /dev/null +++ b/changelog.d/10529.misc @@ -0,0 +1 @@ +Fix CI to not break when run against branches rather than pull requests. diff --git a/changelog.d/10530.misc b/changelog.d/10530.misc new file mode 100644 index 0000000000..3cf22f9daf --- /dev/null +++ b/changelog.d/10530.misc @@ -0,0 +1 @@ +Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)). diff --git a/changelog.d/10532.bugfix b/changelog.d/10532.bugfix new file mode 100644 index 0000000000..d95e3d9b59 --- /dev/null +++ b/changelog.d/10532.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where protocols which are not implemented by any appservices were incorrectly returned via `GET /_matrix/client/r0/thirdparty/protocols`. diff --git a/changelog.d/10537.misc b/changelog.d/10537.misc new file mode 100644 index 0000000000..c9e045300c --- /dev/null +++ b/changelog.d/10537.misc @@ -0,0 +1 @@ +Mark all events stemming from the MSC2716 `/batch_send` endpoint as historical. diff --git a/changelog.d/10539.misc b/changelog.d/10539.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10539.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10541.bugfix b/changelog.d/10541.bugfix new file mode 100644 index 0000000000..bb946e0920 --- /dev/null +++ b/changelog.d/10541.bugfix @@ -0,0 +1 @@ +Fix exceptions in logs when failing to get remote room list. diff --git a/changelog.d/10542.misc b/changelog.d/10542.misc new file mode 100644 index 0000000000..44b70b4730 --- /dev/null +++ b/changelog.d/10542.misc @@ -0,0 +1 @@ +Convert `Transaction` and `Edu` objects to attrs. diff --git a/changelog.d/10546.feature b/changelog.d/10546.feature new file mode 100644 index 0000000000..7709d010b3 --- /dev/null +++ b/changelog.d/10546.feature @@ -0,0 +1 @@ +Add a setting to disable TLS when sending email. diff --git a/changelog.d/9581.feature b/changelog.d/9581.feature new file mode 100644 index 0000000000..fa1949cd4b --- /dev/null +++ b/changelog.d/9581.feature @@ -0,0 +1 @@ +Add `get_userinfo_by_id` method to ModuleApi. diff --git a/debian/build_virtualenv b/debian/build_virtualenv index 68c8659953..801ecb9086 100755 --- a/debian/build_virtualenv +++ b/debian/build_virtualenv @@ -100,3 +100,18 @@ esac # add a dependency on the right version of python to substvars. PYPKG=`basename $SNAKE` echo "synapse:pydepends=$PYPKG" >> debian/matrix-synapse-py3.substvars + + +# add a couple of triggers. This is needed so that dh-virtualenv can rebuild +# the venv when the system python changes (see +# https://dh-virtualenv.readthedocs.io/en/latest/tutorial.html#step-2-set-up-packaging-for-your-project) +# +# we do it here rather than the more conventional way of just adding it to +# debian/matrix-synapse-py3.triggers, because we need to add a trigger on the +# right version of python. +cat >>"debian/.debhelper/generated/matrix-synapse-py3/triggers" <<EOF +# triggers for dh-virtualenv +interest-noawait $SNAKE +interest dh-virtualenv-interpreter-update + +EOF diff --git a/debian/changelog b/debian/changelog index 7b44341bc6..22f9406e5b 100644 --- a/debian/changelog +++ b/debian/changelog @@ -14,6 +14,8 @@ matrix-synapse-py3 (1.40.0~rc1) stable; urgency=medium [ Richard van der Hoff ] * Drop backwards-compatibility code that was required to support Ubuntu Xenial. + * Update package triggers so that the virtualenv is correctly rebuilt + when the system python is rebuilt, on recent Python versions. [ Synapse Packaging team ] * New synapse release 1.40.0~rc1. diff --git a/debian/matrix-synapse-py3.triggers b/debian/matrix-synapse-py3.triggers deleted file mode 100644 index f8c1fdb021..0000000000 --- a/debian/matrix-synapse-py3.triggers +++ /dev/null @@ -1,9 +0,0 @@ -# Register interest in Python interpreter changes and -# don't make the Python package dependent on the virtualenv package -# processing (noawait) -interest-noawait /usr/bin/python3.5 -interest-noawait /usr/bin/python3.6 -interest-noawait /usr/bin/python3.7 - -# Also provide a symbolic trigger for all dh-virtualenv packages -interest dh-virtualenv-interpreter-update diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 10be12d638..3d320a1c43 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -7,6 +7,7 @@ - [Installation](setup/installation.md) - [Using Postgres](postgres.md) - [Configuring a Reverse Proxy](reverse_proxy.md) + - [Configuring a Forward/Outbound Proxy](setup/forward_proxy.md) - [Configuring a Turn Server](turn-howto.md) - [Delegation](delegate.md) diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1a217f35db..aeebcaf45f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -210,6 +210,8 @@ presence: # # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. # +# Note: The value is ignored when an HTTP proxy is in use +# #ip_range_blacklist: # - '127.0.0.0/8' # - '10.0.0.0/8' @@ -711,6 +713,15 @@ caches: # #expiry_time: 30m + # Controls how long the results of a /sync request are cached for after + # a successful response is returned. A higher duration can help clients with + # intermittent connections, at the cost of higher memory usage. + # + # By default, this is zero, which means that sync responses are not cached + # at all. + # + #sync_response_cache_duration: 2m + ## Database ## @@ -963,6 +974,8 @@ media_store_path: "DATADIR/media_store" # This must be specified if url_preview_enabled is set. It is recommended that # you uncomment the following list as a starting point. # +# Note: The value is ignored when an HTTP proxy is in use +# #url_preview_ip_range_blacklist: # - '127.0.0.0/8' # - '10.0.0.0/8' @@ -2229,6 +2242,14 @@ email: # #require_transport_security: true + # Uncomment the following to disable TLS for SMTP. + # + # By default, if the server supports TLS, it will be used, and the server + # must present a certificate that is valid for 'smtp_host'. If this option + # is set to false, TLS will not be used. + # + #enable_tls: false + # notif_from defines the "From" address to use when sending emails. # It must be set if email sending is enabled. # diff --git a/docs/setup/forward_proxy.md b/docs/setup/forward_proxy.md new file mode 100644 index 0000000000..a0720ab342 --- /dev/null +++ b/docs/setup/forward_proxy.md @@ -0,0 +1,74 @@ +# Using a forward proxy with Synapse + +You can use Synapse with a forward or outbound proxy. An example of when +this is necessary is in corporate environments behind a DMZ (demilitarized zone). +Synapse supports routing outbound HTTP(S) requests via a proxy. Only HTTP(S) +proxy is supported, not SOCKS proxy or anything else. + +## Configure + +The `http_proxy`, `https_proxy`, `no_proxy` environment variables are used to +specify proxy settings. The environment variable is not case sensitive. +- `http_proxy`: Proxy server to use for HTTP requests. +- `https_proxy`: Proxy server to use for HTTPS requests. +- `no_proxy`: Comma-separated list of hosts, IP addresses, or IP ranges in CIDR + format which should not use the proxy. Synapse will directly connect to these hosts. + +The `http_proxy` and `https_proxy` environment variables have the form: `[scheme://][<username>:<password>@]<host>[:<port>]` +- Supported schemes are `http://` and `https://`. The default scheme is `http://` + for compatibility reasons; it is recommended to set a scheme. If scheme is set + to `https://` the connection uses TLS between Synapse and the proxy. + + **NOTE**: Synapse validates the certificates. If the certificate is not + valid, then the connection is dropped. +- Default port if not given is `1080`. +- Username and password are optional and will be used to authenticate against + the proxy. + +**Examples** +- HTTP_PROXY=http://USERNAME:PASSWORD@10.0.1.1:8080/ +- HTTPS_PROXY=http://USERNAME:PASSWORD@proxy.example.com:8080/ +- NO_PROXY=master.hostname.example.com,10.1.0.0/16,172.30.0.0/16 + +**NOTE**: +Synapse does not apply the IP blacklist to connections through the proxy (since +the DNS resolution is done by the proxy). It is expected that the proxy or firewall +will apply blacklisting of IP addresses. + +## Connection types + +The proxy will be **used** for: + +- push +- url previews +- phone-home stats +- recaptcha validation +- CAS auth validation +- OpenID Connect +- Federation (checking public key revocation) + +It will **not be used** for: + +- Application Services +- Identity servers +- Outbound federation +- In worker configurations + - connections between workers + - connections from workers to Redis +- Fetching public keys of other servers +- Downloading remote media + +## Troubleshooting + +If a proxy server is used with TLS (HTTPS) and no connections are established, +it is most likely due to the proxy's certificates. To test this, the validation +in Synapse can be deactivated. + +**NOTE**: This has an impact on security and is for testing purposes only! + +To deactivate the certificate validation, the following setting must be made in +[homserver.yaml](../usage/configuration/homeserver_sample_config.md). + +```yaml +use_insecure_ssl_client_just_for_testing_do_not_use: true +``` diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index cba015d942..5d0ef8dd3a 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -65,4 +65,4 @@ if [[ -n "$1" ]]; then fi # Run the tests! -go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... +go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403,msc2716 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 8d5f38b5d9..d119427ad8 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -151,6 +151,15 @@ class CacheConfig(Config): # entries are never evicted based on time. # #expiry_time: 30m + + # Controls how long the results of a /sync request are cached for after + # a successful response is returned. A higher duration can help clients with + # intermittent connections, at the cost of higher memory usage. + # + # By default, this is zero, which means that sync responses are not cached + # at all. + # + #sync_response_cache_duration: 2m """ def read_config(self, config, **kwargs): @@ -212,6 +221,10 @@ class CacheConfig(Config): else: self.expiry_time_msec = None + self.sync_response_cache_duration = self.parse_duration( + cache_config.get("sync_response_cache_duration", 0) + ) + # Resize all caches (if necessary) with the new factors we've loaded self.resize_all_caches() diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 8d8f166e9b..42526502f0 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -80,6 +80,12 @@ class EmailConfig(Config): self.require_transport_security = email_config.get( "require_transport_security", False ) + self.enable_smtp_tls = email_config.get("enable_tls", True) + if self.require_transport_security and not self.enable_smtp_tls: + raise ConfigError( + "email.require_transport_security requires email.enable_tls to be true" + ) + if "app_name" in email_config: self.email_app_name = email_config["app_name"] else: @@ -368,6 +374,14 @@ class EmailConfig(Config): # #require_transport_security: true + # Uncomment the following to disable TLS for SMTP. + # + # By default, if the server supports TLS, it will be used, and the server + # must present a certificate that is valid for 'smtp_host'. If this option + # is set to false, TLS will not be used. + # + #enable_tls: false + # notif_from defines the "From" address to use when sending emails. # It must be set if email sending is enabled. # diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 0dfb3a227a..7481f3bf5f 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from collections import namedtuple from typing import Dict, List +from urllib.request import getproxies_environment # type: ignore from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements @@ -22,6 +24,8 @@ from synapse.util.module_loader import load_module from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + DEFAULT_THUMBNAIL_SIZES = [ {"width": 32, "height": 32, "method": "crop"}, {"width": 96, "height": 96, "method": "crop"}, @@ -36,6 +40,9 @@ THUMBNAIL_SIZE_YAML = """\ # method: %(method)s """ +HTTP_PROXY_SET_WARNING = """\ +The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" + ThumbnailRequirement = namedtuple( "ThumbnailRequirement", ["width", "height", "method", "media_type"] ) @@ -180,12 +187,17 @@ class ContentRepositoryConfig(Config): e.message # noqa: B306, DependencyException.message is a property ) + proxy_env = getproxies_environment() if "url_preview_ip_range_blacklist" not in config: - raise ConfigError( - "For security, you must specify an explicit target IP address " - "blacklist in url_preview_ip_range_blacklist for url previewing " - "to work" - ) + if "http" not in proxy_env or "https" not in proxy_env: + raise ConfigError( + "For security, you must specify an explicit target IP address " + "blacklist in url_preview_ip_range_blacklist for url previewing " + "to work" + ) + else: + if "http" in proxy_env or "https" in proxy_env: + logger.warning("".join(HTTP_PROXY_SET_WARNING)) # we always blacklist '0.0.0.0' and '::', which are supposed to be # unroutable addresses. @@ -292,6 +304,8 @@ class ContentRepositoryConfig(Config): # This must be specified if url_preview_enabled is set. It is recommended that # you uncomment the following list as a starting point. # + # Note: The value is ignored when an HTTP proxy is in use + # #url_preview_ip_range_blacklist: %(ip_range_blacklist)s diff --git a/synapse/config/server.py b/synapse/config/server.py index b9e0c0b300..187b4301a0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -960,6 +960,8 @@ class ServerConfig(Config): # # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. # + # Note: The value is ignored when an HTTP proxy is in use + # #ip_range_blacklist: %(ip_range_blacklist)s diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b7a10da15a..2eefac04fd 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1108,7 +1108,8 @@ class FederationClient(FederationBase): The response from the remote server. Raises: - HttpResponseException: There was an exception returned from the remote server + HttpResponseException / RequestSendFailed: There was an exception + returned from the remote server SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom requests over federation @@ -1290,7 +1291,7 @@ class FederationClient(FederationBase): ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: """Represents a single event in the result of a successful get_space_summary call. @@ -1299,12 +1300,13 @@ class FederationSpaceSummaryEventResult: object attributes. """ - event_type = attr.ib(type=str) - state_key = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) + event_type: str + room_id: str + state_key: str + via: Sequence[str] # the raw data, including the above keys - data = attr.ib(type=JsonDict) + data: JsonDict @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": @@ -1321,6 +1323,10 @@ class FederationSpaceSummaryEventResult: if not isinstance(event_type, str): raise ValueError("Invalid event: 'event_type' must be a str") + room_id = d.get("room_id") + if not isinstance(room_id, str): + raise ValueError("Invalid event: 'room_id' must be a str") + state_key = d.get("state_key") if not isinstance(state_key, str): raise ValueError("Invalid event: 'state_key' must be a str") @@ -1335,15 +1341,15 @@ class FederationSpaceSummaryEventResult: if any(not isinstance(v, str) for v in via): raise ValueError("Invalid event: 'via' must be a list of strings") - return cls(event_type, state_key, via, d) + return cls(event_type, room_id, state_key, via, d) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryResult: """Represents the data returned by a successful get_space_summary call.""" - rooms = attr.ib(type=Sequence[JsonDict]) - events = attr.ib(type=Sequence[FederationSpaceSummaryEventResult]) + rooms: Sequence[JsonDict] + events: Sequence[FederationSpaceSummaryEventResult] @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 145b9161d9..0385aadefa 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -195,13 +195,17 @@ class FederationServer(FederationBase): origin, room_id, versions, limit ) - res = self._transaction_from_pdus(pdus).get_dict() + res = self._transaction_dict_from_pdus(pdus) return 200, res async def on_incoming_transaction( - self, origin: str, transaction_data: JsonDict - ) -> Tuple[int, Dict[str, Any]]: + self, + origin: str, + transaction_id: str, + destination: str, + transaction_data: JsonDict, + ) -> Tuple[int, JsonDict]: # If we receive a transaction we should make sure that kick off handling # any old events in the staging area. if not self._started_handling_of_staged_events: @@ -212,8 +216,14 @@ class FederationServer(FederationBase): # accurate as possible. request_time = self._clock.time_msec() - transaction = Transaction(**transaction_data) - transaction_id = transaction.transaction_id # type: ignore + transaction = Transaction( + transaction_id=transaction_id, + destination=destination, + origin=origin, + origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore + pdus=transaction_data.get("pdus"), # type: ignore + edus=transaction_data.get("edus"), + ) if not transaction_id: raise Exception("Transaction missing transaction_id") @@ -221,9 +231,7 @@ class FederationServer(FederationBase): logger.debug("[%s] Got transaction", transaction_id) # Reject malformed transactions early: reject if too many PDUs/EDUs - if len(transaction.pdus) > 50 or ( # type: ignore - hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore - ): + if len(transaction.pdus) > 50 or len(transaction.edus) > 100: logger.info("Transaction PDU or EDU count too large. Returning 400") return 400, {} @@ -263,7 +271,7 @@ class FederationServer(FederationBase): # CRITICAL SECTION: the first thing we must do (before awaiting) is # add an entry to _active_transactions. assert origin not in self._active_transactions - self._active_transactions[origin] = transaction.transaction_id # type: ignore + self._active_transactions[origin] = transaction.transaction_id try: result = await self._handle_incoming_transaction( @@ -291,11 +299,11 @@ class FederationServer(FederationBase): if response: logger.debug( "[%s] We've already responded to this request", - transaction.transaction_id, # type: ignore + transaction.transaction_id, ) return response - logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore + logger.debug("[%s] Transaction is new", transaction.transaction_id) # We process PDUs and EDUs in parallel. This is important as we don't # want to block things like to device messages from reaching clients @@ -334,7 +342,7 @@ class FederationServer(FederationBase): report back to the sending server. """ - received_pdus_counter.inc(len(transaction.pdus)) # type: ignore + received_pdus_counter.inc(len(transaction.pdus)) origin_host, _ = parse_server_name(origin) @@ -342,7 +350,7 @@ class FederationServer(FederationBase): newest_pdu_ts = 0 - for p in transaction.pdus: # type: ignore + for p in transaction.pdus: # FIXME (richardv): I don't think this works: # https://github.com/matrix-org/synapse/issues/8429 if "unsigned" in p: @@ -436,10 +444,10 @@ class FederationServer(FederationBase): return pdu_results - async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): + async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None: """Process the EDUs in a received transaction.""" - async def _process_edu(edu_dict): + async def _process_edu(edu_dict: JsonDict) -> None: received_edus_counter.inc() edu = Edu( @@ -452,7 +460,7 @@ class FederationServer(FederationBase): await concurrently_execute( _process_edu, - getattr(transaction, "edus", []), + transaction.edus, TRANSACTION_CONCURRENCY_LIMIT, ) @@ -538,7 +546,7 @@ class FederationServer(FederationBase): pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: - return 200, self._transaction_from_pdus([pdu]).get_dict() + return 200, self._transaction_dict_from_pdus([pdu]) else: return 404, "" @@ -879,18 +887,20 @@ class FederationServer(FederationBase): ts_now_ms = self._clock.time_msec() return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) - def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction: + def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict: """Returns a new Transaction containing the given PDUs suitable for transmission. """ time_now = self._clock.time_msec() pdus = [p.get_pdu_json(time_now) for p in pdu_list] return Transaction( + # Just need a dummy transaction ID and destination since it won't be used. + transaction_id="", origin=self.server_name, pdus=pdus, origin_server_ts=int(time_now), - destination=None, - ) + destination="", + ).get_dict() async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: """Process a PDU received in a federation /send/ transaction. diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 2f9c9bc2cd..4fead6ca29 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -45,7 +45,7 @@ class TransactionActions: `None` if we have not previously responded to this transaction or a 2-tuple of `(int, dict)` representing the response code and response body. """ - transaction_id = transaction.transaction_id # type: ignore + transaction_id = transaction.transaction_id if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") @@ -56,7 +56,7 @@ class TransactionActions: self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: """Persist how we responded to a transaction.""" - transaction_id = transaction.transaction_id # type: ignore + transaction_id = transaction.transaction_id if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 72a635830b..dc555cca0b 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -27,6 +27,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.metrics import measure_func @@ -104,13 +105,13 @@ class TransactionManager: len(edus), ) - transaction = Transaction.create_new( + transaction = Transaction( origin_server_ts=int(self.clock.time_msec()), transaction_id=txn_id, origin=self._server_name, destination=destination, - pdus=pdus, - edus=edus, + pdus=[p.get_pdu_json() for p in pdus], + edus=[edu.get_dict() for edu in edus], ) self._next_txn_id += 1 @@ -131,7 +132,7 @@ class TransactionManager: # FIXME (richardv): I also believe it no longer works. We (now?) store # "age_ts" in "unsigned" rather than at the top level. See # https://github.com/matrix-org/synapse/issues/8429. - def json_data_cb(): + def json_data_cb() -> JsonDict: data = transaction.get_dict() now = int(self.clock.time_msec()) if "pdus" in data: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 6a8d3ad4fe..90a7c16b62 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -143,7 +143,7 @@ class TransportLayerClient: """Sends the given Transaction to its destination Args: - transaction (Transaction) + transaction Returns: Succeeds when we get a 2xx HTTP response. The result diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 5e059d6e09..640f46fff6 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -450,21 +450,12 @@ class FederationSendServlet(BaseFederationServerServlet): len(transaction_data.get("edus", [])), ) - # We should ideally be getting this from the security layer. - # origin = body["origin"] - - # Add some extra data to the transaction dict that isn't included - # in the request body. - transaction_data.update( - transaction_id=transaction_id, destination=self.server_name - ) - except Exception as e: logger.exception(e) return 400, {"error": "Invalid transaction"} code, response = await self.handler.on_incoming_transaction( - origin, transaction_data + origin, transaction_id, self.server_name, transaction_data ) return code, response diff --git a/synapse/federation/units.py b/synapse/federation/units.py index c83a261918..b9b12fbea5 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -17,18 +17,17 @@ server protocol. """ import logging -from typing import Optional +from typing import List, Optional import attr from synapse.types import JsonDict -from synapse.util.jsonobject import JsonEncodedObject logger = logging.getLogger(__name__) -@attr.s(slots=True) -class Edu(JsonEncodedObject): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Edu: """An Edu represents a piece of data sent from one homeserver to another. In comparison to Pdus, Edus are not persisted for a long time on disk, are @@ -36,10 +35,10 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - edu_type = attr.ib(type=str) - content = attr.ib(type=dict) - origin = attr.ib(type=str) - destination = attr.ib(type=str) + edu_type: str + content: dict + origin: str + destination: str def get_dict(self) -> JsonDict: return { @@ -55,14 +54,21 @@ class Edu(JsonEncodedObject): "destination": self.destination, } - def get_context(self): + def get_context(self) -> str: return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") - def strip_context(self): + def strip_context(self) -> None: getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}" -class Transaction(JsonEncodedObject): +def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]: + if edus is None: + return [] + return edus + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Transaction: """A transaction is a list of Pdus and Edus to be sent to a remote home server with some extra metadata. @@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject): """ - valid_keys = [ - "transaction_id", - "origin", - "destination", - "origin_server_ts", - "previous_ids", - "pdus", - "edus", - ] - - internal_keys = ["transaction_id", "destination"] - - required_keys = [ - "transaction_id", - "origin", - "destination", - "origin_server_ts", - "pdus", - ] - - def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs): - """If we include a list of pdus then we decode then as PDU's - automatically. - """ - - # If there's no EDUs then remove the arg - if "edus" in kwargs and not kwargs["edus"]: - del kwargs["edus"] - - super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs) - - @staticmethod - def create_new(pdus, **kwargs): - """Used to create a new transaction. Will auto fill out - transaction_id and origin_server_ts keys. - """ - if "origin_server_ts" not in kwargs: - raise KeyError("Require 'origin_server_ts' to construct a Transaction") - if "transaction_id" not in kwargs: - raise KeyError("Require 'transaction_id' to construct a Transaction") - - kwargs["pdus"] = [p.get_pdu_json() for p in pdus] - - return Transaction(**kwargs) + # Required keys. + transaction_id: str + origin: str + destination: str + origin_server_ts: int + pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) + edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) + + def get_dict(self) -> JsonDict: + """A JSON-ready dictionary of valid keys which aren't internal.""" + result = { + "origin": self.origin, + "origin_server_ts": self.origin_server_ts, + "pdus": self.pdus, + } + if self.edus: + result["edus"] = self.edus + return result diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 21a17cd2e8..4ab4046650 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -392,9 +392,6 @@ class ApplicationServicesHandler: protocols[p].append(info) def _merge_instances(infos: List[JsonDict]) -> JsonDict: - if not infos: - return {} - # Merge the 'instances' lists of multiple results, but just take # the other fields from the first as they ought to be identical # copy the result so as not to corrupt the cached one @@ -406,7 +403,9 @@ class ApplicationServicesHandler: return combined - return {p: _merge_instances(protocols[p]) for p in protocols.keys()} + return { + p: _merge_instances(protocols[p]) for p in protocols.keys() if protocols[p] + } async def _get_services_for_event( self, event: EventBase diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8197b60b76..9a5e726533 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -42,6 +42,7 @@ from twisted.internet import defer from synapse import event_auth from synapse.api.constants import ( + EventContentFields, EventTypes, Membership, RejectedReason, @@ -108,21 +109,33 @@ soft_failed_event_counter = Counter( ) -@attr.s(slots=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class _NewEventInfo: """Holds information about a received event, ready for passing to _auth_and_persist_events Attributes: event: the received event - state: the state at that event + state: the state at that event, according to /state_ids from a remote + homeserver. Only populated for backfilled events which are going to be a + new backwards extremity. + + claimed_auth_event_map: a map of (type, state_key) => event for the event's + claimed auth_events. + + This can include events which have not yet been persisted, in the case that + we are backfilling a batch of events. + + Note: May be incomplete: if we were unable to find all of the claimed auth + events. Also, treat the contents with caution: the events might also have + been rejected, might not yet have been authorized themselves, or they might + be in the wrong room. - auth_events: the auth_event map for that event """ - event = attr.ib(type=EventBase) - state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None) + event: EventBase + state: Optional[Sequence[EventBase]] + claimed_auth_event_map: StateMap[EventBase] class FederationHandler(BaseHandler): @@ -262,7 +275,12 @@ class FederationHandler(BaseHandler): state = None - # Get missing pdus if necessary. + # Check that the event passes auth based on the state at the event. This is + # done for events that are to be added to the timeline (non-outliers). + # + # Get missing pdus if necessary: + # - Fetching any missing prev events to fill in gaps in the graph + # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. min_depth = await self.get_min_depth_for_context(pdu.room_id) @@ -432,6 +450,13 @@ class FederationHandler(BaseHandler): affected=event_id, ) + # A second round of checks for all events. Check that the event passes auth + # based on `auth_events`, this allows us to assert that the event would + # have been allowed at some point. If an event passes this check its OK + # for it to be used as part of a returned `/state` request, as either + # a) we received the event as part of the original join and so trust it, or + # b) we'll do a state resolution with existing state before it becomes + # part of the "current state", which adds more protection. await self._process_received_pdu(origin, pdu, state=state) async def _get_missing_events_for_pdu( @@ -889,6 +914,79 @@ class FederationHandler(BaseHandler): "resync_device_due_to_pdu", self._resync_device, event.sender ) + await self._handle_marker_event(origin, event) + + async def _handle_marker_event(self, origin: str, marker_event: EventBase): + """Handles backfilling the insertion event when we receive a marker + event that points to one. + + Args: + origin: Origin of the event. Will be called to get the insertion event + marker_event: The event to process + """ + + if marker_event.type != EventTypes.MSC2716_MARKER: + # Not a marker event + return + + if marker_event.rejected_reason is not None: + # Rejected event + return + + # Skip processing a marker event if the room version doesn't + # support it. + room_version = await self.store.get_room_version(marker_event.room_id) + if not room_version.msc2716_historical: + return + + logger.debug("_handle_marker_event: received %s", marker_event) + + insertion_event_id = marker_event.content.get( + EventContentFields.MSC2716_MARKER_INSERTION + ) + + if insertion_event_id is None: + # Nothing to retrieve then (invalid marker) + return + + logger.debug( + "_handle_marker_event: backfilling insertion event %s", insertion_event_id + ) + + await self._get_events_and_persist( + origin, + marker_event.room_id, + [insertion_event_id], + ) + + insertion_event = await self.store.get_event( + insertion_event_id, allow_none=True + ) + if insertion_event is None: + logger.warning( + "_handle_marker_event: server %s didn't return insertion event %s for marker %s", + origin, + insertion_event_id, + marker_event.event_id, + ) + return + + logger.debug( + "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", + insertion_event, + marker_event, + ) + + await self.store.insert_insertion_extremity( + insertion_event_id, marker_event.room_id + ) + + logger.debug( + "_handle_marker_event: insertion extremity added for %s from marker event %s", + insertion_event, + marker_event, + ) + async def _resync_device(self, sender: str) -> None: """We have detected that the device list for the given user may be out of sync, so we try and resync them. @@ -1000,7 +1098,7 @@ class FederationHandler(BaseHandler): _NewEventInfo( event=ev, state=events_to_state[e_id], - auth_events={ + claimed_auth_event_map={ ( auth_events[a_id].type, auth_events[a_id].state_key, @@ -1057,9 +1155,19 @@ class FederationHandler(BaseHandler): async def _maybe_backfill_inner( self, room_id: str, current_depth: int, limit: int ) -> bool: - extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) + oldest_events_with_depth = ( + await self.store.get_oldest_event_ids_with_depth_in_room(room_id) + ) + insertion_events_to_be_backfilled = ( + await self.store.get_insertion_event_backwards_extremities_in_room(room_id) + ) + logger.debug( + "_maybe_backfill_inner: extremities oldest_events_with_depth=%s insertion_events_to_be_backfilled=%s", + oldest_events_with_depth, + insertion_events_to_be_backfilled, + ) - if not extremities: + if not oldest_events_with_depth and not insertion_events_to_be_backfilled: logger.debug("Not backfilling as no extremeties found.") return False @@ -1089,10 +1197,12 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = await self.store.get_successor_events(list(extremities)) + forward_event_ids = await self.store.get_successor_events( + list(oldest_events_with_depth) + ) extremities_events = await self.store.get_events( - forward_events, + forward_event_ids, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, ) @@ -1106,10 +1216,19 @@ class FederationHandler(BaseHandler): redact=False, check_history_visibility_only=True, ) + logger.debug( + "_maybe_backfill_inner: filtered_extremities %s", filtered_extremities + ) - if not filtered_extremities: + if not filtered_extremities and not insertion_events_to_be_backfilled: return False + extremities = { + **oldest_events_with_depth, + # TODO: insertion_events_to_be_backfilled is currently skipping the filtered_extremities checks + **insertion_events_to_be_backfilled, + } + # Check if we reached a point where we should start backfilling. sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) max_depth = sorted_extremeties_tuple[0][1] @@ -2208,7 +2327,7 @@ class FederationHandler(BaseHandler): event: EventBase, context: EventContext, state: Optional[Iterable[EventBase]] = None, - auth_events: Optional[MutableStateMap[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, backfilled: bool = False, ) -> None: """ @@ -2220,17 +2339,18 @@ class FederationHandler(BaseHandler): context: The event context. - NB that this function potentially modifies it. state: The state events used to check the event for soft-fail. If this is not provided the current state events will be used. - auth_events: - Map from (event_type, state_key) to event - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. + + Only populated where we may not already have persisted these events - + for example, when populating outliers. + backfilled: True if the event was backfilled. """ context = await self._check_event_auth( @@ -2238,7 +2358,7 @@ class FederationHandler(BaseHandler): event, context, state=state, - auth_events=auth_events, + claimed_auth_event_map=claimed_auth_event_map, backfilled=backfilled, ) @@ -2302,7 +2422,7 @@ class FederationHandler(BaseHandler): event, res, state=ev_info.state, - auth_events=ev_info.auth_events, + claimed_auth_event_map=ev_info.claimed_auth_event_map, backfilled=backfilled, ) return res @@ -2568,7 +2688,7 @@ class FederationHandler(BaseHandler): event: EventBase, context: EventContext, state: Optional[Iterable[EventBase]] = None, - auth_events: Optional[MutableStateMap[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, backfilled: bool = False, ) -> EventContext: """ @@ -2580,21 +2700,19 @@ class FederationHandler(BaseHandler): context: The event context. - NB that this function potentially modifies it. state: The state events used to check the event for soft-fail. If this is not provided the current state events will be used. - auth_events: - Map from (event_type, state_key) to event - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. - Also NB that this function adds entries to it. + Only populated where we may not already have persisted these events - + for example, when populating outliers, or the state for a backwards + extremity. - If this is not provided, it is calculated from the previous state IDs. backfilled: True if the event was backfilled. Returns: @@ -2603,7 +2721,12 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if not auth_events: + if claimed_auth_event_map: + # if we have a copy of the auth events from the event, use that as the + # basis for auth. + auth_events = claimed_auth_event_map + else: + # otherwise, we calculate what the auth events *should* be, and use that prev_state_ids = await context.get_prev_state_ids() auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True @@ -2611,18 +2734,11 @@ class FederationHandler(BaseHandler): auth_events_x = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} - # This is a hack to fix some old rooms where the initial join event - # didn't reference the create event in its auth events. - if event.type == EventTypes.Member and not event.auth_event_ids(): - if len(event.prev_event_ids()) == 1 and event.depth < 5: - c = await self.store.get_event( - event.prev_event_ids()[0], allow_none=True - ) - if c and c.type == EventTypes.Create: - auth_events[(c.type, c.state_key)] = c - try: - context = await self._update_auth_events_and_context_for_auth( + ( + context, + auth_events_for_auth, + ) = await self._update_auth_events_and_context_for_auth( origin, event, context, auth_events ) except Exception: @@ -2635,9 +2751,10 @@ class FederationHandler(BaseHandler): "Ignoring failure and continuing processing of event.", event.event_id, ) + auth_events_for_auth = auth_events try: - event_auth.check(room_version_obj, event, auth_events=auth_events) + event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR @@ -2662,8 +2779,8 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, context: EventContext, - auth_events: MutableStateMap[EventBase], - ) -> EventContext: + input_auth_events: StateMap[EventBase], + ) -> Tuple[EventContext, StateMap[EventBase]]: """Helper for _check_event_auth. See there for docs. Checks whether a given event has the expected auth events. If it @@ -2680,7 +2797,7 @@ class FederationHandler(BaseHandler): event: context: - auth_events: + input_auth_events: Map from (event_type, state_key) to event Normally, our calculated auth_events based on the state of the room @@ -2688,11 +2805,12 @@ class FederationHandler(BaseHandler): event is an outlier), may be the auth events claimed by the remote server. - Also NB that this function adds entries to it. - Returns: - updated context + updated context, updated auth event map """ + # take a copy of input_auth_events before we modify it. + auth_events: MutableStateMap[EventBase] = dict(input_auth_events) + event_auth_events = set(event.auth_event_ids()) # missing_auth is the set of the event's auth_events which we don't yet have @@ -2721,7 +2839,7 @@ class FederationHandler(BaseHandler): # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e1) - return context + return context, auth_events seen_remotes = await self.store.have_seen_events( event.room_id, [e.event_id for e in remote_auth_chain] @@ -2752,7 +2870,10 @@ class FederationHandler(BaseHandler): await self.state_handler.compute_event_context(e) ) await self._auth_and_persist_event( - origin, e, missing_auth_event_context, auth_events=auth + origin, + e, + missing_auth_event_context, + claimed_auth_event_map=auth, ) if e.event_id in event_auth_events: @@ -2770,14 +2891,14 @@ class FederationHandler(BaseHandler): # obviously be empty # (b) alternatively, why don't we do it earlier? logger.info("Skipping auth_event fetch for outlier") - return context + return context, auth_events different_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) if not different_auth: - return context + return context, auth_events logger.info( "auth_events refers to events which are not in our calculated auth " @@ -2803,7 +2924,7 @@ class FederationHandler(BaseHandler): # XXX: should we reject the event in this case? It feels like we should, # but then shouldn't we also do so if we've failed to fetch any of the # auth events? - return context + return context, auth_events # now we state-resolve between our own idea of the auth events, and the remote's # idea of them. @@ -2833,7 +2954,7 @@ class FederationHandler(BaseHandler): event, context, auth_events ) - return context + return context, auth_events async def _update_context_for_auth_events( self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 0961dec5ab..8ffeabacf9 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -824,6 +824,7 @@ class IdentityHandler(BaseHandler): room_avatar_url: str, room_join_rules: str, room_name: str, + room_type: Optional[str], inviter_display_name: str, inviter_avatar_url: str, id_access_token: Optional[str] = None, @@ -843,6 +844,7 @@ class IdentityHandler(BaseHandler): notifications. room_join_rules: The join rules of the email (e.g. "public"). room_name: The m.room.name of the room. + room_type: The type of the room from its m.room.create event (e.g "m.space"). inviter_display_name: The current display name of the inviter. inviter_avatar_url: The URL of the inviter's avatar. @@ -869,6 +871,10 @@ class IdentityHandler(BaseHandler): "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, } + + if room_type is not None: + invite_config["org.matrix.msc3288.room_type"] = room_type + # If a custom web client location is available, include it in the request. if self._web_client_location: invite_config["org.matrix.web_client_location"] = self._web_client_location diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index b9085bbccb..5fd4525700 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -70,7 +70,8 @@ class ReceiptsHandler(BaseHandler): ) if not is_in_room: logger.info( - "Ignoring receipt from %s as we're not in the room", + "Ignoring receipt for room %r from server %s as we're not in the room", + room_id, origin, ) continue diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index fae2c098e3..6d433fad41 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -356,6 +356,12 @@ class RoomListHandler(BaseHandler): include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, ) -> JsonDict: + """Get the public room list from remote server + + Raises: + SynapseError + """ + if not self.enable_room_list_search: return {"chunk": [], "total_room_count_estimate": 0} @@ -395,13 +401,16 @@ class RoomListHandler(BaseHandler): limit = None since_token = None - res = await self._get_remote_list_cached( - server_name, - limit=limit, - since_token=since_token, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) + try: + res = await self._get_remote_list_cached( + server_name, + limit=limit, + since_token=since_token, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + except (RequestSendFailed, HttpResponseException): + raise SynapseError(502, "Failed to fetch room list") if search_filter: res = { @@ -423,20 +432,21 @@ class RoomListHandler(BaseHandler): include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, ) -> JsonDict: + """Wrapper around FederationClient.get_public_rooms that caches the + result. + """ + repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search - try: - return await repl_layer.get_public_rooms( - server_name, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - except (RequestSendFailed, HttpResponseException): - raise SynapseError(502, "Failed to fetch room list") + return await repl_layer.get_public_rooms( + server_name, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) key = ( server_name, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 65ad3efa6a..ba13196218 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -19,7 +19,12 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple from synapse import types -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + EventContentFields, + EventTypes, + Membership, +) from synapse.api.errors import ( AuthError, Codes, @@ -1237,6 +1242,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if room_name_event: room_name = room_name_event.content.get("name", "") + room_type = None + room_create_event = room_state.get((EventTypes.Create, "")) + if room_create_event: + room_type = room_create_event.content.get(EventContentFields.ROOM_TYPE) + room_join_rules = "" join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: @@ -1263,6 +1273,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_avatar_url=room_avatar_url, room_join_rules=room_join_rules, room_name=room_name, + room_type=room_type, inviter_display_name=inviter_display_name, inviter_avatar_url=inviter_avatar_url, id_access_token=id_access_token, diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index e9f6aef06f..dda9659c11 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -16,7 +16,12 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING +from io import BytesIO +from typing import TYPE_CHECKING, Optional + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IReactorTCP +from twisted.mail.smtp import ESMTPSenderFactory from synapse.logging.context import make_deferred_yieldable @@ -26,19 +31,75 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +async def _sendmail( + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + username: Optional[bytes] = None, + password: Optional[bytes] = None, + require_auth: bool = False, + require_tls: bool = False, + tls_hostname: Optional[str] = None, +) -> None: + """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests + + Params: + reactor: reactor to use to make the outbound connection + smtphost: hostname to connect to + smtpport: port to connect to + from_addr: "From" address for email + to_addr: "To" address for email + msg_bytes: Message content + username: username to authenticate with, if auth is enabled + password: password to give when authenticating + require_auth: if auth is not offered, fail the request + require_tls: if TLS is not offered, fail the reqest + tls_hostname: TLS hostname to check for. None to disable TLS. + """ + msg = BytesIO(msg_bytes) + + d: "Deferred[object]" = Deferred() + + factory = ESMTPSenderFactory( + username, + password, + from_addr, + to_addr, + msg, + d, + heloFallback=True, + requireAuthentication=require_auth, + requireTransportSecurity=require_tls, + hostname=tls_hostname, + ) + + # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong + reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type] + + await make_deferred_yieldable(d) + + class SendEmailHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self._sendmail = hs.get_sendmail() self._reactor = hs.get_reactor() self._from = hs.config.email.email_notif_from self._smtp_host = hs.config.email.email_smtp_host self._smtp_port = hs.config.email.email_smtp_port - self._smtp_user = hs.config.email.email_smtp_user - self._smtp_pass = hs.config.email.email_smtp_pass + + user = hs.config.email.email_smtp_user + self._smtp_user = user.encode("utf-8") if user is not None else None + passwd = hs.config.email.email_smtp_pass + self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None self._require_transport_security = hs.config.email.require_transport_security + self._enable_tls = hs.config.email.enable_smtp_tls + + self._sendmail = _sendmail async def send_email( self, @@ -82,17 +143,16 @@ class SendEmailHandler: logger.info("Sending email to %s" % email_address) - await make_deferred_yieldable( - self._sendmail( - self._smtp_host, - raw_from, - raw_to, - multipart_msg.as_string().encode("utf8"), - reactor=self._reactor, - port=self._smtp_port, - requireAuthentication=self._smtp_user is not None, - username=self._smtp_user, - password=self._smtp_pass, - requireTransportSecurity=self._require_transport_security, - ) + await self._sendmail( + self._reactor, + self._smtp_host, + self._smtp_port, + raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + username=self._smtp_user, + password=self._smtp_pass, + require_auth=self._smtp_user is not None, + require_tls=self._require_transport_security, + tls_hostname=self._smtp_host if self._enable_tls else None, ) diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 5f7d4602bd..2517f278b6 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -16,7 +16,17 @@ import itertools import logging import re from collections import deque -from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) import attr @@ -116,20 +126,22 @@ class SpaceSummaryHandler: max_children = max_rooms_per_space if processed_rooms else None if is_in_room: - room, events = await self._summarize_local_room( + room_entry = await self._summarize_local_room( requester, None, room_id, suggested_only, max_children ) + events: Collection[JsonDict] = [] + if room_entry: + rooms_result.append(room_entry.room) + events = room_entry.children + logger.debug( "Query of local room %s returned events %s", room_id, ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], ) - - if room: - rooms_result.append(room) else: - fed_rooms, fed_events = await self._summarize_remote_room( + fed_rooms = await self._summarize_remote_room( queue_entry, suggested_only, max_children, @@ -141,12 +153,10 @@ class SpaceSummaryHandler: # user is not permitted see. # # Filter the returned results to only what is accessible to the user. - room_ids = set() events = [] - for room in fed_rooms: - fed_room_id = room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue + for room_entry in fed_rooms: + room = room_entry.room + fed_room_id = room_entry.room_id # The room should only be included in the summary if: # a. the user is in the room; @@ -169,7 +179,9 @@ class SpaceSummaryHandler: # Check if the user is a member of any of the allowed spaces # from the response. - allowed_rooms = room.get("allowed_spaces") + allowed_rooms = room.get("allowed_room_ids") or room.get( + "allowed_spaces" + ) if ( not include_room and allowed_rooms @@ -188,22 +200,23 @@ class SpaceSummaryHandler: # The user can see the room, include it! if include_room: + # Before returning to the client, remove the allowed_room_ids + # and allowed_spaces keys. + room.pop("allowed_room_ids", None) + room.pop("allowed_spaces", None) + rooms_result.append(room) - room_ids.add(fed_room_id) + events.extend(room_entry.children) # All rooms returned don't need visiting again (even if the user # didn't have access to them). processed_rooms.add(fed_room_id) - for event in fed_events: - if event.get("room_id") in room_ids: - events.append(event) - logger.debug( "Query of %s returned rooms %s, events %s", room_id, - [room.get("room_id") for room in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in fed_events], + [room_entry.room.get("room_id") for room_entry in fed_rooms], + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], ) # the room we queried may or may not have been returned, but don't process @@ -230,11 +243,6 @@ class SpaceSummaryHandler: ) processed_events.add(ev_key) - # Before returning to the client, remove the allowed_spaces key for any - # rooms. - for room in rooms_result: - room.pop("allowed_spaces", None) - return {"rooms": rooms_result, "events": events_result} async def federation_space_summary( @@ -283,20 +291,20 @@ class SpaceSummaryHandler: # already done this room continue - logger.debug("Processing room %s", room_id) - - room, events = await self._summarize_local_room( + room_entry = await self._summarize_local_room( None, origin, room_id, suggested_only, max_rooms_per_space ) processed_rooms.add(room_id) - if room: - rooms_result.append(room) - events_result.extend(events) + if room_entry: + rooms_result.append(room_entry.room) + events_result.extend(room_entry.children) - # add any children to the queue - room_queue.extend(edge_event["state_key"] for edge_event in events) + # add any children to the queue + room_queue.extend( + edge_event["state_key"] for edge_event in room_entry.children + ) return {"rooms": rooms_result, "events": events_result} @@ -307,7 +315,7 @@ class SpaceSummaryHandler: room_id: str, suggested_only: bool, max_children: Optional[int], - ) -> Tuple[Optional[JsonDict], Sequence[JsonDict]]: + ) -> Optional["_RoomEntry"]: """ Generate a room entry and a list of event entries for a given room. @@ -326,21 +334,16 @@ class SpaceSummaryHandler: to a server-set limit. Returns: - A tuple of: - The room information, if the room should be returned to the - user. None, otherwise. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. + A room entry if the room should be returned. None, otherwise. """ if not await self._is_room_accessible(room_id, requester, origin): - return None, () + return None - room_entry = await self._build_room_entry(room_id) + room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) # If the room is not a space, return just the room information. if room_entry.get("room_type") != RoomTypes.SPACE: - return room_entry, () + return _RoomEntry(room_id, room_entry) # Otherwise, look for child rooms/spaces. child_events = await self._get_child_events(room_id) @@ -363,7 +366,7 @@ class SpaceSummaryHandler: ) ) - return room_entry, events_result + return _RoomEntry(room_id, room_entry, events_result) async def _summarize_remote_room( self, @@ -371,7 +374,7 @@ class SpaceSummaryHandler: suggested_only: bool, max_children: Optional[int], exclude_rooms: Iterable[str], - ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]: + ) -> Iterable["_RoomEntry"]: """ Request room entries and a list of event entries for a given room by querying a remote server. @@ -386,11 +389,7 @@ class SpaceSummaryHandler: Rooms IDs which do not need to be summarized. Returns: - A tuple of: - An iterable of rooms. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. + An iterable of room entries. """ room_id = room.room_id logger.info("Requesting summary for %s via %s", room_id, room.via) @@ -414,11 +413,30 @@ class SpaceSummaryHandler: e, exc_info=logger.isEnabledFor(logging.DEBUG), ) - return (), () + return () + + # Group the events by their room. + children_by_room: Dict[str, List[JsonDict]] = {} + for ev in res.events: + if ev.event_type == EventTypes.SpaceChild: + children_by_room.setdefault(ev.room_id, []).append(ev.data) + + # Generate the final results. + results = [] + for fed_room in res.rooms: + fed_room_id = fed_room.get("room_id") + if not fed_room_id or not isinstance(fed_room_id, str): + continue - return res.rooms, tuple( - ev.data for ev in res.events if ev.event_type == EventTypes.SpaceChild - ) + results.append( + _RoomEntry( + fed_room_id, + fed_room, + children_by_room.get(fed_room_id, []), + ) + ) + + return results async def _is_room_accessible( self, room_id: str, requester: Optional[str], origin: Optional[str] @@ -532,8 +550,18 @@ class SpaceSummaryHandler: ) return False - async def _build_room_entry(self, room_id: str) -> JsonDict: - """Generate en entry suitable for the 'rooms' list in the summary response""" + async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: + """ + Generate en entry suitable for the 'rooms' list in the summary response. + + Args: + room_id: The room ID to summarize. + for_federation: True if this is a summary requested over federation + (which includes additional fields). + + Returns: + The JSON dictionary for the room. + """ stats = await self._store.get_room_with_stats(room_id) # currently this should be impossible because we call @@ -546,15 +574,6 @@ class SpaceSummaryHandler: current_state_ids[(EventTypes.Create, "")] ) - room_version = await self._store.get_room_version(room_id) - allowed_rooms = None - if await self._event_auth_handler.has_restricted_join_rules( - current_state_ids, room_version - ): - allowed_rooms = await self._event_auth_handler.get_rooms_that_allow_join( - current_state_ids - ) - entry = { "room_id": stats["room_id"], "name": stats["name"], @@ -569,9 +588,25 @@ class SpaceSummaryHandler: "guest_can_join": stats["guest_access"] == "can_join", "creation_ts": create_event.origin_server_ts, "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), - "allowed_spaces": allowed_rooms, } + # Federation requests need to provide additional information so the + # requested server is able to filter the response appropriately. + if for_federation: + room_version = await self._store.get_room_version(room_id) + if await self._event_auth_handler.has_restricted_join_rules( + current_state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join( + current_state_ids + ) + ) + if allowed_rooms: + entry["allowed_room_ids"] = allowed_rooms + # TODO Remove this key once the API is stable. + entry["allowed_spaces"] = allowed_rooms + # Filter out Nones – rather omit the field altogether room_entry = {k: v for k, v in entry.items() if v is not None} @@ -606,10 +641,21 @@ class SpaceSummaryHandler: return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class _RoomQueueEntry: - room_id = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) + room_id: str + via: Sequence[str] + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _RoomEntry: + room_id: str + # The room summary for this room. + room: JsonDict + # An iterable of the sorted, stripped children events for children of this room. + # + # This may not include all children. + children: Collection[JsonDict] = () def _has_valid_via(e: EventBase) -> bool: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f30bfcc93c..590642f510 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -269,14 +269,22 @@ class SyncHandler: self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( - hs.get_clock(), "sync" - ) self.state = hs.get_state_handler() self.auth = hs.get_auth() self.storage = hs.get_storage() self.state_store = self.storage.state + # TODO: flush cache entries on subsequent sync request. + # Once we get the next /sync request (ie, one with the same access token + # that sets 'since' to 'next_batch'), we know that device won't need a + # cached result any more, and we could flush the entry from the cache to save + # memory. + self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( + hs.get_clock(), + "sync", + timeout_ms=hs.config.caches.sync_response_cache_duration, + ) + # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) self.lazy_loaded_members_cache: ExpiringCache[ Tuple[str, Optional[str]], LruCache[str, str] diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0cb651a400..a97c448595 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -335,7 +335,8 @@ class TypingWriterHandler(FollowerTypingHandler): ) if not is_in_room: logger.info( - "Ignoring typing update from %s as we're not in the room", + "Ignoring typing update for room %r from server %s as we're not in the room", + room_id, origin, ) return diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 473812b8e2..1cc13fc97b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester from synapse.util import Clock from synapse.util.caches.descriptors import cached @@ -174,6 +174,16 @@ class ModuleApi: """The application name configured in the homeserver's configuration.""" return self._hs.config.email.email_app_name + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get user info by user_id + + Args: + user_id: Fully qualified user id. + Returns: + UserInfo object if a user was found, otherwise None + """ + return await self._store.get_userinfo_by_id(user_id) + async def get_user_by_req( self, req: SynapseRequest, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 502a917588..f887970b76 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -23,7 +23,6 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, - HttpResponseException, InvalidClientCredentialsError, ShadowBanError, SynapseError, @@ -458,6 +457,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): "state_key": state_event["state_key"], } + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + # Make the state events float off on their own fake_prev_event_id = "$" + random_string(43) @@ -562,7 +564,10 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): "type": EventTypes.MSC2716_CHUNK, "sender": requester.user.to_string(), "room_id": room_id, - "content": {EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to}, + "content": { + EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, + EventContentFields.MSC2716_HISTORICAL: True, + }, # Since the chunk event is put at the end of the chunk, # where the newest-in-time event is, copy the origin_server_ts from # the last event we're inserting @@ -589,10 +594,6 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): for ev in events_to_create: assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) - # Mark all events as historical - # This has important semantics within the Synapse internals to backfill properly - ev["content"][EventContentFields.MSC2716_HISTORICAL] = True - event_dict = { "type": ev["type"], "origin_server_ts": ev["origin_server_ts"], @@ -602,6 +603,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): "prev_events": prev_event_ids.copy(), } + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + event, context = await self.event_creation_handler.create_event( await self._create_requester_for_user_id_from_app_service( ev["sender"], requester.app_service @@ -778,12 +782,9 @@ class PublicRoomListRestServlet(TransactionRestServlet): Codes.INVALID_PARAM, ) - try: - data = await handler.get_remote_public_room_list( - server, limit=limit, since_token=since_token - ) - except HttpResponseException as e: - raise e.to_synapse_error() + data = await handler.get_remote_public_room_list( + server, limit=limit, since_token=since_token + ) else: data = await handler.get_local_public_room_list( limit=limit, since_token=since_token @@ -831,17 +832,15 @@ class PublicRoomListRestServlet(TransactionRestServlet): Codes.INVALID_PARAM, ) - try: - data = await handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - except HttpResponseException as e: - raise e.to_synapse_error() + data = await handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + else: data = await handler.get_local_public_room_list( limit=limit, diff --git a/synapse/server.py b/synapse/server.py index 095dba9ad0..6c867f0f47 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -34,8 +34,6 @@ from typing import ( ) import twisted.internet.tcp -from twisted.internet import defer -from twisted.mail.smtp import sendmail from twisted.web.iweb import IPolicyForHTTPS from twisted.web.resource import IResource @@ -443,10 +441,6 @@ class HomeServer(metaclass=abc.ABCMeta): return RoomShutdownHandler(self) @cache_in_self - def get_sendmail(self) -> Callable[..., defer.Deferred]: - return sendmail - - @cache_in_self def get_state_handler(self) -> StateHandler: return StateHandler(self) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c8015a3848..95d2caff62 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -941,13 +941,13 @@ class DatabasePool: `lock` should generally be set to True (the default), but can be set to False if either of the following are true: - - * there is a UNIQUE INDEX on the key columns. In this case a conflict - will cause an IntegrityError in which case this function will retry - the update. - - * we somehow know that we are the only thread which will be updating - this table. + 1. there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + 2. we somehow know that we are the only thread which will be updating + this table. + As an additional note, this parameter only matters for old SQLite versions + because we will use native upserts otherwise. Args: table: The table to upsert into diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1edc96042b..1f0a39eac4 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -755,81 +755,145 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): """ @trace - def _claim_e2e_one_time_keys(txn): - sql = ( - "SELECT key_id, key_json FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + def _claim_e2e_one_time_key_simple( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that don't support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + sql = """ + SELECT key_id, key_json FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + """ + + txn.execute(sql, (user_id, device_id, algorithm)) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + + self.db_pool.simple_delete_one_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, ) - fallback_sql = ( - "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - result = {} - delete = [] - used_fallbacks = [] - for user_id, device_id, algorithm in query_list: - user_result = result.setdefault(user_id, {}) - device_result = user_result.setdefault(device_id, {}) - txn.execute(sql, (user_id, device_id, algorithm)) - otk_row = txn.fetchone() - if otk_row is not None: - key_id, key_json = otk_row - device_result[algorithm + ":" + key_id] = key_json - delete.append((user_id, device_id, algorithm, key_id)) - else: - # no one-time key available, so see if there's a fallback - # key - txn.execute(fallback_sql, (user_id, device_id, algorithm)) - fallback_row = txn.fetchone() - if fallback_row is not None: - key_id, key_json, used = fallback_row - device_result[algorithm + ":" + key_id] = key_json - if not used: - used_fallbacks.append( - (user_id, device_id, algorithm, key_id) - ) - - # drop any one-time keys that were claimed - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " AND key_id = ?" + + return f"{algorithm}:{key_id}", key_json + + @trace + def _claim_e2e_one_time_key_returning( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + # We can use RETURNING to do the fetch and DELETE in once step. + sql = """ + DELETE FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + AND key_id IN ( + SELECT key_id FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + ) + RETURNING key_id, key_json + """ + + txn.execute( + sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) ) - for user_id, device_id, algorithm, key_id in delete: - log_kv( - { - "message": "Executing claim e2e_one_time_keys transaction on database." - } - ) - txn.execute(sql, (user_id, device_id, algorithm, key_id)) - log_kv({"message": "finished executing and invalidating cache"}) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + return f"{algorithm}:{key_id}", key_json + + results = {} + for user_id, device_id, algorithm in query_list: + if self.database_engine.supports_returning: + # If we support RETURNING clause we can use a single query that + # allows us to use autocommit mode. + _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning + db_autocommit = True + else: + _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple + db_autocommit = False + + row = await self.db_pool.runInteraction( + "claim_e2e_one_time_keys", + _claim_e2e_one_time_key, + user_id, + device_id, + algorithm, + db_autocommit=db_autocommit, + ) + if row: + device_results = results.setdefault(user_id, {}).setdefault( + device_id, {} ) - # mark fallback keys as used - for user_id, device_id, algorithm, key_id in used_fallbacks: - self.db_pool.simple_update_txn( - txn, - "e2e_fallback_keys_json", - { + device_results[row[0]] = row[1] + continue + + # No one-time key available, so see if there's a fallback + # key + row = await self.db_pool.simple_select_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + retcols=("key_id", "key_json", "used"), + desc="_get_fallback_key", + allow_none=True, + ) + if row is None: + continue + + key_id = row["key_id"] + key_json = row["key_json"] + used = row["used"] + + # Mark fallback key as used if not already. + if not used: + await self.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, "key_id": key_id, }, - {"used": True}, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", ) - self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) ) - return result + device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) + device_results[f"{algorithm}:{key_id}"] = key_json - return await self.db_pool.runInteraction( - "claim_e2e_one_time_keys", _claim_e2e_one_time_keys - ) + return results class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 44018c1c31..bddf5ef192 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -671,27 +671,97 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - async def get_oldest_events_with_depth_in_room(self, room_id): + async def get_oldest_event_ids_with_depth_in_room(self, room_id) -> Dict[str, int]: + """Gets the oldest events(backwards extremities) in the room along with the + aproximate depth. + + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + event extremeties. If the current depth is close to the depth of given + oldest event, we can trigger a backfill. + + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): + # Assemble a dictionary with event_id -> depth for the oldest events + # we know of in the room. Backwards extremeties are the oldest + # events we know of in the room but we only know of them because + # some other event referenced them by prev_event and aren't peristed + # in our database yet (meaning we don't know their depth + # specifically). So we need to look for the aproximate depth from + # the events connected to the current backwards extremeties. + sql = """ + SELECT b.event_id, MAX(e.depth) FROM events as e + /** + * Get the edge connections from the event_edges table + * so we can see whether this event's prev_events points + * to a backward extremity in the next join. + */ + INNER JOIN event_edges as g + ON g.event_id = e.event_id + /** + * We find the "oldest" events in the room by looking for + * events connected to backwards extremeties (oldest events + * in the room that we know of so far). + */ + INNER JOIN event_backward_extremities as b + ON g.prev_event_id = b.event_id + WHERE b.room_id = ? AND g.is_state is ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id, False)) + + return dict(txn) + return await self.db_pool.runInteraction( - "get_oldest_events_with_depth_in_room", - self.get_oldest_events_with_depth_in_room_txn, + "get_oldest_event_ids_with_depth_in_room", + get_oldest_event_ids_with_depth_in_room_txn, room_id, ) - def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): - sql = ( - "SELECT b.event_id, MAX(e.depth) FROM events as e" - " INNER JOIN event_edges as g" - " ON g.event_id = e.event_id" - " INNER JOIN event_backward_extremities as b" - " ON g.prev_event_id = b.event_id" - " WHERE b.room_id = ? AND g.is_state is ?" - " GROUP BY b.event_id" - ) + async def get_insertion_event_backwards_extremities_in_room( + self, room_id + ) -> Dict[str, int]: + """Get the insertion events we know about that we haven't backfilled yet. - txn.execute(sql, (room_id, False)) + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + insertion event. If the current depth is close to the depth of given + insertion event, we can trigger a backfill. - return dict(txn) + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id): + sql = """ + SELECT b.event_id, MAX(e.depth) FROM insertion_events as i + /* We only want insertion events that are also marked as backwards extremities */ + INNER JOIN insertion_event_extremities as b USING (event_id) + /* Get the depth of the insertion event from the events table */ + INNER JOIN events AS e USING (event_id) + WHERE b.room_id = ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id,)) + + return dict(txn) + + return await self.db_pool.runInteraction( + "get_insertion_event_backwards_extremities_in_room", + get_insertion_event_backwards_extremities_in_room_txn, + room_id, + ) async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs @@ -1041,7 +1111,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas if row[1] not in event_results: queue.put((-row[0], row[1])) - # Navigate up the DAG by prev_event txn.execute(query, (event_id, False, limit - len(event_results))) prev_event_id_results = txn.fetchall() logger.debug( @@ -1136,6 +1205,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) + async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None: + await self.db_pool.simple_upsert( + table="insertion_event_extremities", + keyvalues={"event_id": event_id}, + values={ + "event_id": event_id, + "room_id": room_id, + }, + insertion_values={}, + desc="insert_insertion_extremity", + lock=False, + ) + async def insert_received_event_to_staging( self, origin: str, event: EventBase ) -> None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 86baf397fb..40b53274fb 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1845,6 +1845,18 @@ class PersistEventsStore: }, ) + # When we receive an event with a `chunk_id` referencing the + # `next_chunk_id` of the insertion event, we can remove it from the + # `insertion_event_extremities` table. + sql = """ + DELETE FROM insertion_event_extremities WHERE event_id IN ( + SELECT event_id FROM insertion_events + WHERE next_chunk_id = ? + ) + """ + + txn.execute(sql, (chunk_id,)) + def _handle_redaction(self, txn, redacted_event_id): """Handles receiving a redaction and checking whether we need to remove any redacted relations from the database. @@ -2101,15 +2113,17 @@ class PersistEventsStore: Forward extremities are handled when we first start persisting the events. """ + # From the events passed in, add all of the prev events as backwards extremities. + # Ignore any events that are already backwards extrems or outliers. query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" " )" " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" " )" ) @@ -2123,6 +2137,8 @@ class PersistEventsStore: ], ) + # Delete all these events that we've already fetched and now know that their + # prev events are the new backwards extremeties. query = ( "DELETE FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 3c86adab56..375463e4e9 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -14,7 +14,6 @@ import logging import threading -from collections import namedtuple from typing import ( Collection, Container, @@ -27,6 +26,7 @@ from typing import ( overload, ) +import attr from constantly import NamedConstant, Names from typing_extensions import Literal @@ -42,7 +42,11 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event -from synapse.logging.context import PreserveLoggingContext, current_context +from synapse.logging.context import ( + PreserveLoggingContext, + current_context, + make_deferred_yieldable, +) from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -56,6 +60,8 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -74,7 +80,10 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +@attr.s(slots=True, auto_attribs=True) +class _EventCacheEntry: + event: EventBase + redacted_event: Optional[EventBase] class EventRedactBehaviour(Names): @@ -161,6 +170,13 @@ class EventsWorkerStore(SQLBaseStore): max_size=hs.config.caches.event_cache_size, ) + # Map from event ID to a deferred that will result in a map from event + # ID to cache entry. Note that the returned dict may not have the + # requested event in it if the event isn't in the DB. + self._current_event_fetches: Dict[ + str, ObservableDeferred[Dict[str, _EventCacheEntry]] + ] = {} + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -476,7 +492,9 @@ class EventsWorkerStore(SQLBaseStore): return events - async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db( + self, event_ids: Iterable[str], allow_rejected: bool = False + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -485,53 +503,107 @@ class EventsWorkerStore(SQLBaseStore): Args: - event_ids (Iterable[str]): The event_ids of the events to fetch + event_ids: The event_ids of the events to fetch - allow_rejected (bool): Whether to include rejected events. If False, + allow_rejected: Whether to include rejected events. If False, rejected events are omitted from the response. Returns: - Dict[str, _EventCacheEntry]: - map from event id to result + map from event id to result """ event_entry_map = self._get_events_from_cache( - event_ids, allow_rejected=allow_rejected + event_ids, ) - missing_events_ids = [e for e in event_ids if e not in event_entry_map] + missing_events_ids = {e for e in event_ids if e not in event_entry_map} + + # We now look up if we're already fetching some of the events in the DB, + # if so we wait for those lookups to finish instead of pulling the same + # events out of the DB multiple times. + already_fetching: Dict[str, defer.Deferred] = {} + + for event_id in missing_events_ids: + deferred = self._current_event_fetches.get(event_id) + if deferred is not None: + # We're already pulling the event out of the DB. Add the deferred + # to the collection of deferreds to wait on. + already_fetching[event_id] = deferred.observe() + + missing_events_ids.difference_update(already_fetching) if missing_events_ids: log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) + # Add entries to `self._current_event_fetches` for each event we're + # going to pull from the DB. We use a single deferred that resolves + # to all the events we pulled from the DB (this will result in this + # function returning more events than requested, but that can happen + # already due to `_get_events_from_db`). + fetching_deferred: ObservableDeferred[ + Dict[str, _EventCacheEntry] + ] = ObservableDeferred(defer.Deferred()) + for event_id in missing_events_ids: + self._current_event_fetches[event_id] = fetching_deferred + # Note that _get_events_from_db is also responsible for turning db rows # into FrozenEvents (via _get_event_from_row), which involves seeing if # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = await self._get_events_from_db( - missing_events_ids, allow_rejected=allow_rejected - ) + try: + missing_events = await self._get_events_from_db( + missing_events_ids, + ) - event_entry_map.update(missing_events) + event_entry_map.update(missing_events) + except Exception as e: + with PreserveLoggingContext(): + fetching_deferred.errback(e) + raise e + finally: + # Ensure that we mark these events as no longer being fetched. + for event_id in missing_events_ids: + self._current_event_fetches.pop(event_id, None) + + with PreserveLoggingContext(): + fetching_deferred.callback(missing_events) + + if already_fetching: + # Wait for the other event requests to finish and add their results + # to ours. + results = await make_deferred_yieldable( + defer.gatherResults( + already_fetching.values(), + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) + + for result in results: + event_entry_map.update(result) + + if not allow_rejected: + event_entry_map = { + event_id: entry + for event_id, entry in event_entry_map.items() + if not entry.event.rejected_reason + } return event_entry_map def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches + def _get_events_from_cache( + self, events: Iterable[str], update_metrics: bool = True + ) -> Dict[str, _EventCacheEntry]: + """Fetch events from the caches. - Args: - events (Iterable[str]): list of event_ids to fetch - allow_rejected (bool): Whether to return events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics + May return rejected events. - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` + Args: + events: list of event_ids to fetch + update_metrics: Whether to update the cache hit ratio metrics """ event_map = {} @@ -542,10 +614,7 @@ class EventsWorkerStore(SQLBaseStore): if not ret: continue - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None + event_map[event_id] = ret return event_map @@ -672,23 +741,23 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - async def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db( + self, event_ids: Iterable[str] + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the database. + May return rejected events. + Returned events will be added to the cache for future lookups. Unknown events are omitted from the response. Args: - event_ids (Iterable[str]): The event_ids of the events to fetch - - allow_rejected (bool): Whether to include rejected events. If False, - rejected events are omitted from the response. + event_ids: The event_ids of the events to fetch Returns: - Dict[str, _EventCacheEntry]: - map from event id to result. May return extra events which - weren't asked for. + map from event id to result. May return extra events which + weren't asked for. """ fetched_events = {} events_to_fetch = event_ids @@ -717,9 +786,6 @@ class EventsWorkerStore(SQLBaseStore): rejected_reason = row["rejected_reason"] - if not allow_rejected and rejected_reason: - continue - # If the event or metadata cannot be parsed, log the error and act # as if the event is unknown. try: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6ad1a0cf7f..14670c2881 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID +from synapse.types import UserID, UserInfo from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + """Deprecated: use get_userinfo_by_id instead""" return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, @@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_user_by_id", ) + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get a UserInfo object for a user by user ID. + + Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, + this method should be cached. + + Args: + user_id: The user to fetch user info for. + Returns: + `UserInfo` object if user found, otherwise `None`. + """ + user_data = await self.get_user_by_id(user_id) + if not user_data: + return None + return UserInfo( + appservice_id=user_data["appservice_id"], + consent_server_notice_sent=user_data["consent_server_notice_sent"], + consent_version=user_data["consent_version"], + creation_ts=user_data["creation_ts"], + is_admin=bool(user_data["admin"]), + is_deactivated=bool(user_data["deactivated"]), + is_guest=bool(user_data["is_guest"]), + is_shadow_banned=bool(user_data["shadow_banned"]), + user_id=UserID.from_string(user_data["name"]), + user_type=user_data["user_type"], + ) + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 68f1b40ea6..e8157ba3d4 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -629,14 +629,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't update the event cache hit ratio as it completely throws off # the hit ratio counts. After all, we don't populate the cache if we # miss it here - event_map = self._get_events_from_cache( - member_event_ids, allow_rejected=False, update_metrics=False - ) + event_map = self._get_events_from_cache(member_event_ids, update_metrics=False) missing_member_event_ids = [] for event_id in member_event_ids: ev_entry = event_map.get(event_id) - if ev_entry: + if ev_entry and not ev_entry.event.rejected_reason: if ev_entry.event.membership == Membership.JOIN: users_in_room[ev_entry.event.state_key] = ProfileInfo( display_name=ev_entry.event.content.get("displayname", None), diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 36340a652a..fd4dd67d91 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 61 +SCHEMA_VERSION = 62 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the diff --git a/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql b/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql new file mode 100644 index 0000000000..b731ef284a --- /dev/null +++ b/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql @@ -0,0 +1,24 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Add a table that keeps track of which "insertion" events need to be backfilled +CREATE TABLE IF NOT EXISTS insertion_event_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS insertion_event_extremities_event_id ON insertion_event_extremities(event_id); +CREATE INDEX IF NOT EXISTS insertion_event_extremities_room_id ON insertion_event_extremities(room_id); diff --git a/synapse/types.py b/synapse/types.py index 429bb013d2..80fa903c4b 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -751,3 +751,32 @@ def get_verify_key_from_cross_signing_key(key_info): # and return that one key for key_id, key_data in keys.items(): return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class UserInfo: + """Holds information about a user. Result of get_userinfo_by_id. + + Attributes: + user_id: ID of the user. + appservice_id: Application service ID that created this user. + consent_server_notice_sent: Version of policy documents the user has been sent. + consent_version: Version of policy documents the user has consented to. + creation_ts: Creation timestamp of the user. + is_admin: True if the user is an admin. + is_deactivated: True if the user has been deactivated. + is_guest: True if the user is a guest user. + is_shadow_banned: True if the user has been shadow-banned. + user_type: User type (None for normal user, 'support' and 'bot' other options). + """ + + user_id: UserID + appservice_id: Optional[int] + consent_server_notice_sent: Optional[str] + consent_version: Optional[str] + user_type: Optional[str] + creation_ts: int + is_admin: bool + is_deactivated: bool + is_guest: bool + is_shadow_banned: bool diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py deleted file mode 100644 index abc12f0837..0000000000 --- a/synapse/util/jsonobject.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class JsonEncodedObject: - """A common base class for defining protocol units that are represented - as JSON. - - Attributes: - unrecognized_keys (dict): A dict containing all the key/value pairs we - don't recognize. - """ - - valid_keys = [] # keys we will store - """A list of strings that represent keys we know about - and can handle. If we have values for these keys they will be - included in the `dictionary` instance variable. - """ - - internal_keys = [] # keys to ignore while building dict - """A list of strings that should *not* be encoded into JSON. - """ - - required_keys = [] - """A list of strings that we require to exist. If they are not given upon - construction it raises an exception. - """ - - def __init__(self, **kwargs): - """Takes the dict of `kwargs` and loads all keys that are *valid* - (i.e., are included in the `valid_keys` list) into the dictionary` - instance variable. - - Any keys that aren't recognized are added to the `unrecognized_keys` - attribute. - - Args: - **kwargs: Attributes associated with this protocol unit. - """ - for required_key in self.required_keys: - if required_key not in kwargs: - raise RuntimeError("Key %s is required" % required_key) - - self.unrecognized_keys = {} # Keys we were given not listed as valid - for k, v in kwargs.items(): - if k in self.valid_keys or k in self.internal_keys: - self.__dict__[k] = v - else: - self.unrecognized_keys[k] = v - - def get_dict(self): - """Converts this protocol unit into a :py:class:`dict`, ready to be - encoded as JSON. - - The keys it encodes are: `valid_keys` - `internal_keys` - - Returns - dict - """ - d = { - k: _encode(v) - for (k, v) in self.__dict__.items() - if k in self.valid_keys and k not in self.internal_keys - } - d.update(self.unrecognized_keys) - return d - - def get_internal_dict(self): - d = { - k: _encode(v, internal=True) - for (k, v) in self.__dict__.items() - if k in self.valid_keys - } - d.update(self.unrecognized_keys) - return d - - def __str__(self): - return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) - - -def _encode(obj, internal=False): - if type(obj) is list: - return [_encode(o, internal=internal) for o in obj] - - if isinstance(obj, JsonEncodedObject): - if internal: - return obj.get_internal_dict() - else: - return obj.get_dict() - - return obj diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 024c5e963c..43998020b2 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -133,11 +133,131 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.assertEquals(result.room_id, room_id) self.assertEquals(result.servers, servers) - def _mkservice(self, is_interested): + def test_get_3pe_protocols_no_appservices(self): + self.mock_store.get_app_services.return_value = [] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_no_protocols(self): + service = self._mkservice(False, []) + self.mock_store.get_app_services.return_value = [service] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_protocol_no_response(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals(response, {}) + + def test_get_3pe_protocols_select_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_multiple_protocol(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["other-protocol"]) + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called() + self.assertEquals( + response, + { + "my-protocol": {"x-protocol-data": 42, "instances": []}, + "other-protocol": {"x-protocol-data": 42, "instances": []}, + }, + ) + + def test_get_3pe_protocols_multiple_info(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["my-protocol"]) + + async def get_3pe_protocol(service, unusedProtocol): + if service == service_one: + return { + "x-protocol-data": 42, + "instances": [{"desc": "Alice's service"}], + } + if service == service_two: + return { + "x-protocol-data": 36, + "x-not-used": 45, + "instances": [{"desc": "Bob's service"}], + } + raise Exception("Unexpected service") + + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol = get_3pe_protocol + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + # It's expected that the second service's data doesn't appear in the response + self.assertEquals( + response, + { + "my-protocol": { + "x-protocol-data": 42, + "instances": [ + { + "desc": "Alice's service", + }, + {"desc": "Bob's service"}, + ], + }, + }, + ) + + def _mkservice(self, is_interested, protocols=None): service = Mock() service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" + service.protocols = protocols return service def _mkservice_alias(self, is_interested_in_alias): diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 01975c13d4..6cc1a02e12 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -26,7 +26,7 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict -from synapse.handlers.space_summary import _child_events_comparison_key +from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEntry from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.server import HomeServer @@ -351,26 +351,30 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # events before child events). # Note that these entries are brief, but should contain enough info. - rooms = [ - { - "room_id": subspace, - "world_readable": True, - "room_type": RoomTypes.SPACE, - }, - { - "room_id": subroom, - "world_readable": True, - }, - ] - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": subroom, - "content": event_content, - }, + return [ + _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "room_id": subspace, + "state_key": subroom, + "content": {"via": [fed_hostname]}, + } + ], + ), + _RoomEntry( + subroom, + { + "room_id": subroom, + "world_readable": True, + }, + ), ] - return rooms, events # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) @@ -436,70 +440,95 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ): # Note that these entries are brief, but should contain enough info. rooms = [ - { - "room_id": public_room, - "world_readable": False, - "join_rules": JoinRules.PUBLIC, - }, - { - "room_id": knock_room, - "world_readable": False, - "join_rules": JoinRules.KNOCK, - }, - { - "room_id": not_invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": restricted_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [], - }, - { - "room_id": restricted_accessible_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [self.room], - }, - { - "room_id": world_readable_room, - "world_readable": True, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": joined_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ] - - # Place each room in the sub-space. - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": room["room_id"], - "content": event_content, - } - for room in rooms + _RoomEntry( + public_room, + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + ), + _RoomEntry( + knock_room, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + ), + _RoomEntry( + not_invited_room, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + invited_room, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + restricted_room, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [], + }, + ), + _RoomEntry( + restricted_accessible_room, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [self.room], + }, + ), + _RoomEntry( + world_readable_room, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + joined_room, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), ] # Also include the subspace. rooms.insert( 0, - { - "room_id": subspace, - "world_readable": True, - }, + _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + }, + # Place each room in the sub-space. + [ + { + "room_id": subspace, + "state_key": room.room_id, + "content": {"via": [fed_hostname]}, + } + for room in rooms + ], + ), ) - return rooms, events + return rooms # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 81d9e2f484..0b817cc701 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase): displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + def test_get_userinfo_by_id(self): + user_id = self.register_user("alice", "1234") + found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + self.assertEqual(found_user.user_id.to_string(), user_id) + self.assertIdentical(found_user.is_admin, False) + + def test_get_userinfo_by_id__no_user_found(self): + found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test")) + self.assertIsNone(found_user) + def test_sending_events_into_room(self): """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent diff --git a/tests/push/test_email.py b/tests/push/test_email.py index e04bc5c9a6..a487706758 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -45,14 +45,6 @@ class EmailPusherTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): - # List[Tuple[Deferred, args, kwargs]] - self.email_attempts = [] - - def sendmail(*args, **kwargs): - d = Deferred() - self.email_attempts.append((d, args, kwargs)) - return d - config = self.default_config() config["email"] = { "enable_notifs": True, @@ -75,7 +67,17 @@ class EmailPusherTests(HomeserverTestCase): config["public_baseurl"] = "aaa" config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + # List[Tuple[Deferred, args, kwargs]] + self.email_attempts = [] + + def sendmail(*args, **kwargs): + d = Deferred() + self.email_attempts.append((d, args, kwargs)) + return d + + hs.get_send_email_handler()._sendmail = sendmail return hs diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 3df070c936..1a9528ec20 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -19,11 +19,14 @@ import json from typing import Iterable -from unittest.mock import Mock +from unittest.mock import Mock, call from urllib import parse as urlparse +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.errors import HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client.v1 import directory, login, profile, room @@ -1124,6 +1127,93 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) +class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): + """Test that we correctly fallback to local filtering if a remote server + doesn't support search. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(federation_client=Mock()) + + def prepare(self, reactor, clock, hs): + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.federation_client = hs.get_federation_client() + + def test_simple(self): + "Simple test for searching rooms over federation" + self.federation_client.get_public_rooms.side_effect = ( + lambda *a, **k: defer.succeed({}) + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_called_once_with( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ) + + def test_fallback(self): + "Test that searching public rooms over federation falls back if it gets a 404" + + # The `get_public_rooms` should be called again if the first call fails + # with a 404, when using search filters. + self.federation_client.get_public_rooms.side_effect = ( + HttpResponseException(404, "Not Found", b""), + defer.succeed({}), + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_has_calls( + [ + call( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ), + call( + "testserv", + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ), + ] + ) + + class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 317a2287e3..e7e617e9df 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -47,12 +47,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - return - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -67,7 +61,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "https://example.com" - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + hs.get_send_email_handler()._sendmail = sendmail + return hs def prepare(self, reactor, clock, hs): @@ -511,11 +514,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -530,7 +528,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "https://example.com" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + return self.hs def prepare(self, reactor, clock, hs): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 1cad5f00eb..a52e5e608a 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -509,10 +509,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): } # Email config. - self.email_attempts = [] - - async def sendmail(*args, **kwargs): - self.email_attempts.append((args, kwargs)) config["email"] = { "enable_notifs": True, @@ -532,7 +528,13 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "aaa" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail(*args, **kwargs): + self.email_attempts.append((args, kwargs)) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail self.store = self.hs.get_datastore() diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 932970fd9a..d05d367685 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -14,7 +14,10 @@ import json from synapse.logging.context import LoggingContext +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util.async_helpers import yieldable_gather_results from tests import unittest @@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): res = self.get_success(self.store.have_seen_events("room1", ["event10"])) self.assertEquals(res, {"event10"}) self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + + +class EventCacheTestCase(unittest.HomeserverTestCase): + """Test that the various layers of event cache works.""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store: EventsWorkerStore = hs.get_datastore() + + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + + self.room = self.helper.create_room_as(self.user, tok=self.token) + + res = self.helper.send(self.room, tok=self.token) + self.event_id = res["event_id"] + + # Reset the event cache so the tests start with it empty + self.store._get_event_cache.clear() + + def test_simple(self): + """Test that we cache events that we pull from the DB.""" + + with LoggingContext("test") as ctx: + self.get_success(self.store.get_event(self.event_id)) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + + def test_dedupe(self): + """Test that if we request the same event multiple times we only pull it + out once. + """ + + with LoggingContext("test") as ctx: + d = yieldable_gather_results( + self.store.get_event, [self.event_id, self.event_id] + ) + self.get_success(d) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) diff --git a/tests/test_federation.py b/tests/test_federation.py index 0ed8326f55..3785799f46 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,10 +75,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) self.handler = self.homeserver.get_federation_handler() - self.handler._check_event_auth = ( - lambda origin, event, context, state, auth_events, backfilled: succeed( - context - ) + self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( + context ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( |