summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7658.feature1
-rw-r--r--changelog.d/8432.misc1
-rw-r--r--changelog.d/8433.misc1
-rw-r--r--changelog.d/8443.misc1
-rwxr-xr-xdemo/start.sh2
-rw-r--r--docs/openid.md41
-rw-r--r--docs/sample_config.yaml8
-rw-r--r--mypy.ini4
-rw-r--r--stubs/txredisapi.pyi20
-rw-r--r--synapse/config/oidc_config.py9
-rw-r--r--synapse/config/tls.py18
-rw-r--r--synapse/federation/federation_server.py5
-rw-r--r--synapse/handlers/directory.py2
-rw-r--r--synapse/handlers/oidc_handler.py11
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/http/server.py4
-rw-r--r--synapse/logging/_structured.py10
-rw-r--r--synapse/push/push_rule_evaluator.py4
-rw-r--r--synapse/replication/tcp/handler.py6
-rw-r--r--synapse/replication/tcp/protocol.py10
-rw-r--r--synapse/replication/tcp/redis.py40
-rw-r--r--synapse/state/__init__.py2
-rw-r--r--synapse/storage/databases/main/censor_events.py6
-rw-r--r--synapse/storage/databases/main/events.py18
-rw-r--r--synapse/storage/databases/main/stream.py2
-rw-r--r--synapse/storage/util/id_generators.py2
-rw-r--r--tests/handlers/test_oidc.py10
-rw-r--r--tests/replication/_base.py224
-rw-r--r--tests/replication/test_sharded_event_persister.py102
-rw-r--r--tests/unittest.py2
32 files changed, 477 insertions, 95 deletions
diff --git a/changelog.d/7658.feature b/changelog.d/7658.feature
new file mode 100644
index 0000000000..fbf345988d
--- /dev/null
+++ b/changelog.d/7658.feature
@@ -0,0 +1 @@
+Add a configuration option for always using the "userinfo endpoint" for OpenID Connect. This fixes support for some identity providers, e.g. GitLab. Contributed by Benjamin Koch.
diff --git a/changelog.d/8432.misc b/changelog.d/8432.misc
new file mode 100644
index 0000000000..01fdad4caf
--- /dev/null
+++ b/changelog.d/8432.misc
@@ -0,0 +1 @@
+Check for unreachable code with mypy.
diff --git a/changelog.d/8433.misc b/changelog.d/8433.misc
new file mode 100644
index 0000000000..05f8b5bbf4
--- /dev/null
+++ b/changelog.d/8433.misc
@@ -0,0 +1 @@
+Add unit test for event persister sharding.
diff --git a/changelog.d/8443.misc b/changelog.d/8443.misc
new file mode 100644
index 0000000000..633598e6b3
--- /dev/null
+++ b/changelog.d/8443.misc
@@ -0,0 +1 @@
+Configure `public_baseurl` when using demo scripts.
diff --git a/demo/start.sh b/demo/start.sh
index 83396e5c33..f6b5ea137f 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -30,6 +30,8 @@ for port in 8080 8081 8082; do
     if ! grep -F "Customisation made by demo/start.sh" -q  $DIR/etc/$port.config; then
         printf '\n\n# Customisation made by demo/start.sh\n' >> $DIR/etc/$port.config
 
+        echo "public_baseurl: http://localhost:$port/" >> $DIR/etc/$port.config
+
         echo 'enable_registration: true' >> $DIR/etc/$port.config
 
         # Warning, this heredoc depends on the interaction of tabs and spaces. Please don't
