summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/device.py15
-rw-r--r--synapse/logging/opentracing.py8
-rw-r--r--synapse/module_api/__init__.py6
-rw-r--r--synapse/replication/http/_base.py154
-rw-r--r--synapse/storage/databases/main/client_ips.py153
-rw-r--r--synapse/storage/state.py172
6 files changed, 382 insertions, 126 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 75e6019760..6eafbea25d 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,7 +14,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+)
 
 from synapse.api import errors
 from synapse.api.constants import EventTypes
@@ -595,7 +606,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
 
 def _update_device_from_client_ips(
-    device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
+    device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 5276c4bfcc..20d23a4260 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -807,6 +807,14 @@ def trace(func=None, opname=None):
                         result.addCallbacks(call_back, err_back)
 
                     else:
+                        if inspect.isawaitable(result):
+                            logger.error(
+                                "@trace may not have wrapped %s correctly! "
+                                "The function is not async but returned a %s.",
+                                func.__qualname__,
+                                type(result).__name__,
+                            )
+
                         scope.__exit__(None, None, None)
 
                     return result
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 8ae21bc43c..b2a228c231 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -773,9 +773,9 @@ class ModuleApi:
             # Sanitize some of the data. We don't want to return tokens.
             return [
                 UserIpAndAgent(
-                    ip=str(data["ip"]),
-                    user_agent=str(data["user_agent"]),
-                    last_seen=int(data["last_seen"]),
+                    ip=data["ip"],
+                    user_agent=data["user_agent"],
+                    last_seen=data["last_seen"],
                 )
                 for data in raw_data
             ]
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index f1b78d09f9..e047ec74d8 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -182,85 +182,87 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             )
 
         @trace(opname="outgoing_replication_request")
-        @outgoing_gauge.track_inprogress()
         async def send_request(*, instance_name="master", **kwargs):
