summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/database.py55
-rw-r--r--synapse/crypto/event_signing.py9
-rw-r--r--synapse/crypto/keyring.py13
-rw-r--r--synapse/handlers/presence.py9
-rw-r--r--synapse/handlers/room.py15
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py30
-rw-r--r--synapse/storage/data_stores/__init__.py21
-rw-r--r--synapse/storage/data_stores/main/state.py4
9 files changed, 113 insertions, 45 deletions
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 134824789c..219b32f670 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -15,7 +15,6 @@
 import logging
 import os
 from textwrap import indent
-from typing import List
 
 import yaml
 
@@ -30,16 +29,13 @@ class DatabaseConnectionConfig:
     Args:
         name: A label for the database, used for logging.
         db_config: The config for a particular database, as per `database`
-            section of main config. Has two fields: `name` for database
-            module name, and `args` for the args to give to the database
-            connector.
-        data_stores: The list of data stores that should be provisioned on the
-            database. Defaults to all data stores.
+            section of main config. Has three fields: `name` for database
+            module name, `args` for the args to give to the database
+            connector, and optional `data_stores` that is a list of stores to
+            provision on this database (defaulting to all).
     """
 
-    def __init__(
-        self, name: str, db_config: dict, data_stores: List[str] = ["main", "state"]
-    ):
+    def __init__(self, name: str, db_config: dict):
         if db_config["name"] not in ("sqlite3", "psycopg2"):
             raise ConfigError("Unsupported database type %r" % (db_config["name"],))
 
@@ -48,6 +44,10 @@ class DatabaseConnectionConfig:
                 {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
             )
 
+        data_stores = db_config.get("data_stores")
+        if data_stores is None:
+            data_stores = ["main", "state"]
+
         self.name = name
         self.config = db_config
         self.data_stores = data_stores
@@ -59,14 +59,43 @@ class DatabaseConfig(Config):
     def read_config(self, config, **kwargs):
         self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
 
+        # We *experimentally* support specifying multiple databases via the
+        # `databases` key. This is a map from a label to database config in the
+        # same format as the `database` config option, plus an extra
+        # `data_stores` key to specify which data store goes where. For example:
+        #
+        #   databases:
+        #       master:
+        #           name: psycopg2
+        #           data_stores: ["main"]
+        #           args: {}
+        #       state:
+        #           name: psycopg2
+        #           data_stores: ["state"]
+        #           args: {}
+
+        multi_database_config = config.get("databases")
         database_config = config.get("database")
 
-        if database_config is None:
-            database_config = {"name": "sqlite3", "args": {}}
+        if multi_database_config and database_config:
+            raise ConfigError("Can't specify both 'database' and 'datbases' in config")
+
+        if multi_database_config:
+            if config.get("database_path"):
+                raise ConfigError("Can't specify 'database_path' with 'databases'")
+
+            self.databases = [
+                DatabaseConnectionConfig(name, db_conf)
+                for name, db_conf in multi_database_config.items()
+            ]
+
+        else:
+            if database_config is None:
+                database_config = {"name": "sqlite3", "args": {}}
 
-        self.databases = [DatabaseConnectionConfig("master", database_config)]
+            self.databases = [DatabaseConnectionConfig("master", database_config)]
 
-        self.set_databasepath(config.get("database_path"))
+            self.set_databasepath(config.get("database_path"))
 
     def generate_config_section(self, data_dir_path, database_conf, **kwargs):
         if not database_conf:
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index ccaa8a9920..e65bd61d97 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
+import collections.abc
 import hashlib
 import logging
 
@@ -40,8 +40,11 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
     # some malformed events lack a 'hashes'. Protect against it being missing
     # or a weird type by basically treating it the same as an unhashed event.
     hashes = event.get("hashes")
-    if not isinstance(hashes, dict):
-        raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED)
+    # nb it might be a frozendict or a dict
+    if not isinstance(hashes, collections.abc.Mapping):
+        raise SynapseError(
+            400, "Malformed 'hashes': %s" % (type(hashes),), Codes.UNAUTHORIZED
+        )
 
     if name not in hashes:
         raise SynapseError(
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 7cfad192e8..6fe5a6a26a 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -511,17 +511,18 @@ class BaseV2KeyFetcher(object):
         server_name = response_json["server_name"]
         verified = False
         for key_id in response_json["signatures"].get(server_name, {}):
-            # each of the keys used for the signature must be present in the response
-            # json.
             key = verify_keys.get(key_id)
             if not key:
-                raise KeyLookupError(
-                    "Key response is signed by key id %s:%s but that key is not "
-                    "present in the response" % (server_name, key_id)
-                )
+                # the key may not be present in verify_keys if:
+                #  * we got the key from the notary server, and:
+                #  * the key belongs to the notary server, and:
+                #  * the notary server is using a different key to sign notary
+                #    responses.
+                continue
 
             verify_signed_json(response_json, server_name, key.verify_key)
             verified = True
+            break
 
         if not verified:
             raise KeyLookupError(
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 240c4add12..202aa9294f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -95,12 +95,7 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
 
 
 class PresenceHandler(object):
-    def __init__(self, hs):
-        """
-
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.hs = hs
         self.is_mine = hs.is_mine
         self.is_mine_id = hs.is_mine_id