diff --git a/docs/openid.md b/docs/openid.md
index 70b37f858b..4873681999 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -238,13 +238,36 @@ Synapse config:
 
 ```yaml
 oidc_config:
-   enabled: true
-   issuer: "https://id.twitch.tv/oauth2/"
-   client_id: "your-client-id" # TO BE FILLED
-   client_secret: "your-client-secret" # TO BE FILLED
-   client_auth_method: "client_secret_post"
-   user_mapping_provider:
-     config:
-       localpart_template: '{{ user.preferred_username }}'
-       display_name_template: '{{ user.name }}'
+  enabled: true
+  issuer: "https://id.twitch.tv/oauth2/"
+  client_id: "your-client-id" # TO BE FILLED
+  client_secret: "your-client-secret" # TO BE FILLED
+  client_auth_method: "client_secret_post"
+  user_mapping_provider:
+    config:
+      localpart_template: "{{ user.preferred_username }}"
+      display_name_template: "{{ user.name }}"
+```
+
+### GitLab
+
+1. Create a [new application](https://gitlab.com/profile/applications).
+2. Add the `read_user` and `openid` scopes.
+3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback`
+
+Synapse config:
+
+```yaml
+oidc_config:
+  enabled: true
+  issuer: "https://gitlab.com/"
+  client_id: "your-client-id" # TO BE FILLED
+  client_secret: "your-client-secret" # TO BE FILLED
+  client_auth_method: "client_secret_post"
+  scopes: ["openid", "read_user"]
+  user_profile_method: "userinfo_endpoint"
+  user_mapping_provider:
+    config:
+      localpart_template: '{{ user.nickname }}'
+      display_name_template: '{{ user.name }}'
 ```
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 8a3206e845..b2c1d7a737 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1714,6 +1714,14 @@ oidc_config:
   #
   #skip_verification: true
 
+  # Whether to fetch the user profile from the userinfo endpoint. Valid
+  # values are: "auto" or "userinfo_endpoint".
+  #
+  # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
+  # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
+  #
+  #user_profile_method: "userinfo_endpoint"
+
   # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
   # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
   #
diff --git a/mypy.ini b/mypy.ini
index 7986781432..e84ad04e41 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -6,6 +6,7 @@ check_untyped_defs = True
 show_error_codes = True
 show_traceback = True
 mypy_path = stubs
+warn_unreachable = True
 files =
   synapse/api,
   synapse/appservice,
@@ -142,3 +143,6 @@ ignore_missing_imports = True
 
 [mypy-nacl.*]
 ignore_missing_imports = True
+
+[mypy-hiredis]
+ignore_missing_imports = True
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index c66413f003..522244bb57 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -16,7 +16,7 @@
 """Contains *incomplete* type hints for txredisapi.
 """
 
-from typing import List, Optional, Union
+from typing import List, Optional, Union, Type
 
 class RedisProtocol:
     def publish(self, channel: str, message: bytes): ...
@@ -42,3 +42,21 @@ def lazyConnection(
 
 class SubscriberFactory:
     def buildProtocol(self, addr): ...
+
+class ConnectionHandler: ...
+
+class RedisFactory:
+    continueTrying: bool
+    handler: RedisProtocol
+    def __init__(
+        self,
+        uuid: str,
+        dbid: Optional[int],
+        poolsize: int,
+        isLazy: bool = False,
+        handler: Type = ConnectionHandler,
+        charset: str = "utf-8",
+        password: Optional[str] = None,
+        replyTimeout: Optional[int] = None,
+        convertNumbers: Optional[int] = True,
+    ): ...
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index f924116819..7597fbc864 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -56,6 +56,7 @@ class OIDCConfig(Config):
         self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
         self.oidc_jwks_uri = oidc_config.get("jwks_uri")
         self.oidc_skip_verification = oidc_config.get("skip_verification", False)
+        self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
         self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
 
         ump_config = oidc_config.get("user_mapping_provider", {})
@@ -159,6 +160,14 @@ class OIDCConfig(Config):
           #
           #skip_verification: true
 
+          # Whether to fetch the user profile from the userinfo endpoint. Valid
+          # values are: "auto" or "userinfo_endpoint".
+          #
+          # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
+          # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
+          #
+          #user_profile_method: "userinfo_endpoint"
+
           # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
           # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
           #
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 9ddb8b546b..ad37b93c02 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -18,7 +18,7 @@ import os
 import warnings
 from datetime import datetime
 from hashlib import sha256
-from typing import List
+from typing import List, Optional
 
 from unpaddedbase64 import encode_base64
 
@@ -177,8 +177,8 @@ class TlsConfig(Config):
             "use_insecure_ssl_client_just_for_testing_do_not_use"
         )
 
-        self.tls_certificate = None
-        self.tls_private_key = None
+        self.tls_certificate = None  # type: Optional[crypto.X509]
+        self.tls_private_key = None  # type: Optional[crypto.PKey]
 
     def is_disk_cert_valid(self, allow_self_signed=True):
         """
@@ -226,12 +226,12 @@ class TlsConfig(Config):
         days_remaining = (expires_on - now).days
         return days_remaining
 
-    def read_certificate_from_disk(self, require_cert_and_key):
+    def read_certificate_from_disk(self, require_cert_and_key: bool):
         """
         Read the certificates and private key from disk.
 
         Args:
-            require_cert_and_key (bool): set to True to throw an error if the certificate
+            require_cert_and_key: set to True to throw an error if the certificate
                 and key file are not given
         """
         if require_cert_and_key:
@@ -479,13 +479,13 @@ class TlsConfig(Config):
             }
         )
 
-    def read_tls_certificate(self):
+    def read_tls_certificate(self) -> crypto.X509:
         """Reads the TLS certificate from the configured file, and returns it
 
         Also checks if it is self-signed, and warns if so
 
         Returns:
-            OpenSSL.crypto.X509: the certificate
+            The certificate
         """
         cert_path = self.tls_certificate_file
         logger.info("Loading TLS certificate from %s", cert_path)
@@ -504,11 +504,11 @@ class TlsConfig(Config):
 
         return cert
 
-    def read_tls_private_key(self):
+    def read_tls_private_key(self) -> crypto.PKey:
         """Reads the TLS private key from the configured file, and returns it
 
         Returns:
-            OpenSSL.crypto.PKey: the private key
+            The private key
         """
         private_key_path = self.tls_private_key_file
         logger.info("Loading TLS key from %s", private_key_path)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 24329dd0e3..02f11e1209 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,7 +22,6 @@ from typing import (
     Callable,
     Dict,
     List,
-    Match,
     Optional,
     Tuple,
     Union,
@@ -825,14 +824,14 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
     return False
 
 
-def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
+def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
     if not isinstance(acl_entry, str):
         logger.warning(
             "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
         )
         return False
     regex = glob_to_regex(acl_entry)
-    return regex.match(server_name)
+    return bool(regex.match(server_name))
 
 
 class FederationHandlerRegistry:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 62aa9a2da8..6f15c68240 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -383,7 +383,7 @@ class DirectoryHandler(BaseHandler):
         """
         creator = await self.store.get_room_alias_creator(alias.to_string())
 
-        if creator is not None and creator == user_id:
+        if creator == user_id:
             return True
 
         # Resolve the alias to the corresponding room.
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 19cd652675..05ac86e697 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -96,6 +96,7 @@ class OidcHandler:
         self.hs = hs
         self._callback_url = hs.config.oidc_callback_url  # type: str
         self._scopes = hs.config.oidc_scopes  # type: List[str]
+        self._user_profile_method = hs.config.oidc_user_profile_method  # type: str
         self._client_auth = ClientAuth(
             hs.config.oidc_client_id,
             hs.config.oidc_client_secret,
@@ -196,11 +197,11 @@ class OidcHandler:
                     % (m["response_types_supported"],)
                 )
 
-        # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
+        # Ensure there's a userinfo endpoint to fetch from if it is required.
         if self._uses_userinfo:
             if m.get("userinfo_endpoint") is None:
                 raise ValueError(
-                    'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
+                    'provider has no "userinfo_endpoint", even though it is required'
                 )
         else:
             # If we're not using userinfo, we need a valid jwks to validate the ID token
@@ -220,8 +221,10 @@ class OidcHandler:
         ``access_token`` with the ``userinfo_endpoint``.
         """
 
-        # Maybe that should be user-configurable and not inferred?
-        return "openid" not in self._scopes
+        return (
+            "openid" not in self._scopes
+            or self._user_profile_method == "userinfo_endpoint"
+        )
 
     async def load_metadata(self) -> OpenIDProviderMetadata:
         """Load and validate the provider metadata.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index d5f7c78edf..f1a6699cd4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -962,8 +962,6 @@ class RoomCreationHandler(BaseHandler):
             try:
                 random_string = stringutils.random_string(18)
                 gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
-                if isinstance(gen_room_id, bytes):
-                    gen_room_id = gen_room_id.decode("utf-8")
                 await self.store.store_room(
                     room_id=gen_room_id,
                     room_creator_user_id=creator_id,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8feba8c90a..5ec36f591d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -642,7 +642,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
     async def send_membership_event(
         self,
-        requester: Requester,
+        requester: Optional[Requester],
         event: EventBase,
         context: EventContext,
         ratelimit: bool = True,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index bfe2583002..260ec19b41 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -87,7 +87,7 @@ class SyncConfig:
 class TimelineBatch:
     prev_batch = attr.ib(type=StreamToken)
     events = attr.ib(type=List[EventBase])
-    limited = attr.ib(bool)
+    limited = attr.ib(type=bool)
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 996a31a9ec..09ed74f6ce 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -257,7 +257,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
             if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
                 callback_return = await raw_callback_return
             else:
-                callback_return = raw_callback_return
+                callback_return = raw_callback_return  # type: ignore
 
             return callback_return
 
@@ -406,7 +406,7 @@ class JsonResource(DirectServeJsonResource):
         if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
             callback_return = await raw_callback_return
         else:
-            callback_return = raw_callback_return
+            callback_return = raw_callback_return  # type: ignore
 
         return callback_return
 
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 144506c8f2..0fc2ea609e 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import os.path
 import sys
@@ -89,14 +88,7 @@ class LogContextObserver:
         context = current_context()
 
         # Copy the context information to the log event.
-        if context is not None:
-            context.copy_to_twisted_log_entry(event)
-        else:
-            # If there's no logging context, not even the root one, we might be
-            # starting up or it might be from non-Synapse code. Log it as if it
-            # came from the root logger.
-            event["request"] = None
-            event["scope"] = None
+        context.copy_to_twisted_log_entry(event)
 
         self.observer(event)
 
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 709ace01e5..3a68ce636f 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,7 +16,7 @@
 
 import logging
 import re
-from typing import Any, Dict, List, Pattern, Union
+from typing import Any, Dict, List, Optional, Pattern, Union
 
 from synapse.events import EventBase
 from synapse.types import UserID
@@ -181,7 +181,7 @@ class PushRuleEvaluatorForEvent:
 
         return r.search(body)
 
-    def _get_value(self, dotted_key: str) -> str:
+    def _get_value(self, dotted_key: str) -> Optional[str]:
         return self._value_cache.get(dotted_key, None)
 
 
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..e92da7b263 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -251,10 +251,9 @@ class ReplicationCommandHandler:
         using TCP.
         """
         if hs.config.redis.redis_enabled:
-            import txredisapi
-
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
+                lazyConnection,
             )
 
             logger.info(
@@ -271,7 +270,8 @@ class ReplicationCommandHandler:
             # connection after SUBSCRIBE is called).
 
             # First create the connection for sending commands.
-            outbound_redis_connection = txredisapi.lazyConnection(
+            outbound_redis_connection = lazyConnection(
+                reactor=hs.get_reactor(),
                 host=hs.config.redis_host,
                 port=hs.config.redis_port,
                 password=hs.config.redis.redis_password,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0b0d204e64..a509e599c2 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -51,10 +51,11 @@ import fcntl
 import logging
 import struct
 from inspect import isawaitable
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
 
 from prometheus_client import Counter
 
+from twisted.internet import task
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
@@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
-        self.time_we_closed = None  # When we requested the connection be closed
+        # When we requested the connection be closed
+        self.time_we_closed = None  # type: Optional[int]
 
-        self.received_ping = False  # Have we reecived a ping from the other side
+        self.received_ping = False  # Have we received a ping from the other side
 
         self.state = ConnectionStates.CONNECTING
 
@@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.pending_commands = []  # type: List[Command]
 
         # The LoopingCall for sending pings.
-        self._send_ping_loop = None
+        self._send_ping_loop = None  # type: Optional[task.LoopingCall]
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index f225e533de..de19705c1f 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
 
 import txredisapi
 
@@ -228,3 +228,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
         p.password = self.password
 
         return p
+
+
+def lazyConnection(
+    reactor,
+    host: str = "localhost",
+    port: int = 6379,
+    dbid: Optional[int] = None,
+    reconnect: bool = True,
+    charset: str = "utf-8",
+    password: Optional[str] = None,
+    connectTimeout: Optional[int] = None,
+    replyTimeout: Optional[int] = None,
+    convertNumbers: bool = True,
+) -> txredisapi.RedisProtocol:
+    """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
+    reactor.
+    """
+
+    isLazy = True
+    poolsize = 1
+
+    uuid = "%s:%d" % (host, port)
+    factory = txredisapi.RedisFactory(
+        uuid,
+        dbid,
+        poolsize,
+        isLazy,
+        txredisapi.ConnectionHandler,
+        charset,
+        password,
+        replyTimeout,
+        convertNumbers,
+    )
+    factory.continueTrying = reconnect
+    for x in range(poolsize):
+        reactor.connectTCP(host, port, factory, connectTimeout)
+
+    return factory.handler
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 31082bb16a..5b0900aa3c 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -738,7 +738,7 @@ def _make_state_cache_entry(
 
     # failing that, look for the closest match.
     prev_group = None
-    delta_ids = None
+    delta_ids = None  # type: Optional[StateMap[str]]
 
     for old_group, old_state in state_groups_ids.items():
         n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index f211ddbaf8..4bb2b9c28c 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -21,8 +21,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.databases.main.events import encode_json
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util.frozenutils import frozendict_json_encoder
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 and original_event.internal_metadata.is_redacted()
             ):
                 # Redaction was allowed
-                pruned_json = encode_json(
+                pruned_json = frozendict_json_encoder.encode(
                     prune_event_dict(
                         original_event.room_version, original_event.get_dict()
                     )
@@ -171,7 +171,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 return
 
             # Prune the event's dict then convert it to JSON.
-            pruned_json = encode_json(
+            pruned_json = frozendict_json_encoder.encode(
                 prune_event_dict(event.room_version, event.get_dict())
             )
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 18def01f50..78e645592f 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -52,16 +52,6 @@ event_counter = Counter(
 )
 
 
-def encode_json(json_object):
-    """
-    Encode a Python object as JSON and return it in a Unicode string.
-    """
-    out = frozendict_json_encoder.encode(json_object)
-    if isinstance(out, bytes):
-        out = out.decode("utf8")
-    return out
-
-
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
@@ -743,7 +733,9 @@ class PersistEventsStore:
                     logger.exception("")
                     raise
 
-                metadata_json = encode_json(event.internal_metadata.get_dict())
+                metadata_json = frozendict_json_encoder.encode(
+                    event.internal_metadata.get_dict()
+                )
 
                 sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
                 txn.execute(sql, (metadata_json, event.event_id))
@@ -797,10 +789,10 @@ class PersistEventsStore:
                 {
                     "event_id": event.event_id,
                     "room_id": event.room_id,
-                    "internal_metadata": encode_json(
+                    "internal_metadata": frozendict_json_encoder.encode(
                         event.internal_metadata.get_dict()
                     ),
-                    "json": encode_json(event_dict(event)),
+                    "json": frozendict_json_encoder.encode(event_dict(event)),
                     "format_version": event.format_version,
                 }
                 for event, _ in events_and_contexts
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 37249f1e3f..1d27439536 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -546,7 +546,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
     async def get_room_event_before_stream_ordering(
         self, room_id: str, stream_ordering: int
-    ) -> Tuple[int, int, str]:
+    ) -> Optional[Tuple[int, int, str]]:
         """Gets details of the first event in a room at or before a stream ordering
 
         Args:
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 48efbb5067..c92cd4a6ba 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -421,7 +421,7 @@ class MultiWriterIdGenerator:
             self._unfinished_ids.discard(next_id)
             self._finished_ids.add(next_id)
 
-            new_cur = None
+            new_cur = None  # type: Optional[int]
 
             if self._unfinished_ids:
                 # If there are unfinished IDs then the new position will be the
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index d5087e58be..b6f436c016 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -286,9 +286,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 h._validate_metadata,
             )
 
-        # Tests for configs that the userinfo endpoint
+        # Tests for configs that require the userinfo endpoint
         self.assertFalse(h._uses_userinfo)
-        h._scopes = []  # do not request the openid scope
+        self.assertEqual(h._user_profile_method, "auto")
+        h._user_profile_method = "userinfo_endpoint"
+        self.assertTrue(h._uses_userinfo)
+
+        # Revert the profile method and do not request the "openid" scope.
+        h._user_profile_method = "auto"
+        h._scopes = []
         self.assertTrue(h._uses_userinfo)
         self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
 
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index ae60874ec3..81ea985b9f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import Any, Callable, List, Optional, Tuple
 
 import attr
+import hiredis
 
 from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
+from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 
@@ -27,7 +28,7 @@ from synapse.app.generic_worker import (
     GenericWorkerServer,
 )
 from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
+from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource, streams
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.server_factory = ReplicationStreamProtocolFactory(self.hs)
         self.streamer = self.hs.get_replication_streamer()
 
+        # Fake in memory Redis server that servers can connect to.
+        self._redis_server = FakeRedisPubSubServer()
+
         store = self.hs.get_datastore()
         self.database_pool = store.db_pool
 
         self.reactor.lookups["testserv"] = "1.2.3.4"
+        self.reactor.lookups["localhost"] = "127.0.0.1"
+
+        # A map from a HS instance to the associated HTTP Site to use for
+        # handling inbound HTTP requests to that instance.
+        self._hs_to_site = {self.hs: self.site}
+
+        if self.hs.config.redis.redis_enabled:
+            # Handle attempts to connect to fake redis server.
+            self.reactor.add_tcp_client_callback(
+                "localhost", 6379, self.connect_any_redis_attempts,
+            )
 
-        self._worker_hs_to_resource = {}
+            self.hs.get_tcp_replication().start_replication(self.hs)
 
         # When we see a connection attempt to the master replication listener we
         # automatically set up the connection. This is so that tests don't
         # manually have to go and explicitly set it up each time (plus sometimes
         # it is impossible to write the handling explicitly in the tests).
+        #
+        # Register the master replication listener:
         self.reactor.add_tcp_client_callback(
-            "1.2.3.4", 8765, self._handle_http_replication_attempt
+            "1.2.3.4",
+            8765,
+            lambda: self._handle_http_replication_attempt(self.hs, 8765),
         )
 
     def create_test_json_resource(self):
@@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             **kwargs
         )
 
+        # If the instance is in the `instance_map` config then workers may try
+        # and send HTTP requests to it, so we register it with
+        # `_handle_http_replication_attempt` like we do with the master HS.
+        instance_name = worker_hs.get_instance_name()
+        instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
+        if instance_loc:
+            # Ensure the host is one that has a fake DNS entry.
+            if instance_loc.host not in self.reactor.lookups:
+                raise Exception(
+                    "Host does not have an IP for instance_map[%r].host = %r"
+                    % (instance_name, instance_loc.host,)
+                )
+
+            self.reactor.add_tcp_client_callback(
+                self.reactor.lookups[instance_loc.host],
+                instance_loc.port,
+                lambda: self._handle_http_replication_attempt(
+                    worker_hs, instance_loc.port
+                ),
+            )
+
         store = worker_hs.get_datastore()
         store.db_pool._db_pool = self.database_pool._db_pool
 
-        repl_handler = ReplicationCommandHandler(worker_hs)
-        client = ClientReplicationStreamProtocol(
-            worker_hs, "client", "test", self.clock, repl_handler,
-        )
-        server = self.server_factory.buildProtocol(None)
+        # Set up TCP replication between master and the new worker if we don't
+        # have Redis support enabled.
+        if not worker_hs.config.redis_enabled:
+            repl_handler = ReplicationCommandHandler(worker_hs)
+            client = ClientReplicationStreamProtocol(
+                worker_hs, "client", "test", self.clock, repl_handler,
+            )
+            server = self.server_factory.buildProtocol(None)
 
-        client_transport = FakeTransport(server, self.reactor)
-        client.makeConnection(client_transport)
+            client_transport = FakeTransport(server, self.reactor)
+            client.makeConnection(client_transport)
 
-        server_transport = FakeTransport(client, self.reactor)
-        server.makeConnection(server_transport)
+            server_transport = FakeTransport(client, self.reactor)
+            server.makeConnection(server_transport)
 
         # Set up a resource for the worker
-        resource = ReplicationRestResource(self.hs)
+        resource = ReplicationRestResource(worker_hs)
 
         for servlet in self.servlets:
             servlet(worker_hs, resource)
 
-        self._worker_hs_to_resource[worker_hs] = resource
+        self._hs_to_site[worker_hs] = SynapseSite(
+            logger_name="synapse.access.http.fake",
+            site_tag="{}-{}".format(
+                worker_hs.config.server.server_name, worker_hs.get_instance_name()
+            ),
+            config=worker_hs.config.server.listeners[0],
+            resource=resource,
+            server_version_string="1",
+        )
+
+        if worker_hs.config.redis.redis_enabled:
+            worker_hs.get_tcp_replication().start_replication(worker_hs)
 
         return worker_hs
 
@@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         return config
 
     def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
-        render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+        render(request, self._hs_to_site[worker_hs].resource, self.reactor)
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
@@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.streamer.on_notifier_poke()
         self.pump()
 
-    def _handle_http_replication_attempt(self):
-        """Handles a connection attempt to the master replication HTTP
-        listener.
+    def _handle_http_replication_attempt(self, hs, repl_port):
+        """Handles a connection attempt to the given HS replication HTTP
+        listener on the given port.
         """
 
         # We should have at least one outbound connection attempt, where the
@@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.assertGreaterEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
         self.assertEqual(host, "1.2.3.4")
-        self.assertEqual(port, 8765)
+        self.assertEqual(port, repl_port)
 
         # Set up client side protocol
         client_protocol = client_factory.buildProtocol(None)
@@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # Set up the server side protocol
         channel = _PushHTTPChannel(self.reactor)
         channel.requestFactory = request_factory
-        channel.site = self.site
+        channel.site = self._hs_to_site[hs]
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # inside `connecTCP` before the connection has been passed back to the
         # code that requested the TCP connection.
 
+    def connect_any_redis_attempts(self):
+        """If redis is enabled we need to deal with workers connecting to a
+        redis server. We don't want to use a real Redis server so we use a
+        fake one.
+        """
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, "localhost")
+        self.assertEqual(port, 6379)
+
+        client_protocol = client_factory.buildProtocol(None)
+        server_protocol = self._redis_server.buildProtocol(None)
+
+        client_to_server_transport = FakeTransport(
+            server_protocol, self.reactor, client_protocol
+        )
+        client_protocol.makeConnection(client_to_server_transport)
+
+        server_to_client_transport = FakeTransport(
+            client_protocol, self.reactor, server_protocol
+        )
+        server_protocol.makeConnection(server_to_client_transport)
+
+        return client_to_server_transport, server_to_client_transport
+
 
 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
     """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -467,3 +547,105 @@ class _PullToPushProducer:
                 pass
 
             self.stopProducing()
+
+
+class FakeRedisPubSubServer:
+    """A fake Redis server for pub/sub.
+    """
+
+    def __init__(self):
+        self._subscribers = set()
+
+    def add_subscriber(self, conn):
+        """A connection has called SUBSCRIBE
+        """
+        self._subscribers.add(conn)
+
+    def remove_subscriber(self, conn):
+        """A connection has called UNSUBSCRIBE
+        """
+        self._subscribers.discard(conn)
+
+    def publish(self, conn, channel, msg) -> int:
+        """A connection want to publish a message to subscribers.
+        """
+        for sub in self._subscribers:
+            sub.send(["message", channel, msg])
+
+        return len(self._subscribers)
+
+    def buildProtocol(self, addr):
+        return FakeRedisPubSubProtocol(self)
+
+
+class FakeRedisPubSubProtocol(Protocol):
+    """A connection from a client talking to the fake Redis server.
+    """
+
+    def __init__(self, server: FakeRedisPubSubServer):
+        self._server = server
+        self._reader = hiredis.Reader()
+
+    def dataReceived(self, data):
+        self._reader.feed(data)
+
+        # We might get multiple messages in one packet.
+        while True:
+            msg = self._reader.gets()
+
+            if msg is False:
+                # No more messages.
+                return
+
+            if not isinstance(msg, list):
+                # Inbound commands should always be a list
+                raise Exception("Expected redis list")
+
+            self.handle_command(msg[0], *msg[1:])
+
+    def handle_command(self, command, *args):
+        """Received a Redis command from the client.
+        """
+
+        # We currently only support pub/sub.
+        if command == b"PUBLISH":
+            channel, message = args
+            num_subscribers = self._server.publish(self, channel, message)
+            self.send(num_subscribers)
+        elif command == b"SUBSCRIBE":
+            (channel,) = args
+            self._server.add_subscriber(self)
+            self.send(["subscribe", channel, 1])
+        else:
+            raise Exception("Unknown command")
+
+    def send(self, msg):
+        """Send a message back to the client.
+        """
+        raw = self.encode(msg).encode("utf-8")
+
+        self.transport.write(raw)
+        self.transport.flush()
+
+    def encode(self, obj):
+        """Encode an object to its Redis format.
+
+        Supports: strings/bytes, integers and list/tuples.
+        """
+
+        if isinstance(obj, bytes):
+            # We assume bytes are just unicode strings.
+            obj = obj.decode("utf-8")
+
+        if isinstance(obj, str):
+            return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
+        if isinstance(obj, int):
+            return ":{val}\r\n".format(val=obj)
+        if isinstance(obj, (list, tuple)):
+            items = "".join(self.encode(a) for a in obj)
+            return "*{len}\r\n{items}".format(len=len(obj), items=items)
+
+        raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
+
+    def connectionLost(self, reason):
+        self._server.remove_subscriber(self)
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
new file mode 100644
index 0000000000..6068d14905
--- /dev/null
+++ b/tests/replication/test_sharded_event_persister.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+logger = logging.getLogger(__name__)
+
+
+class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
+    """Checks event persisting sharding works
+    """
+
+    # Event persister sharding requires postgres (due to needing
+    # `MutliWriterIdGenerator`).
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        # Register a user who sends a message that we'll get notified about
+        self.other_user_id = self.register_user("otheruser", "pass")
+        self.other_access_token = self.login("otheruser", "pass")
+
+    def default_config(self):
+        conf = super().default_config()
+        conf["redis"] = {"enabled": "true"}
+        conf["stream_writers"] = {"events": ["worker1", "worker2"]}
+        conf["instance_map"] = {
+            "worker1": {"host": "testserv", "port": 1001},
+            "worker2": {"host": "testserv", "port": 1002},
+        }
+        return conf
+
+    def test_basic(self):
+        """Simple test to ensure that multiple rooms can be created and joined,
+        and that different rooms get handled by different instances.
+        """
+
+        self.make_worker_hs(
+            "synapse.app.generic_worker", {"worker_name": "worker1"},
+        )
+
+        self.make_worker_hs(
+            "synapse.app.generic_worker", {"worker_name": "worker2"},
+        )
+
+        persisted_on_1 = False
+        persisted_on_2 = False
+
+        store = self.hs.get_datastore()
+
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Keep making new rooms until we see rooms being persisted on both
+        # workers.
+        for _ in range(10):
+            # Create a room
+            room = self.helper.create_room_as(user_id, tok=access_token)
+
+            # The other user joins
+            self.helper.join(
+                room=room, user=self.other_user_id, tok=self.other_access_token
+            )
+
+            # The other user sends some messages
+            rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+            event_id = rseponse["event_id"]
+
+            # The event position includes which instance persisted the event.
+            pos = self.get_success(store.get_position_for_event(event_id))
+
+            persisted_on_1 |= pos.instance_name == "worker1"
+            persisted_on_2 |= pos.instance_name == "worker2"
+
+            if persisted_on_1 and persisted_on_2:
+                break
+
+        self.assertTrue(persisted_on_1)
+        self.assertTrue(persisted_on_2)
diff --git a/tests/unittest.py b/tests/unittest.py
index e654c0442d..82ede9de34 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -241,7 +241,7 @@ class HomeserverTestCase(TestCase):
         # create a site to wrap the resource.
         self.site = SynapseSite(
             logger_name="synapse.access.http.fake",
-            site_tag="test",
+            site_tag=self.hs.config.server.server_name,
             config=self.hs.config.server.listeners[0],
             resource=self.resource,
             server_version_string="1",