-            if instance_name == local_instance_name:
-                raise Exception("Trying to send HTTP request to self")
-            if instance_name == "master":
-                host = master_host
-                port = master_port
-            elif instance_name in instance_map:
-                host = instance_map[instance_name].host
-                port = instance_map[instance_name].port
-            else:
-                raise Exception(
-                    "Instance %r not in 'instance_map' config" % (instance_name,)
+            with outgoing_gauge.track_inprogress():
+                if instance_name == local_instance_name:
+                    raise Exception("Trying to send HTTP request to self")
+                if instance_name == "master":
+                    host = master_host
+                    port = master_port
+                elif instance_name in instance_map:
+                    host = instance_map[instance_name].host
+                    port = instance_map[instance_name].port
+                else:
+                    raise Exception(
+                        "Instance %r not in 'instance_map' config" % (instance_name,)
+                    )
+
+                data = await cls._serialize_payload(**kwargs)
+
+                url_args = [
+                    urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
+                ]
+
+                if cls.CACHE:
+                    txn_id = random_string(10)
+                    url_args.append(txn_id)
+
+                if cls.METHOD == "POST":
+                    request_func = client.post_json_get_json
+                elif cls.METHOD == "PUT":
+                    request_func = client.put_json
+                elif cls.METHOD == "GET":
+                    request_func = client.get_json
+                else:
+                    # We have already asserted in the constructor that a
+                    # compatible was picked, but lets be paranoid.
+                    raise Exception(
+                        "Unknown METHOD on %s replication endpoint" % (cls.NAME,)
+                    )
+
+                uri = "http://%s:%s/_synapse/replication/%s/%s" % (
+                    host,
+                    port,
+                    cls.NAME,
+                    "/".join(url_args),
                 )
 
-            data = await cls._serialize_payload(**kwargs)
-
-            url_args = [
-                urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
-            ]
-
-            if cls.CACHE:
-                txn_id = random_string(10)
-                url_args.append(txn_id)
-
-            if cls.METHOD == "POST":
-                request_func = client.post_json_get_json
-            elif cls.METHOD == "PUT":
-                request_func = client.put_json
-            elif cls.METHOD == "GET":
-                request_func = client.get_json
-            else:
-                # We have already asserted in the constructor that a
-                # compatible was picked, but lets be paranoid.
-                raise Exception(
-                    "Unknown METHOD on %s replication endpoint" % (cls.NAME,)
-                )
-
-            uri = "http://%s:%s/_synapse/replication/%s/%s" % (
-                host,
-                port,
-                cls.NAME,
-                "/".join(url_args),
-            )
-
-            try:
-                # We keep retrying the same request for timeouts. This is so that we
-                # have a good idea that the request has either succeeded or failed on
-                # the master, and so whether we should clean up or not.
-                while True:
-                    headers: Dict[bytes, List[bytes]] = {}
-                    # Add an authorization header, if configured.
-                    if replication_secret:
-                        headers[b"Authorization"] = [b"Bearer " + replication_secret]
-                    opentracing.inject_header_dict(headers, check_destination=False)
-                    try:
-                        result = await request_func(uri, data, headers=headers)
-                        break
-                    except RequestTimedOutError:
-                        if not cls.RETRY_ON_TIMEOUT:
-                            raise
-
-                    logger.warning("%s request timed out; retrying", cls.NAME)
-
-                    # If we timed out we probably don't need to worry about backing
-                    # off too much, but lets just wait a little anyway.
-                    await clock.sleep(1)
-            except HttpResponseException as e:
-                # We convert to SynapseError as we know that it was a SynapseError
-                # on the main process that we should send to the client. (And
-                # importantly, not stack traces everywhere)
-                _outgoing_request_counter.labels(cls.NAME, e.code).inc()
-                raise e.to_synapse_error()
-            except Exception as e:
-                _outgoing_request_counter.labels(cls.NAME, "ERR").inc()
-                raise SynapseError(502, "Failed to talk to main process") from e
-
-            _outgoing_request_counter.labels(cls.NAME, 200).inc()
-            return result
+                try:
+                    # We keep retrying the same request for timeouts. This is so that we
+                    # have a good idea that the request has either succeeded or failed
+                    # on the master, and so whether we should clean up or not.
+                    while True:
+                        headers: Dict[bytes, List[bytes]] = {}
+                        # Add an authorization header, if configured.
+                        if replication_secret:
+                            headers[b"Authorization"] = [
+                                b"Bearer " + replication_secret
+                            ]
+                        opentracing.inject_header_dict(headers, check_destination=False)
+                        try:
+                            result = await request_func(uri, data, headers=headers)
+                            break
+                        except RequestTimedOutError:
+                            if not cls.RETRY_ON_TIMEOUT:
+                                raise
+
+                        logger.warning("%s request timed out; retrying", cls.NAME)
+
+                        # If we timed out we probably don't need to worry about backing
+                        # off too much, but lets just wait a little anyway.
+                        await clock.sleep(1)
+                except HttpResponseException as e:
+                    # We convert to SynapseError as we know that it was a SynapseError
+                    # on the main process that we should send to the client. (And
+                    # importantly, not stack traces everywhere)
+                    _outgoing_request_counter.labels(cls.NAME, e.code).inc()
+                    raise e.to_synapse_error()
+                except Exception as e:
+                    _outgoing_request_counter.labels(cls.NAME, "ERR").inc()
+                    raise SynapseError(502, "Failed to talk to main process") from e
+
+                _outgoing_request_counter.labels(cls.NAME, 200).inc()
+                return result
 
         return send_request
 
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index c77acc7c84..b81d9218ce 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -13,14 +13,26 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
+
+from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
-from synapse.types import UserID
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_tuple_comparison_clause,
+)
+from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
+from synapse.storage.types import Connection
+from synapse.types import JsonDict, UserID
 from synapse.util.caches.lrucache import LruCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # Number of msec of granularity to store the user IP 'last seen' time. Smaller