@@ -230,7 +225,7 @@ class PresenceHandler(object):
         is some spurious presence changes that will self-correct.
         """
         # If the DB pool has already terminated, don't try updating
-        if not self.store.database.is_running():
+        if not self.store.db.is_running():
             return
 
         logger.info(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 4f489762fc..9cab2adbfb 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 
 """Contains functions for performing events on rooms."""
-import copy
+
 import itertools
 import logging
 import math
@@ -368,13 +368,16 @@ class RoomCreationHandler(BaseHandler):
         # Raise the requester's power level in the new room if necessary
         current_power_level = power_levels["users"][user_id]
         if current_power_level < needed_power_level:
-            # Perform a deepcopy in order to not modify the original power levels in a
-            # room, as its contents are preserved as the state for the old room later on
-            new_power_levels = copy.deepcopy(power_levels)
-            initial_state[(EventTypes.PowerLevels, "")] = new_power_levels
+            # make sure we copy the event content rather than overwriting it.
+            # note that if frozen_dicts are enabled, `power_levels` will be a frozen
+            # dict so we can't just copy.deepcopy it.
 
-            # Assign this power level to the requester
+            new_power_levels = {k: v for k, v in power_levels.items() if k != "users"}
+            new_power_levels["users"] = {
+                k: v for k, v in power_levels.get("users", {}).items() if k != user_id
+            }
             new_power_levels["users"][user_id] = needed_power_level
+            initial_state[(EventTypes.PowerLevels, "")] = new_power_levels
 
         yield self._send_events_for_new_room(
             requester,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 3dc2b2dd8a..03bb52ccfb 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -506,6 +506,8 @@ class RoomMemberHandler(object):
         Returns:
             Deferred
         """
+        logger.info("Transferring room state from %s to %s", old_room_id, room_id)
+
         # Find all local users that were in the old room and copy over each user's state
         users = yield self.store.get_users_in_room(old_room_id)
         yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index e7fc3f0431..bf5e0eb844 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -15,6 +15,7 @@
 import logging
 
 from canonicaljson import encode_canonical_json, json
+from signedjson.key import encode_verify_key_base64
 from signedjson.sign import sign_json
 
 from twisted.internet import defer
@@ -216,15 +217,28 @@ class RemoteKey(DirectServeResource):
         if cache_misses and query_remote_on_cache_miss:
             yield self.fetcher.get_keys(cache_misses)
             yield self.query_keys(request, query, query_remote_on_cache_miss=False)
-        else:
-            signed_keys = []
-            for key_json in json_results:
-                key_json = json.loads(key_json)
+            return
+
+        signed_keys = []
+        for key_json in json_results:
+            key_json = json.loads(key_json)
+
+            # backwards-compatibility hack for #6596: if the requested key belongs
+            # to us, make sure that all of the signing keys appear in the
+            # "verify_keys" section.
+            if key_json["server_name"] == self.config.server_name:
+                verify_keys = key_json["verify_keys"]
                 for signing_key in self.config.key_server_signing_keys:
-                    key_json = sign_json(key_json, self.config.server_name, signing_key)
+                    key_id = "%s:%s" % (signing_key.alg, signing_key.version)
+                    verify_keys[key_id] = {
+                        "key": encode_verify_key_base64(signing_key.verify_key)
+                    }
+
+            for signing_key in self.config.key_server_signing_keys:
+                key_json = sign_json(key_json, self.config.server_name, signing_key)
 
-                signed_keys.append(key_json)
+            signed_keys.append(key_json)
 
-            results = {"server_keys": signed_keys}
+        results = {"server_keys": signed_keys}
 
-            respond_with_json_bytes(request, 200, encode_canonical_json(results))
+        respond_with_json_bytes(request, 200, encode_canonical_json(results))
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index d20df5f076..092e803799 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -37,6 +37,8 @@ class DataStores(object):
         # store.
 
         self.databases = []
+        self.main = None
+        self.state = None
 
         for database_config in hs.config.database.databases:
             db_name = database_config.name
@@ -54,10 +56,22 @@ class DataStores(object):
 
                 if "main" in database_config.data_stores:
                     logger.info("Starting 'main' data store")
+
+                    # Sanity check we don't try and configure the main store on
+                    # multiple databases.
+                    if self.main:
+                        raise Exception("'main' data store already configured")
+
                     self.main = main_store_class(database, db_conn, hs)
 
                 if "state" in database_config.data_stores:
                     logger.info("Starting 'state' data store")
+
+                    # Sanity check we don't try and configure the state store on
+                    # multiple databases.
+                    if self.state:
+                        raise Exception("'state' data store already configured")
+
                     self.state = StateGroupDataStore(database, db_conn, hs)
 
                 db_conn.commit()
@@ -65,3 +79,10 @@ class DataStores(object):
                 self.databases.append(database)
 
                 logger.info("Database %r prepared", db_name)
+
+        # Sanity check that we have actually configured all the required stores.
+        if not self.main:
+            raise Exception("No 'main' data store configured")
+
+        if not self.state:
+            raise Exception("No 'main' data store configured")
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 0dc39f139c..d07440e3ed 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import collections.abc
 import logging
 from collections import namedtuple
 from typing import Iterable, Tuple
@@ -107,7 +107,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         predecessor = create_event.content.get("predecessor", None)
 
         # Ensure the key is a dictionary
-        if not isinstance(predecessor, dict):
+        if not isinstance(predecessor, collections.abc.Mapping):
             return None
 
         return predecessor