diff --git a/changelog.d/6455.feature b/changelog.d/6455.feature
new file mode 100644
index 0000000000..eb286cb70f
--- /dev/null
+++ b/changelog.d/6455.feature
@@ -0,0 +1 @@
+Include room states on invite events that are sent to application services. Contributed by @Sorunome.
diff --git a/changelog.d/7798.feature b/changelog.d/7798.feature
new file mode 100644
index 0000000000..56ffaf0d4a
--- /dev/null
+++ b/changelog.d/7798.feature
@@ -0,0 +1 @@
+Add experimental support for running multiple federation sender processes.
diff --git a/changelog.d/7802.misc b/changelog.d/7802.misc
new file mode 100644
index 0000000000..d81f8875c5
--- /dev/null
+++ b/changelog.d/7802.misc
@@ -0,0 +1 @@
+ Switch from simplejson to the standard library json.
diff --git a/changelog.d/7813.misc b/changelog.d/7813.misc
new file mode 100644
index 0000000000..f3005cfd27
--- /dev/null
+++ b/changelog.d/7813.misc
@@ -0,0 +1 @@
+Add type hints to the http server code and remove an unused parameter.
diff --git a/changelog.d/7815.bugfix b/changelog.d/7815.bugfix
new file mode 100644
index 0000000000..3e7c7d412e
--- /dev/null
+++ b/changelog.d/7815.bugfix
@@ -0,0 +1 @@
+Fix detection of out of sync remote device lists when receiving events from remote users.
diff --git a/changelog.d/7817.bugfix b/changelog.d/7817.bugfix
new file mode 100644
index 0000000000..1c001070d5
--- /dev/null
+++ b/changelog.d/7817.bugfix
@@ -0,0 +1 @@
+Fix bug where Synapse fails to process an incoming event over federation if the server is missing too much of the event's auth chain.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 164a104045..1a2d9fb153 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -118,38 +118,6 @@ pid_file: DATADIR/homeserver.pid
#
#enable_search: false
-# Restrict federation to the following whitelist of domains.
-# N.B. we recommend also firewalling your federation listener to limit
-# inbound federation traffic as early as possible, rather than relying
-# purely on this application-layer restriction. If not specified, the
-# default is to whitelist everything.
-#
-#federation_domain_whitelist:
-# - lon.example.com
-# - nyc.example.com
-# - syd.example.com
-
-# Prevent federation requests from being sent to the following
-# blacklist IP address CIDR ranges. If this option is not specified, or
-# specified with an empty list, no ip range blacklist will be enforced.
-#
-# As of Synapse v1.4.0 this option also affects any outbound requests to identity
-# servers provided by user input.
-#
-# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
-# listed here, since they correspond to unroutable addresses.)
-#
-federation_ip_range_blacklist:
- - '127.0.0.0/8'
- - '10.0.0.0/8'
- - '172.16.0.0/12'
- - '192.168.0.0/16'
- - '100.64.0.0/10'
- - '169.254.0.0/16'
- - '::1/128'
- - 'fe80::/64'
- - 'fc00::/7'
-
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
@@ -608,6 +576,39 @@ acme:
+# Restrict federation to the following whitelist of domains.
+# N.B. we recommend also firewalling your federation listener to limit
+# inbound federation traffic as early as possible, rather than relying
+# purely on this application-layer restriction. If not specified, the
+# default is to whitelist everything.
+#
+#federation_domain_whitelist:
+# - lon.example.com
+# - nyc.example.com
+# - syd.example.com
+
+# Prevent federation requests from being sent to the following
+# blacklist IP address CIDR ranges. If this option is not specified, or
+# specified with an empty list, no ip range blacklist will be enforced.
+#
+# As of Synapse v1.4.0 this option also affects any outbound requests to identity
+# servers provided by user input.
+#
+# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+# listed here, since they correspond to unroutable addresses.)
+#
+federation_ip_range_blacklist:
+ - '127.0.0.0/8'
+ - '10.0.0.0/8'
+ - '172.16.0.0/12'
+ - '192.168.0.0/16'
+ - '100.64.0.0/10'
+ - '169.254.0.0/16'
+ - '::1/128'
+ - 'fe80::/64'
+ - 'fc00::/7'
+
+
## Caching ##
# Caching can be configured through the following options.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 5305038c21..d5d4522336 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -15,13 +15,11 @@
# limitations under the License.
"""Contains exceptions and error codes."""
-
+import json
import logging
from http import HTTPStatus
from typing import Dict, List
-from canonicaljson import json
-
from twisted.web import http
logger = logging.getLogger(__name__)
@@ -573,7 +571,7 @@ class HttpResponseException(CodeMessageException):
# try to parse the body as json, to get better errcode/msg, but
# default to M_UNKNOWN with the HTTP status as the error text
try:
- j = json.loads(self.response)
+ j = json.loads(self.response.decode("utf-8"))
except ValueError:
j = {}
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f6792d9fc8..e90695f026 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -511,25 +511,7 @@ class GenericWorkerSlavedStore(
SearchWorkerStore,
BaseSlavedStore,
):
- def __init__(self, database, db_conn, hs):
- super(GenericWorkerSlavedStore, self).__init__(database, db_conn, hs)
-
- # We pull out the current federation stream position now so that we
- # always have a known value for the federation position in memory so
- # that we don't have to bounce via a deferred once when we start the
- # replication streams.
- self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
-
- def _get_federation_out_pos(self, db_conn):
- sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?"
- sql = self.database_engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
- txn.execute(sql, ("federation",))
- rows = txn.fetchall()
- txn.close()
-
- return rows[0][0] if rows else -1
+ pass
class GenericWorkerServer(HomeServer):
@@ -812,19 +794,11 @@ class FederationSenderHandler(object):
self.federation_sender = hs.get_federation_sender()
self._hs = hs
- # if the worker is restarted, we want to pick up where we left off in
- # the replication stream, so load the position from the database.
- #
- # XXX is this actually worthwhile? Whenever the master is restarted, we'll
- # drop some rows anyway (which is mostly fine because we're only dropping
- # typing and presence notifications). If the replication stream is
- # unreliable, why do we do all this hoop-jumping to store the position in the
- # database? See also https://github.com/matrix-org/synapse/issues/7535.
- #
- self.federation_position = self.store.federation_out_pos_startup
+ # Stores the latest position in the federation stream we've gotten up
+ # to. This is always set before we use it.
+ self.federation_position = None
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
- self._last_ack = self.federation_position
def on_start(self):
# There may be some events that are persisted but haven't been sent,
@@ -932,7 +906,6 @@ class FederationSenderHandler(object):
# We ACK this token over replication so that the master can drop
# its in memory queues
self._hs.get_tcp_replication().send_federation_ack(current_position)
- self._last_ack = current_position
except Exception:
logger.exception("Error updating federation stream position")
@@ -960,7 +933,7 @@ def start(config_options):
)
if config.worker_app == "synapse.app.appservice":
- if config.notify_appservices:
+ if config.appservice.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -970,13 +943,13 @@ def start(config_options):
sys.exit(1)
# Force the appservice to start since they will be disabled in the main config
- config.notify_appservices = True
+ config.appservice.notify_appservices = True
else:
# For other worker types we force this to off.
- config.notify_appservices = False
+ config.appservice.notify_appservices = False
if config.worker_app == "synapse.app.pusher":
- if config.start_pushers:
+ if config.server.start_pushers:
sys.stderr.write(
"\nThe pushers must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -986,13 +959,13 @@ def start(config_options):
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
- config.start_pushers = True
+ config.server.start_pushers = True
else:
# For other worker types we force this to off.
- config.start_pushers = False
+ config.server.start_pushers = False
if config.worker_app == "synapse.app.user_dir":
- if config.update_user_directory:
+ if config.server.update_user_directory:
sys.stderr.write(
"\nThe update_user_directory must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -1002,13 +975,13 @@ def start(config_options):
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
- config.update_user_directory = True
+ config.server.update_user_directory = True
else:
# For other worker types we force this to off.
- config.update_user_directory = False
+ config.server.update_user_directory = False
if config.worker_app == "synapse.app.federation_sender":
- if config.send_federation:
+ if config.federation.send_federation:
sys.stderr.write(
"\nThe send_federation must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -1018,10 +991,10 @@ def start(config_options):
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
- config.send_federation = True
+ config.federation.send_federation = True
else:
# For other worker types we force this to off.
- config.send_federation = False
+ config.federation.send_federation = False
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index f92bfb420b..1e0e4d497d 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -19,7 +19,7 @@ from prometheus_client import Counter
from twisted.internet import defer
-from synapse.api.constants import ThirdPartyEntityKind
+from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
@@ -207,7 +207,7 @@ class ApplicationServiceApi(SimpleHttpClient):
if service.url is None:
return True
- events = self._serialize(events)
+ events = self._serialize(service, events)
if txn_id is None:
logger.warning(
@@ -233,6 +233,18 @@ class ApplicationServiceApi(SimpleHttpClient):
failed_transactions_counter.labels(service.id).inc()
return False
- def _serialize(self, events):
+ def _serialize(self, service, events):
time_now = self.clock.time_msec()
- return [serialize_event(e, time_now, as_client_event=True) for e in events]
+ return [
+ serialize_event(
+ e,
+ time_now,
+ as_client_event=True,
+ is_invite=(
+ e.type == EventTypes.Member
+ and e.membership == "invite"
+ and service.is_interested_in_user(e.state_key)
+ ),
+ )
+ for e in events
+ ]
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
new file mode 100644
index 0000000000..7782ab4c9d
--- /dev/null
+++ b/synapse/config/federation.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from hashlib import sha256
+from typing import List, Optional
+
+import attr
+from netaddr import IPSet
+
+from ._base import Config, ConfigError
+
+
+@attr.s
+class ShardedFederationSendingConfig:
+ """Algorithm for choosing which federation sender instance is responsible
+ for which destionation host.
+ """
+
+ instances = attr.ib(type=List[str])
+
+ def should_send_to(self, instance_name: str, destination: str) -> bool:
+ """Whether this instance is responsible for sending transcations for
+ the given host.
+ """
+
+ # If multiple federation senders are not defined we always return true.
+ if not self.instances or len(self.instances) == 1:
+ return True
+
+ # We shard by taking the hash, modulo it by the number of federation
+ # senders and then checking whether this instance matches the instance
+ # at that index.
+ #
+ # (Technically this introduces some bias and is not entirely uniform, but
+ # since the hash is so large the bias is ridiculously small).
+ dest_hash = sha256(destination.encode("utf8")).digest()
+ dest_int = int.from_bytes(dest_hash, byteorder="little")
+ remainder = dest_int % (len(self.instances))
+ return self.instances[remainder] == instance_name
+
+
+class FederationConfig(Config):
+ section = "federation"
+
+ def read_config(self, config, **kwargs):
+ # Whether to send federation traffic out in this process. This only
+ # applies to some federation traffic, and so shouldn't be used to
+ # "disable" federation
+ self.send_federation = config.get("send_federation", True)
+
+ federation_sender_instances = config.get("federation_sender_instances") or []
+ self.federation_shard_config = ShardedFederationSendingConfig(
+ federation_sender_instances
+ )
+
+ # FIXME: federation_domain_whitelist needs sytests
+ self.federation_domain_whitelist = None # type: Optional[dict]
+ federation_domain_whitelist = config.get("federation_domain_whitelist", None)
+
+ if federation_domain_whitelist is not None:
+ # turn the whitelist into a hash for speed of lookup
+ self.federation_domain_whitelist = {}
+
+ for domain in federation_domain_whitelist:
+ self.federation_domain_whitelist[domain] = True
+
+ self.federation_ip_range_blacklist = config.get(
+ "federation_ip_range_blacklist", []
+ )
+
+ # Attempt to create an IPSet from the given ranges
+ try:
+ self.federation_ip_range_blacklist = IPSet(
+ self.federation_ip_range_blacklist
+ )
+
+ # Always blacklist 0.0.0.0, ::
+ self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
+ except Exception as e:
+ raise ConfigError(
+ "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
+ )
+
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ return """\
+ # Restrict federation to the following whitelist of domains.
+ # N.B. we recommend also firewalling your federation listener to limit
+ # inbound federation traffic as early as possible, rather than relying
+ # purely on this application-layer restriction. If not specified, the
+ # default is to whitelist everything.
+ #
+ #federation_domain_whitelist:
+ # - lon.example.com
+ # - nyc.example.com
+ # - syd.example.com
+
+ # Prevent federation requests from being sent to the following
+ # blacklist IP address CIDR ranges. If this option is not specified, or
+ # specified with an empty list, no ip range blacklist will be enforced.
+ #
+ # As of Synapse v1.4.0 this option also affects any outbound requests to identity
+ # servers provided by user input.
+ #
+ # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+ # listed here, since they correspond to unroutable addresses.)
+ #
+ federation_ip_range_blacklist:
+ - '127.0.0.0/8'
+ - '10.0.0.0/8'
+ - '172.16.0.0/12'
+ - '192.168.0.0/16'
+ - '100.64.0.0/10'
+ - '169.254.0.0/16'
+ - '::1/128'
+ - 'fe80::/64'
+ - 'fc00::/7'
+ """
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 264c274c52..8e93d31394 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -23,6 +23,7 @@ from .cas import CasConfig
from .consent_config import ConsentConfig
from .database import DatabaseConfig
from .emailconfig import EmailConfig
+from .federation import FederationConfig
from .groups import GroupsConfig
from .jwt_config import JWTConfig
from .key import KeyConfig
@@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig):
config_classes = [
ServerConfig,
TlsConfig,
+ FederationConfig,
CacheConfig,
DatabaseConfig,
LoggingConfig,
@@ -90,4 +92,5 @@ class HomeServerConfig(RootConfig):
ThirdPartyRulesConfig,
TracerConfig,
RedisConfig,
+ FederationConfig,
]
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 8204664883..b6afa642ca 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -23,7 +23,6 @@ from typing import Any, Dict, Iterable, List, Optional
import attr
import yaml
-from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
@@ -136,11 +135,6 @@ class ServerConfig(Config):
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl")
- # Whether to send federation traffic out in this process. This only
- # applies to some federation traffic, and so shouldn't be used to
- # "disable" federation
- self.send_federation = config.get("send_federation", True)
-
# Whether to enable user presence.
self.use_presence = config.get("use_presence", True)
@@ -263,34 +257,6 @@ class ServerConfig(Config):
# due to resource constraints
self.admin_contact = config.get("admin_contact", None)
- # FIXME: federation_domain_whitelist needs sytests
- self.federation_domain_whitelist = None # type: Optional[dict]
- federation_domain_whitelist = config.get("federation_domain_whitelist", None)
-
- if federation_domain_whitelist is not None:
- # turn the whitelist into a hash for speed of lookup
- self.federation_domain_whitelist = {}
-
- for domain in federation_domain_whitelist:
- self.federation_domain_whitelist[domain] = True
-
- self.federation_ip_range_blacklist = config.get(
- "federation_ip_range_blacklist", []
- )
-
- # Attempt to create an IPSet from the given ranges
- try:
- self.federation_ip_range_blacklist = IPSet(
- self.federation_ip_range_blacklist
- )
-
- # Always blacklist 0.0.0.0, ::
- self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
- except Exception as e:
- raise ConfigError(
- "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
- )
-
if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
@@ -743,38 +709,6 @@ class ServerConfig(Config):
#
#enable_search: false
- # Restrict federation to the following whitelist of domains.
- # N.B. we recommend also firewalling your federation listener to limit
- # inbound federation traffic as early as possible, rather than relying
- # purely on this application-layer restriction. If not specified, the
- # default is to whitelist everything.
- #
- #federation_domain_whitelist:
- # - lon.example.com
- # - nyc.example.com
- # - syd.example.com
-
- # Prevent federation requests from being sent to the following
- # blacklist IP address CIDR ranges. If this option is not specified, or
- # specified with an empty list, no ip range blacklist will be enforced.
- #
- # As of Synapse v1.4.0 this option also affects any outbound requests to identity
- # servers provided by user input.
- #
- # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
- # listed here, since they correspond to unroutable addresses.)
- #
- federation_ip_range_blacklist:
- - '127.0.0.0/8'
- - '10.0.0.0/8'
- - '172.16.0.0/12'
- - '192.168.0.0/16'
- - '100.64.0.0/10'
- - '169.254.0.0/16'
- - '::1/128'
- - 'fe80::/64'
- - 'fc00::/7'
-
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c582355146..c0981eee62 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -65,14 +65,16 @@ def check(
room_id = event.room_id
- # I'm not really expecting to get auth events in the wrong room, but let's
- # sanity-check it
+ # We need to ensure that the auth events are actually for the same room, to
+ # stop people from using powers they've been granted in other rooms for
+ # example.
for auth_event in auth_events.values():
if auth_event.room_id != room_id:
- raise Exception(
+ raise AuthError(
+ 403,
"During auth for event %s in room %s, found event %s in the state "
"which is in room %s"
- % (event.event_id, room_id, auth_event.event_id, auth_event.room_id)
+ % (event.event_id, room_id, auth_event.event_id, auth_event.room_id),
)
if do_sig_check:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 86051decd4..2aab9c5f55 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -14,10 +14,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import logging
from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
-from canonicaljson import json
from prometheus_client import Counter, Histogram
from twisted.internet import defer
@@ -526,9 +526,9 @@ class FederationServer(FederationBase):
json_result = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
- for key_id, json_bytes in keys.items():
+ for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
- key_id: json.loads(json_bytes)
+ key_id: json.loads(json_str)
}
logger.info(
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 860b03f7b9..4fc9ff92e5 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -55,6 +55,11 @@ class FederationRemoteSendQueue(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
+ # We may have multiple federation sender instances, so we need to track
+ # their positions separately.
+ self._sender_instances = hs.config.federation.federation_shard_config.instances
+ self._sender_positions = {}
+
# Pending presence map user_id -> UserPresenceState
self.presence_map = {} # type: Dict[str, UserPresenceState]
@@ -261,7 +266,14 @@ class FederationRemoteSendQueue(object):
def get_current_token(self):
return self.pos - 1
- def federation_ack(self, token):
+ def federation_ack(self, instance_name, token):
+ if self._sender_instances:
+ # If we have configured multiple federation sender instances we need
+ # to track their positions separately, and only clear the queue up
+ # to the token all instances have acked.
+ self._sender_positions[instance_name] = token
+ token = min(self._sender_positions.values())
+
self._clear_queue_before_pos(token)
async def get_replication_rows(
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 464d7a41de..4b63a0755f 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -69,6 +69,9 @@ class FederationSender(object):
self._transaction_manager = TransactionManager(hs)
+ self._instance_name = hs.get_instance_name()
+ self._federation_shard_config = hs.config.federation.federation_shard_config
+
# map from destination to PerDestinationQueue
self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue]
@@ -191,7 +194,13 @@ class FederationSender(object):
)
return
- destinations = set(destinations)
+ destinations = {
+ d
+ for d in destinations
+ if self._federation_shard_config.should_send_to(
+ self._instance_name, d
+ )
+ }
if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server
@@ -322,7 +331,12 @@ class FederationSender(object):
# Work out which remote servers should be poked and poke them.
domains = yield self.state.get_current_hosts_in_room(room_id)
- domains = [d for d in domains if d != self.server_name]
+ domains = [
+ d
+ for d in domains
+ if d != self.server_name
+ and self._federation_shard_config.should_send_to(self._instance_name, d)
+ ]
if not domains:
return
@@ -427,6 +441,10 @@ class FederationSender(object):
for destination in destinations:
if destination == self.server_name:
continue
+ if not self._federation_shard_config.should_send_to(
+ self._instance_name, destination
+ ):
+ continue
self._get_per_destination_queue(destination).send_presence(states)
@measure_func("txnqueue._process_presence")
@@ -441,6 +459,12 @@ class FederationSender(object):
for destination in destinations:
if destination == self.server_name:
continue
+
+ if not self._federation_shard_config.should_send_to(
+ self._instance_name, destination
+ ):
+ continue
+
self._get_per_destination_queue(destination).send_presence(states)
def build_and_send_edu(
@@ -462,6 +486,11 @@ class FederationSender(object):
logger.info("Not sending EDU to ourselves")
return
+ if not self._federation_shard_config.should_send_to(
+ self._instance_name, destination
+ ):
+ return
+
edu = Edu(
origin=self.server_name,
destination=destination,
@@ -478,6 +507,11 @@ class FederationSender(object):
edu: edu to send
key: clobbering key for this edu
"""
+ if not self._federation_shard_config.should_send_to(
+ self._instance_name, edu.destination
+ ):
+ return
+
queue = self._get_per_destination_queue(edu.destination)
if key:
queue.send_keyed_edu(edu, key)
@@ -489,6 +523,11 @@ class FederationSender(object):
logger.warning("Not sending device update to ourselves")
return
+ if not self._federation_shard_config.should_send_to(
+ self._instance_name, destination
+ ):
+ return
+
self._get_per_destination_queue(destination).attempt_new_transaction()
def wake_destination(self, destination: str):
@@ -502,6 +541,11 @@ class FederationSender(object):
logger.warning("Not waking up ourselves")
return
+ if not self._federation_shard_config.should_send_to(
+ self._instance_name, destination
+ ):
+ return
+
self._get_per_destination_queue(destination).attempt_new_transaction()
@staticmethod
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 12966e239b..6402136e8a 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -74,6 +74,20 @@ class PerDestinationQueue(object):
self._clock = hs.get_clock()
self._store = hs.get_datastore()
self._transaction_manager = transaction_manager
+ self._instance_name = hs.get_instance_name()
+ self._federation_shard_config = hs.config.federation.federation_shard_config
+
+ self._should_send_on_this_instance = True
+ if not self._federation_shard_config.should_send_to(
+ self._instance_name, destination
+ ):
+ # We don't raise an exception here to avoid taking out any other
+ # processing. We have a guard in `attempt_new_transaction` that
+ # ensure we don't start sending stuff.
+ logger.error(
+ "Create a per destination queue for %s on wrong worker", destination,
+ )
+ self._should_send_on_this_instance = False
self._destination = destination
self.transmission_loop_running = False
@@ -180,6 +194,14 @@ class PerDestinationQueue(object):
logger.debug("TX [%s] Transaction already in progress", self._destination)
return
+ if not self._should_send_on_this_instance:
+ # We don't raise an exception here to avoid taking out any other
+ # processing.
+ logger.error(
+ "Trying to start a transaction to %s on wrong worker", self._destination
+ )
+ return
+
logger.debug("TX [%s] Starting transaction loop", self._destination)
run_as_background_process(
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index d79ffefdb5..786e608fa2 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -104,7 +104,7 @@ class CasHandler:
return user, displayname
def _parse_cas_response(
- self, cas_response_body: str
+ self, cas_response_body: bytes
) -> Tuple[str, Dict[str, Optional[str]]]:
"""
Retrieve the user and other parameters from the CAS response.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ca7da42a3f..e43bccd721 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -61,6 +61,7 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.logging.utils import log_function
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
@@ -618,6 +619,11 @@ class FederationHandler(BaseHandler):
will be omitted from the result. Likewise, any events which turn out not to
be in the given room.
+ This function *does not* automatically get missing auth events of the
+ newly fetched events. Callers must include the full auth chain of
+ of the missing events in the `event_ids` argument, to ensure that any
+ missing auth events are correctly fetched.
+
Returns:
map from event_id to event
"""
@@ -784,15 +790,25 @@ class FederationHandler(BaseHandler):
resync = True
if resync:
- await self.store.mark_remote_user_device_cache_as_stale(event.sender)
+ run_as_background_process(
+ "resync_device_due_to_pdu", self._resync_device, event.sender
+ )
- # Immediately attempt a resync in the background
- if self.config.worker_app:
- return run_in_background(self._user_device_resync, event.sender)
- else:
- return run_in_background(
- self._device_list_updater.user_device_resync, event.sender
- )
+ async def _resync_device(self, sender: str) -> None:
+ """We have detected that the device list for the given user may be out
+ of sync, so we try and resync them.
+ """
+
+ try:
+ await self.store.mark_remote_user_device_cache_as_stale(sender)
+
+ # Immediately attempt a resync in the background
+ if self.config.worker_app:
+ await self._user_device_resync(user_id=sender)
+ else:
+ await self._device_list_updater.user_device_resync(sender)
+ except Exception:
+ logger.exception("Failed to resync device for %s", sender)
@log_function
async def backfill(self, dest, room_id, limit, extremities):
@@ -1131,12 +1147,16 @@ class FederationHandler(BaseHandler):
):
"""Fetch the given events from a server, and persist them as outliers.
+ This function *does not* recursively get missing auth events of the
+ newly fetched events. Callers must include in the `events` argument
+ any missing events from the auth chain.
+
Logs a warning if we can't find the given event.
"""
room_version = await self.store.get_room_version(room_id)
- event_infos = []
+ event_map = {} # type: Dict[str, EventBase]
async def get_event(event_id: str):
with nested_logging_context(event_id):
@@ -1150,17 +1170,7 @@ class FederationHandler(BaseHandler):
)
return
- # recursively fetch the auth events for this event
- auth_events = await self._get_events_from_store_or_dest(
- destination, room_id, event.auth_event_ids()
- )
- auth = {}
- for auth_event_id in event.auth_event_ids():
- ae = auth_events.get(auth_event_id)
- if ae:
- auth[(ae.type, ae.state_key)] = ae
-
- event_infos.append(_NewEventInfo(event, None, auth))
+ event_map[event.event_id] = event
except Exception as e:
logger.warning(
@@ -1172,6 +1182,32 @@ class FederationHandler(BaseHandler):
await concurrently_execute(get_event, events, 5)
+ # Make a map of auth events for each event. We do this after fetching
+ # all the events as some of the events' auth events will be in the list
+ # of requested events.
+
+ auth_events = [
+ aid
+ for event in event_map.values()
+ for aid in event.auth_event_ids()
+ if aid not in event_map
+ ]
+ persisted_events = await self.store.get_events(
+ auth_events, allow_rejected=True,
+ )
+
+ event_infos = []
+ for event in event_map.values():
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id)
+ if ae:
+ auth[(ae.type, ae.state_key)] = ae
+ else:
+ logger.info("Missing auth event %s", auth_event_id)
+
+ event_infos.append(_NewEventInfo(event, None, auth))
+
await self._handle_new_events(
destination, event_infos,
)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8743e9839d..505872ee90 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -13,13 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
import urllib
from io import BytesIO
import treq
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from netaddr import IPAddress
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -371,7 +371,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body)
+ return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -412,7 +412,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body)
+ return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -441,7 +441,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
body = yield self.get_raw(uri, args, headers=headers)
- return json.loads(body)
+ return json.loads(body.decode("utf-8"))
@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}, headers=None):
@@ -485,7 +485,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body)
+ return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -503,7 +503,7 @@ class SimpleHttpClient(object):
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
- HTTP body at text.
+ HTTP body as bytes.
Raises:
HttpResponseException on a non-2xx HTTP response.
"""
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2b35f86066..cff49202f4 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -217,7 +217,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
return NOT_DONE_YET
@wrap_async_request_handler
- async def _async_render_wrapper(self, request):
+ async def _async_render_wrapper(self, request: SynapseRequest):
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
@@ -237,7 +237,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
f = failure.Failure()
self._send_error_response(f, request)
- async def _async_render(self, request):
+ async def _async_render(self, request: Request):
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overriden in sub classes for
different routing.
@@ -278,7 +278,7 @@ class DirectServeJsonResource(_AsyncResource):
"""
def _send_response(
- self, request, code, response_object,
+ self, request: Request, code: int, response_object: Any,
):
"""Implements _AsyncResource._send_response
"""
@@ -507,14 +507,29 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
def respond_with_json(
- request,
- code,
- json_object,
- send_cors=False,
- response_code_message=None,
- pretty_print=False,
- canonical_json=True,
+ request: Request,
+ code: int,
+ json_object: Any,
+ send_cors: bool = False,
+ pretty_print: bool = False,
+ canonical_json: bool = True,
):
+ """Sends encoded JSON in response to the given request.
+
+ Args:
+ request: The http request to respond to.
+ code: The HTTP response code.
+ json_object: The object to serialize to JSON.
+ send_cors: Whether to send Cross-Origin Resource Sharing headers
+ https://fetch.spec.whatwg.org/#http-cors-protocol
+ pretty_print: Whether to include indentation and line-breaks in the
+ resulting JSON bytes.
+ canonical_json: Whether to use the canonicaljson algorithm when encoding
+ the JSON bytes.
+
+ Returns:
+ twisted.web.server.NOT_DONE_YET if the request is still active.
+ """
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
@@ -522,7 +537,7 @@ def respond_with_json(
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
- return
+ return None
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + b"\n"
@@ -533,30 +548,26 @@ def respond_with_json(
else:
json_bytes = json.dumps(json_object).encode("utf-8")
- return respond_with_json_bytes(
- request,
- code,
- json_bytes,
- send_cors=send_cors,
- response_code_message=response_code_message,
- )
+ return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors)
def respond_with_json_bytes(
- request, code, json_bytes, send_cors=False, response_code_message=None
+ request: Request, code: int, json_bytes: bytes, send_cors: bool = False,
):
"""Sends encoded JSON in response to the given request.
Args:
- request (twisted.web.http.Request): The http request to respond to.
- code (int): The HTTP response code.
- json_bytes (bytes): The json bytes to use as the response body.
- send_cors (bool): Whether to send Cross-Origin Resource Sharing headers
+ request: The http request to respond to.
+ code: The HTTP response code.
+ json_bytes: The json bytes to use as the response body.
+ send_cors: Whether to send Cross-Origin Resource Sharing headers
https://fetch.spec.whatwg.org/#http-cors-protocol
+
Returns:
- twisted.web.server.NOT_DONE_YET"""
+ twisted.web.server.NOT_DONE_YET if the request is still active.
+ """
- request.setResponseCode(code, message=response_code_message)
+ request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
@@ -573,12 +584,12 @@ def respond_with_json_bytes(
return NOT_DONE_YET
-def set_cors_headers(request):
- """Set the CORs headers so that javascript running in a web browsers can
+def set_cors_headers(request: Request):
+ """Set the CORS headers so that javascript running in a web browsers can
use this API
Args:
- request (twisted.web.http.Request): The http request to add CORs to.
+ request: The http request to add CORS to.
"""
request.setHeader(b"Access-Control-Allow-Origin", b"*")
request.setHeader(
@@ -643,7 +654,7 @@ def set_clickjacking_protection_headers(request: Request):
request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
-def finish_request(request):
+def finish_request(request: Request):
""" Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the
@@ -662,7 +673,7 @@ def finish_request(request):
logger.info("Connection disconnected before response was written: %r", e)
-def _request_user_agent_is_curl(request):
+def _request_user_agent_is_curl(request: Request) -> bool:
user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
for user_agent in user_agents:
if b"curl" in user_agent:
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 13fcb408a6..3cabe9d02e 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,11 +14,9 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
-
+import json
import logging
-from canonicaljson import json
-
from synapse.api.errors import Codes, SynapseError
logger = logging.getLogger(__name__)
@@ -214,16 +212,8 @@ def parse_json_value_from_request(request, allow_empty_body=False):
if not content_bytes and allow_empty_body:
return None
- # Decode to Unicode so that simplejson will return Unicode strings on
- # Python 2
- try:
- content_unicode = content_bytes.decode("utf8")
- except UnicodeDecodeError:
- logger.warning("Unable to decode UTF-8")
- raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
-
try:
- content = json.loads(content_unicode)
+ content = json.loads(content_bytes.decode("utf-8"))
except Exception as e:
logger.warning("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index ccc7f1f0d1..f33801f883 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -293,20 +293,22 @@ class FederationAckCommand(Command):
Format::
- FEDERATION_ACK <token>
+ FEDERATION_ACK <instance_name> <token>
"""
NAME = "FEDERATION_ACK"
- def __init__(self, token):
+ def __init__(self, instance_name, token):
+ self.instance_name = instance_name
self.token = token
@classmethod
def from_line(cls, line):
- return cls(int(line))
+ instance_name, token = line.split(" ")
+ return cls(instance_name, int(token))
def to_line(self):
- return str(self.token)
+ return "%s %s" % (self.instance_name, self.token)
class RemovePusherCommand(Command):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 55b3b79008..80f5df60f9 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -238,7 +238,7 @@ class ReplicationCommandHandler:
federation_ack_counter.inc()
if self._federation_sender:
- self._federation_sender.federation_ack(cmd.token)
+ self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
async def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
@@ -527,7 +527,7 @@ class ReplicationCommandHandler:
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
- self.send_command(FederationAckCommand(token))
+ self.send_command(FederationAckCommand(self._instance_name, token))
def send_user_sync(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
diff --git a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql
new file mode 100644
index 0000000000..1cc2633aad
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We need to store the stream positions by instance in a sharded config world.
+--
+-- We default to master as we want the column to be NOT NULL and we correctly
+-- reset the instance name to match the config each time we start up.
+ALTER TABLE federation_stream_position ADD COLUMN instance_name TEXT NOT NULL DEFAULT 'master';
+
+CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name);
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 379d758b5d..5e32c7aa1e 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -45,7 +45,7 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_in_list_sql_clause
from synapse.storage.engines import PostgresEngine
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -253,6 +253,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
+ self._instance_name = hs.get_instance_name()
+ self._send_federation = hs.should_send_federation()
+ self._federation_shard_config = hs.config.federation.federation_shard_config
+
+ # If we're a process that sends federation we may need to reset the
+ # `federation_stream_position` table to match the current sharding
+ # config. We don't do this now as otherwise two processes could conflict
+ # during startup which would cause one to die.
+ self._need_to_reset_federation_stream_positions = self._send_federation
+
events_max = self.get_room_max_stream_ordering()
event_cache_prefill, min_event_val = self.db.get_cache_dict(
db_conn,
@@ -793,22 +803,95 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, events
- def get_federation_out_pos(self, typ):
- return self.db.simple_select_one_onecol(
+ async def get_federation_out_pos(self, typ: str) -> int:
+ if self._need_to_reset_federation_stream_positions:
+ await self.db.runInteraction(
+ "_reset_federation_positions_txn", self._reset_federation_positions_txn
+ )
+ self._need_to_reset_federation_stream_positions = False
+
+ return await self.db.simple_select_one_onecol(
table="federation_stream_position",
retcol="stream_id",
- keyvalues={"type": typ},
+ keyvalues={"type": typ, "instance_name": self._instance_name},
desc="get_federation_out_pos",
)
- def update_federation_out_pos(self, typ, stream_id):
- return self.db.simple_update_one(
+ async def update_federation_out_pos(self, typ, stream_id):
+ if self._need_to_reset_federation_stream_positions:
+ await self.db.runInteraction(
+ "_reset_federation_positions_txn", self._reset_federation_positions_txn
+ )
+ self._need_to_reset_federation_stream_positions = False
+
+ return await self.db.simple_update_one(
table="federation_stream_position",
- keyvalues={"type": typ},
+ keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
+ def _reset_federation_positions_txn(self, txn):
+ """Fiddles with the `federation_stream_position` table to make it match
+ the configured federation sender instances during start up.
+ """
+
+ # The federation sender instances may have changed, so we need to
+ # massage the `federation_stream_position` table to have a row per type
+ # per instance sending federation. If there is a mismatch we update the
+ # table with the correct rows using the *minimum* stream ID seen. This
+ # may result in resending of events/EDUs to remote servers, but that is
+ # preferable to dropping them.
+
+ if not self._send_federation:
+ return
+
+ # Pull out the configured instances. If we don't have a shard config then
+ # we assume that we're the only instance sending.
+ configured_instances = self._federation_shard_config.instances
+ if not configured_instances:
+ configured_instances = [self._instance_name]
+ elif self._instance_name not in configured_instances:
+ return
+
+ instances_in_table = self.db.simple_select_onecol_txn(
+ txn,
+ table="federation_stream_position",
+ keyvalues={},
+ retcol="instance_name",
+ )
+
+ if set(instances_in_table) == set(configured_instances):
+ # Nothing to do
+ return
+
+ sql = """
+ SELECT type, MIN(stream_id) FROM federation_stream_position
+ GROUP BY type
+ """
+ txn.execute(sql)
+ min_positions = dict(txn) # Map from type -> min position
+
+ # Ensure we do actually have some values here
+ assert set(min_positions) == {"federation", "events"}
+
+ sql = """
+ DELETE FROM federation_stream_position
+ WHERE NOT (%s)
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "instance_name", configured_instances
+ )
+ txn.execute(sql % (clause,), args)
+
+ for typ, stream_id in min_positions.items():
+ self.db.simple_upsert_txn(
+ txn,
+ table="federation_stream_position",
+ keyvalues={"type": typ, "instance_name": self._instance_name},
+ values={"stream_id": stream_id},
+ )
+
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 5448d9f0dc..23be1167a3 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -32,6 +32,7 @@ class FederationAckTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
+
return hs
def test_federation_ack_sent(self):
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
new file mode 100644
index 0000000000..519a2dc510
--- /dev/null
+++ b/tests/replication/test_federation_sender_shard.py
@@ -0,0 +1,286 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.app.generic_worker import GenericWorkerServer
+from synapse.events.builder import EventBuilderFactory
+from synapse.replication.http import streams
+from synapse.replication.tcp.handler import ReplicationCommandHandler
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.rest.admin import register_servlets_for_client_rest_resource
+from synapse.rest.client.v1 import login, room
+from synapse.types import UserID
+
+from tests import unittest
+from tests.server import FakeTransport
+
+logger = logging.getLogger(__name__)
+
+
+class BaseStreamTestCase(unittest.HomeserverTestCase):
+ """Base class for tests of the replication streams"""
+
+ servlets = [
+ streams.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # build a replication server
+ self.server_factory = ReplicationStreamProtocolFactory(hs)
+ self.streamer = hs.get_replication_streamer()
+
+ store = hs.get_datastore()
+ self.database = store.db
+
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["send_federation"] = False
+ return conf
+
+ def make_worker_hs(self, extra_config={}):
+ config = self._get_worker_hs_config()
+ config.update(extra_config)
+
+ mock_federation_client = Mock(spec=["put_json"])
+ mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
+
+ worker_hs = self.setup_test_homeserver(
+ http_client=mock_federation_client,
+ homeserverToUse=GenericWorkerServer,
+ config=config,
+ reactor=self.reactor,
+ )
+
+ store = worker_hs.get_datastore()
+ store.db._db_pool = self.database._db_pool
+
+ repl_handler = ReplicationCommandHandler(worker_hs)
+ client = ClientReplicationStreamProtocol(
+ worker_hs, "client", "test", self.clock, repl_handler,
+ )
+ server = self.server_factory.buildProtocol(None)
+
+ client_transport = FakeTransport(server, self.reactor)
+ client.makeConnection(client_transport)
+
+ server_transport = FakeTransport(client, self.reactor)
+ server.makeConnection(server_transport)
+
+ return worker_hs
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.federation_sender"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
+ def replicate(self):
+ """Tell the master side of replication that something has happened, and then
+ wait for the replication to occur.
+ """
+ self.streamer.on_notifier_poke()
+ self.pump()
+
+ def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+ room = self.helper.create_room_as(user, tok=token)
+ store = self.hs.get_datastore()
+ federation = self.hs.get_handlers().federation_handler
+
+ prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
+ room_version = self.get_success(store.get_room_version(room))
+
+ factory = EventBuilderFactory(self.hs)
+ factory.hostname = remote_server
+
+ user_id = UserID("user", remote_server).to_string()
+
+ event_dict = {
+ "type": EventTypes.Member,
+ "state_key": user_id,
+ "content": {"membership": Membership.JOIN},
+ "sender": user_id,
+ "room_id": room,
+ }
+
+ builder = factory.for_room_version(room_version, event_dict)
+ join_event = self.get_success(builder.build(prev_event_ids))
+
+ self.get_success(federation.on_send_join_request(remote_server, join_event))
+ self.replicate()
+
+ return room
+
+
+class FederationSenderTestCase(BaseStreamTestCase):
+ servlets = [
+ login.register_servlets,
+ register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ ]
+
+ def test_send_event_single_sender(self):
+ """Test that using a single federation sender worker correctly sends a
+ new event.
+ """
+ worker_hs = self.make_worker_hs({"send_federation": True})
+ mock_client = worker_hs.get_http_client()
+
+ user = self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ room = self.create_room_with_remote_server(user, token)
+
+ mock_client.put_json.reset_mock()
+
+ self.create_and_send_event(room, UserID.from_string(user))
+ self.replicate()
+
+ # Assert that the event was sent out over federation.
+ mock_client.put_json.assert_called()
+ self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
+ self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
+
+ def test_send_event_sharded(self):
+ """Test that using two federation sender workers correctly sends
+ new events.
+ """
+ worker1 = self.make_worker_hs(
+ {
+ "send_federation": True,
+ "worker_name": "sender1",
+ "federation_sender_instances": ["sender1", "sender2"],
+ }
+ )
+ mock_client1 = worker1.get_http_client()
+
+ worker2 = self.make_worker_hs(
+ {
+ "send_federation": True,
+ "worker_name": "sender2",
+ "federation_sender_instances": ["sender1", "sender2"],
+ }
+ )
+ mock_client2 = worker2.get_http_client()
+
+ user = self.register_user("user2", "pass")
+ token = self.login("user2", "pass")
+
+ sent_on_1 = False
+ sent_on_2 = False
+ for i in range(20):
+ server_name = "other_server_%d" % (i,)
+ room = self.create_room_with_remote_server(user, token, server_name)
+ mock_client1.reset_mock()
+ mock_client2.reset_mock()
+
+ self.create_and_send_event(room, UserID.from_string(user))
+ self.replicate()
+
+ if mock_client1.put_json.called:
+ sent_on_1 = True
+ mock_client2.put_json.assert_not_called()
+ self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("pdus"))
+ elif mock_client2.put_json.called:
+ sent_on_2 = True
+ mock_client1.put_json.assert_not_called()
+ self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("pdus"))
+ else:
+ raise AssertionError(
+ "Expected send transaction from one or the other sender"
+ )
+
+ if sent_on_1 and sent_on_2:
+ break
+
+ self.assertTrue(sent_on_1)
+ self.assertTrue(sent_on_2)
+
+ def test_send_typing_sharded(self):
+ """Test that using two federation sender workers correctly sends
+ new typing EDUs.
+ """
+ worker1 = self.make_worker_hs(
+ {
+ "send_federation": True,
+ "worker_name": "sender1",
+ "federation_sender_instances": ["sender1", "sender2"],
+ }
+ )
+ mock_client1 = worker1.get_http_client()
+
+ worker2 = self.make_worker_hs(
+ {
+ "send_federation": True,
+ "worker_name": "sender2",
+ "federation_sender_instances": ["sender1", "sender2"],
+ }
+ )
+ mock_client2 = worker2.get_http_client()
+
+ user = self.register_user("user3", "pass")
+ token = self.login("user3", "pass")
+
+ typing_handler = self.hs.get_typing_handler()
+
+ sent_on_1 = False
+ sent_on_2 = False
+ for i in range(20):
+ server_name = "other_server_%d" % (i,)
+ room = self.create_room_with_remote_server(user, token, server_name)
+ mock_client1.reset_mock()
+ mock_client2.reset_mock()
+
+ self.get_success(
+ typing_handler.started_typing(
+ target_user=UserID.from_string(user),
+ auth_user=UserID.from_string(user),
+ room_id=room,
+ timeout=20000,
+ )
+ )
+
+ self.replicate()
+
+ if mock_client1.put_json.called:
+ sent_on_1 = True
+ mock_client2.put_json.assert_not_called()
+ self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("edus"))
+ elif mock_client2.put_json.called:
+ sent_on_2 = True
+ mock_client1.put_json.assert_not_called()
+ self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("edus"))
+ else:
+ raise AssertionError(
+ "Expected send transaction from one or the other sender"
+ )
+
+ if sent_on_1 and sent_on_2:
+ break
+
+ self.assertTrue(sent_on_1)
+ self.assertTrue(sent_on_2)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index fd97999956..2be7238b00 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -398,7 +398,7 @@ class CASTestCase(unittest.HomeserverTestCase):
</cas:serviceResponse>
"""
% cas_user_id
- )
+ ).encode("utf-8")
mocked_http_client = Mock(spec=["get_raw"])
mocked_http_client.get_raw.side_effect = get_raw
|