@@ -29,8 +41,31 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
+class DeviceLastConnectionInfo(TypedDict):
+    """Metadata for the last connection seen for a user and device combination"""
+
+    # These types must match the columns in the `devices` table
+    user_id: str
+    device_id: str
+
+    ip: Optional[str]
+    user_agent: Optional[str]
+    last_seen: Optional[int]
+
+
+class LastConnectionInfo(TypedDict):
+    """Metadata for the last connection seen for an access token and IP combination"""
+
+    # These types must match the columns in the `user_ips` table
+    access_token: str
+    ip: str
+
+    user_agent: str
+    last_seen: int
+
+
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             "devices_last_seen", self._devices_last_seen_update
         )
 
-    async def _remove_user_ip_nonunique(self, progress, batch_size):
-        def f(conn):
+    async def _remove_user_ip_nonunique(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def f(conn: LoggingDatabaseConnection) -> None:
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
             txn.close()
@@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
         )
         return 1
 
-    async def _analyze_user_ip(self, progress, batch_size):
+    async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int:
         # Background update to analyze user_ips table before we run the
         # deduplication background update. The table may not have been analyzed
         # for ages due to the table locks.
         #
         # This will lock out the naive upserts to user_ips while it happens, but
         # the analyze should be quick (28GB table takes ~10s)
-        def user_ips_analyze(txn):
+        def user_ips_analyze(txn: LoggingTransaction) -> None:
             txn.execute("ANALYZE user_ips")
 
         await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
@@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
         return 1
 
-    async def _remove_user_ip_dupes(self, progress, batch_size):
+    async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int:
         # This works function works by scanning the user_ips table in batches
         # based on `last_seen`. For each row in a batch it searches the rest of
         # the table to see if there are any duplicates, if there are then they
         # are removed and replaced with a suitable row.
 
         # Fetch the start of the batch
-        begin_last_seen = progress.get("last_seen", 0)
+        begin_last_seen: int = progress.get("last_seen", 0)
 
-        def get_last_seen(txn):
+        def get_last_seen(txn: LoggingTransaction) -> Optional[int]:
             txn.execute(
                 """
                 SELECT last_seen FROM user_ips
@@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
                 """,
                 (begin_last_seen, batch_size),
             )
-            row = txn.fetchone()
+            row = cast(Optional[Tuple[int]], txn.fetchone())
             if row:
                 return row[0]
             else:
@@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             end_last_seen,
         )
 
-        def remove(txn):
+        def remove(txn: LoggingTransaction) -> None:
             # This works by looking at all entries in the given time span, and
             # then for each (user_id, access_token, ip) tuple in that range
             # checking for any duplicates in the rest of the table (via a join).
@@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
             # Define the search space, which requires handling the last batch in
             # a different way
+            args: Tuple[int, ...]
             if last:
                 clause = "? <= last_seen"
                 args = (begin_last_seen,)
             else:
+                assert end_last_seen is not None
                 clause = "? <= last_seen AND last_seen < ?"
                 args = (begin_last_seen, end_last_seen)
 
@@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
                 ),
                 args,
             )
-            res = txn.fetchall()
+            res = cast(
+                List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall()
+            )
 
             # We've got some duplicates
             for i in res:
@@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
         return batch_size
 
-    async def _devices_last_seen_update(self, progress, batch_size):
+    async def _devices_last_seen_update(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to insert last seen info into devices table"""
 
-        last_user_id = progress.get("last_user_id", "")
-        last_device_id = progress.get("last_device_id", "")
+        last_user_id: str = progress.get("last_user_id", "")
+        last_device_id: str = progress.get("last_device_id", "")
 
-        def _devices_last_seen_update_txn(txn):
+        def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:
             # This consists of two queries:
             #
             #   1. The sub-query searches for the next N devices and joins
@@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             #      we'll just end up updating the same device row multiple
             #      times, which is fine.
 
+            where_args: List[Union[str, int]]
             where_clause, where_args = make_tuple_comparison_clause(
                 [("user_id", last_user_id), ("device_id", last_device_id)],
             )
@@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             }
             txn.execute(sql, where_args + [batch_size])
 
-            rows = txn.fetchall()
+            rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
             if not rows:
                 return 0
 
@@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
 class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.user_ips_max_age = hs.config.server.user_ips_max_age
@@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
 
     @wrap_as_background_process("prune_old_user_ips")
-    async def _prune_old_user_ips(self):
+    async def _prune_old_user_ips(self) -> None:
         """Removes entries in user IPs older than the configured period."""
 
         if self.user_ips_max_age is None:
@@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             )
         """
 
-        timestamp = self.clock.time_msec() - self.user_ips_max_age
+        timestamp = self._clock.time_msec() - self.user_ips_max_age
 
-        def _prune_old_user_ips_txn(txn):
+        def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None:
             txn.execute(sql, (timestamp,))
 
         await self.db_pool.runInteraction(
@@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]
-    ) -> Dict[Tuple[str, str], dict]:
+    ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
         """For each device_id listed, give the user_ip it was last seen on.
 
         The result might be slightly out of date as client IPs are inserted in batches.
@@ -423,26 +467,32 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
         if device_id is not None:
             keyvalues["device_id"] = device_id
 
-        res = await self.db_pool.simple_select_list(
-            table="devices",
-            keyvalues=keyvalues,
-            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+        res = cast(
+            List[DeviceLastConnectionInfo],
+            await self.db_pool.simple_select_list(
+                table="devices",
+                keyvalues=keyvalues,
+                retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+            ),
         )
 
         return {(d["user_id"], d["device_id"]): d for d in res}
 
 
-class ClientIpStore(ClientIpWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
 
-        self.client_ip_last_seen = LruCache(
+        # (user_id, access_token, ip,) -> last_seen
+        self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
             cache_name="client_ip_last_seen", max_size=50000
         )
 
         super().__init__(database, db_conn, hs)
 
         # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
-        self._batch_row_update = {}
+        self._batch_row_update: Dict[
+            Tuple[str, str, str], Tuple[str, Optional[str], int]
+        ] = {}
 
         self._client_ip_looper = self._clock.looping_call(
             self._update_client_ips_batch, 5 * 1000
@@ -452,8 +502,14 @@ class ClientIpStore(ClientIpWorkerStore):
         )
 
     async def insert_client_ip(
-        self, user_id, access_token, ip, user_agent, device_id, now=None
-    ):
+        self,
+        user_id: str,
+        access_token: str,
+        ip: str,
+        user_agent: str,
+        device_id: Optional[str],
+        now: Optional[int] = None,
+    ) -> None:
         if not now:
             now = int(self._clock.time_msec())
         key = (user_id, access_token, ip)
@@ -485,7 +541,11 @@ class ClientIpStore(ClientIpWorkerStore):
             "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
         )
 
-    def _update_client_ips_batch_txn(self, txn, to_update):
+    def _update_client_ips_batch_txn(
+        self,
+        txn: LoggingTransaction,
+        to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
+    ) -> None:
         if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
             not self.database_engine.can_native_upsert
         ):
@@ -525,7 +585,7 @@ class ClientIpStore(ClientIpWorkerStore):
 
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]
-    ) -> Dict[Tuple[str, str], dict]:
+    ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
         """For each device_id listed, give the user_ip it was last seen on
 
         Args:
@@ -538,15 +598,20 @@ class ClientIpStore(ClientIpWorkerStore):
         """
         ret = await super().get_last_client_ip_by_device(user_id, device_id)
 
-        # Update what is retrieved from the database with data which is pending insertion.
+        # Update what is retrieved from the database with data which is pending
+        # insertion, as if it has already been stored in the database.
         for key in self._batch_row_update:
-            uid, access_token, ip = key
+            uid, _access_token, ip = key
             if uid == user_id:
                 user_agent, did, last_seen = self._batch_row_update[key]
+
+                if did is None:
+                    # These updates don't make it to the `devices` table
+                    continue
+
                 if not device_id or did == device_id:
-                    ret[(user_id, device_id)] = {
+                    ret[(user_id, did)] = {
                         "user_id": user_id,
-                        "access_token": access_token,
                         "ip": ip,
                         "user_agent": user_agent,
                         "device_id": did,
@@ -556,12 +621,12 @@ class ClientIpStore(ClientIpWorkerStore):
 
     async def get_user_ip_and_agents(
         self, user: UserID, since_ts: int = 0
-    ) -> List[Dict[str, Union[str, int]]]:
+    ) -> List[LastConnectionInfo]:
         """
         Fetch IP/User Agent connection since a given timestamp.
         """
         user_id = user.to_string()
-        results = {}
+        results: Dict[Tuple[str, str], Tuple[str, int]] = {}
 
         for key in self._batch_row_update:
             (
@@ -574,7 +639,7 @@ class ClientIpStore(ClientIpWorkerStore):
                 if last_seen >= since_ts:
                     results[(access_token, ip)] = (user_agent, last_seen)
 
-        def get_recent(txn):
+        def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
             txn.execute(
                 """
                 SELECT access_token, ip, user_agent, last_seen FROM user_ips
@@ -584,7 +649,7 @@ class ClientIpStore(ClientIpWorkerStore):
                 """,
                 (since_ts, user_id),
             )
-            return txn.fetchall()
+            return cast(List[Tuple[str, str, str, int]], txn.fetchall())
 
         rows = await self.db_pool.runInteraction(
             desc="get_user_ip_and_agents", func=get_recent
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 5e86befde4..b5ba1560d1 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -15,9 +15,11 @@ import logging
 from typing import (
     TYPE_CHECKING,
     Awaitable,
+    Collection,
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -29,7 +31,7 @@ from frozendict import frozendict
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateKey, StateMap
 
 if TYPE_CHECKING:
     from typing import FrozenSet  # noqa: used within quoted type hint; flake8 sad
@@ -134,6 +136,23 @@ class StateFilter:
             include_others=True,
         )
 
+    @staticmethod
+    def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
+        """
+        Returns a (frozen) StateFilter with the same contents as the parameters
+        specified here, which can be made of mutable types.
+        """
+        types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {}
+        for state_types, state_keys in types.items():
+            if state_keys is not None:
+                types_with_frozen_values[state_types] = frozenset(state_keys)
+            else:
+                types_with_frozen_values[state_types] = None
+
+        return StateFilter(
+            frozendict(types_with_frozen_values), include_others=include_others
+        )
+
     def return_expanded(self) -> "StateFilter":
         """Creates a new StateFilter where type wild cards have been removed
         (except for memberships). The returned filter is a superset of the
@@ -356,6 +375,157 @@ class StateFilter:
 
         return member_filter, non_member_filter
 
+    def _decompose_into_four_parts(
+        self,
+    ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]:
+        """
+        Decomposes this state filter into 4 constituent parts, which can be
+        thought of as this:
+            all? - minus_wildcards + plus_wildcards + plus_state_keys
+
+        where
+        * all represents ALL state
+        * minus_wildcards represents entire state types to remove
+        * plus_wildcards represents entire state types to add
+        * plus_state_keys represents individual state keys to add
+
+        See `recompose_from_four_parts` for the other direction of this
+        correspondence.
+        """
+        is_all = self.include_others
+        excluded_types: Set[str] = {t for t in self.types if is_all}
+        wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None}
+        concrete_keys: Set[StateKey] = set(self.concrete_types())
+
+        return (is_all, excluded_types), (wildcard_types, concrete_keys)
+
+    @staticmethod
+    def _recompose_from_four_parts(
+        all_part: bool,
+        minus_wildcards: Set[str],
+        plus_wildcards: Set[str],
+        plus_state_keys: Set[StateKey],
+    ) -> "StateFilter":
+        """
+        Recomposes a state filter from 4 parts.
+
+        See `decompose_into_four_parts` (the other direction of this
+        correspondence) for descriptions on each of the parts.
+        """
+
+        # {state type -> set of state keys OR None for wildcard}
+        # (The same structure as that of a StateFilter.)
+        new_types: Dict[str, Optional[Set[str]]] = {}
+
+        # if we start with all, insert the excluded statetypes as empty sets
+        # to prevent them from being included
+        if all_part:
+            new_types.update({state_type: set() for state_type in minus_wildcards})
+
+        # insert the plus wildcards
+        new_types.update({state_type: None for state_type in plus_wildcards})
+
+        # insert the specific state keys
+        for state_type, state_key in plus_state_keys:
+            if state_type in new_types:
+                entry = new_types[state_type]
+                if entry is not None:
+                    entry.add(state_key)
+            elif not all_part:
+                # don't insert if the entire type is already included by
+                # include_others as this would actually shrink the state allowed
+                # by this filter.
+                new_types[state_type] = {state_key}
+
+        return StateFilter.freeze(new_types, include_others=all_part)
+
+    def approx_difference(self, other: "StateFilter") -> "StateFilter":
+        """
+        Returns a state filter which represents `self - other`.
+
+        This is useful for determining what state remains to be pulled out of the
+        database if we want the state included by `self` but already have the state
+        included by `other`.
+
+        The returned state filter
+        - MUST include all state events that are included by this filter (`self`)
+          unless they are included by `other`;
+        - MUST NOT include state events not included by this filter (`self`); and
+        - MAY be an over-approximation: the returned state filter
+          MAY additionally include some state events from `other`.
+
+        This implementation attempts to return the narrowest such state filter.
+        In the case that `self` contains wildcards for state types where
+        `other` contains specific state keys, an approximation must be made:
+        the returned state filter keeps the wildcard, as state filters are not
+        able to express 'all state keys except some given examples'.
+        e.g.
+            StateFilter(m.room.member -> None (wildcard))
+                minus
+            StateFilter(m.room.member -> {'@wombat:example.org'})
+                is approximated as
+            StateFilter(m.room.member -> None (wildcard))
+        """
+
+        # We first transform self and other into an alternative representation:
+        #   - whether or not they include all events to begin with ('all')
+        #   - if so, which event types are excluded? ('excludes')
+        #   - which entire event types to include ('wildcards')
+        #   - which concrete state keys to include ('concrete state keys')
+        (self_all, self_excludes), (
+            self_wildcards,
+            self_concrete_keys,
+        ) = self._decompose_into_four_parts()
+        (other_all, other_excludes), (
+            other_wildcards,
+            other_concrete_keys,
+        ) = other._decompose_into_four_parts()
+
+        # Start with an estimate of the difference based on self
+        new_all = self_all
+        # Wildcards from the other can be added to the exclusion filter
+        new_excludes = self_excludes | other_wildcards
+        # We remove wildcards that appeared as wildcards in the other
+        new_wildcards = self_wildcards - other_wildcards
+        # We filter out the concrete state keys that appear in the other
+        # as wildcards or concrete state keys.
+        new_concrete_keys = {
+            (state_type, state_key)
+            for (state_type, state_key) in self_concrete_keys
+            if state_type not in other_wildcards
+        } - other_concrete_keys
+
+        if other_all:
+            if self_all:
+                # If self starts with all, then we add as wildcards any
+                # types which appear in the other's exclusion filter (but
+                # aren't in the self exclusion filter). This is as the other
+                # filter will return everything BUT the types in its exclusion, so
+                # we need to add those excluded types that also match the self
+                # filter as wildcard types in the new filter.
+                new_wildcards |= other_excludes.difference(self_excludes)
+
+            # If other is an `include_others` then the difference isn't.
+            new_all = False
+            # (We have no need for excludes when we don't start with all, as there
+            #  is nothing to exclude.)
+            new_excludes = set()
+
+            # We also filter out all state types that aren't in the exclusion
+            # list of the other.
+            new_wildcards &= other_excludes
+            new_concrete_keys = {
+                (state_type, state_key)
+                for (state_type, state_key) in new_concrete_keys
+                if state_type in other_excludes
+            }
+
+        # Transform our newly-constructed state filter from the alternative
+        # representation back into the normal StateFilter representation.
+        return StateFilter._recompose_from_four_parts(
+            new_all, new_excludes, new_wildcards, new_concrete_keys
+        )
+
 
 class StateGroupStorage:
     """High level interface to fetching state for event."""