summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/urls.py2
-rwxr-xr-xsynapse/app/homeserver.py1
-rw-r--r--synapse/event_auth.py11
-rw-r--r--synapse/events/__init__.py13
-rw-r--r--synapse/federation/transaction_queue.py36
-rw-r--r--synapse/handlers/federation.py78
-rw-r--r--synapse/handlers/message.py25
-rw-r--r--synapse/handlers/profile.py8
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/handlers/typing.py1
-rw-r--r--synapse/http/site.py27
-rw-r--r--synapse/metrics/background_process_metrics.py8
-rw-r--r--synapse/python_dependencies.py5
-rw-r--r--synapse/storage/events.py102
-rw-r--r--synapse/storage/transactions.py23
-rw-r--r--synapse/util/caches/expiringcache.py24
17 files changed, 240 insertions, 130 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index b1f7a89fba..43c5821ade 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -27,4 +27,4 @@ try:
 except ImportError:
     pass
 
-__version__ = "0.33.5.1"
+__version__ = "0.33.6"
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 71347912f1..6d9f1ca0ef 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -64,7 +64,7 @@ class ConsentURIBuilder(object):
         """
         mac = hmac.new(
             key=self._hmac_secret,
-            msg=user_id,
+            msg=user_id.encode('ascii'),
             digestmod=sha256,
         ).hexdigest()
         consent_uri = "%s_matrix/consent?%s" % (
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index a98fdbd210..e3f0d99a3f 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -386,7 +386,6 @@ def setup(config_options):
         hs.get_pusherpool().start()
         hs.get_datastore().start_profiling()
         hs.get_datastore().start_doing_background_updates()
-        hs.get_federation_client().start_get_pdu_cache()
 
     reactor.callWhenRunning(start)
 
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 6baeccca38..af3eee95b9 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -98,9 +98,9 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
     creation_event = auth_events.get((EventTypes.Create, ""), None)
 
     if not creation_event:
-        raise SynapseError(
+        raise AuthError(
             403,
-            "Room %r does not exist" % (event.room_id,)
+            "No create event in auth events",
         )
 
     creating_domain = get_domain_from_id(event.room_id)
@@ -155,10 +155,7 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
 
         if user_level < invite_level:
             raise AuthError(
-                403, (
-                    "You cannot issue a third party invite for %s." %
-                    (event.content.display_name,)
-                )
+                403, "You don't have permission to invite users",
             )
         else:
             logger.debug("Allowing! %s", event)
@@ -305,7 +302,7 @@ def _is_membership_change_allowed(event, auth_events):
 
             if user_level < invite_level:
                 raise AuthError(
-                    403, "You cannot invite user %s." % target_user_id
+                    403, "You don't have permission to invite users",
                 )
     elif Membership.JOIN == membership:
         # Joins are valid iff caller == target and they were:
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index b782af6308..12f1eb0a3e 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -13,15 +13,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import os
+from distutils.util import strtobool
+
 import six
 
 from synapse.util.caches import intern_dict
 from synapse.util.frozenutils import freeze
 
 # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
-# bugs where we accidentally share e.g. signature dicts. However, converting
-# a dict to frozen_dicts is expensive.
-USE_FROZEN_DICTS = True
+# bugs where we accidentally share e.g. signature dicts. However, converting a
+# dict to frozen_dicts is expensive.
+#
+# NOTE: This is overridden by the configuration by the Synapse worker apps, but
+# for the sake of tests, it is set here while it cannot be configured on the
+# homeserver object itself.
+USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
 
 
 class _EventInternalMetadata(object):
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 8cbf8c4f7f..98b5950800 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -137,26 +137,6 @@ class TransactionQueue(object):
 
         self._processing_pending_presence = False
 
-    def can_send_to(self, destination):
-        """Can we send messages to the given server?
-
-        We can't send messages to ourselves. If we are running on localhost
-        then we can only federation with other servers running on localhost.
-        Otherwise we only federate with servers on a public domain.
-
-        Args:
-            destination(str): The server we are possibly trying to send to.
-        Returns:
-            bool: True if we can send to the server.
-        """
-
-        if destination == self.server_name:
-            return False
-        if self.server_name.startswith("localhost"):
-            return destination.startswith("localhost")
-        else:
-            return not destination.startswith("localhost")
-
     def notify_new_events(self, current_id):
         """This gets called when we have some new events we might want to
         send out to other servers.
@@ -279,10 +259,7 @@ class TransactionQueue(object):
         self._order += 1
 
         destinations = set(destinations)
-        destinations = set(
-            dest for dest in destinations if self.can_send_to(dest)
-        )
-
+        destinations.discard(self.server_name)
         logger.debug("Sending to: %s", str(destinations))
 
         if not destinations:
@@ -358,7 +335,7 @@ class TransactionQueue(object):
 
         for destinations, states in hosts_and_states:
             for destination in destinations:
-                if not self.can_send_to(destination):
+                if destination == self.server_name:
                     continue
 
                 self.pending_presence_by_dest.setdefault(
@@ -377,7 +354,8 @@ class TransactionQueue(object):
             content=content,
         )
 
-        if not self.can_send_to(destination):
+        if destination == self.server_name:
+            logger.info("Not sending EDU to ourselves")
             return
 
         sent_edus_counter.inc()
@@ -392,10 +370,8 @@ class TransactionQueue(object):
         self._attempt_new_transaction(destination)
 
     def send_device_messages(self, destination):
-        if destination == self.server_name or destination == "localhost":
-            return
-
-        if not self.can_send_to(destination):
+        if destination == self.server_name:
+            logger.info("Not sending device update to ourselves")
             return
 
         self._attempt_new_transaction(destination)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 38bebbf598..45d955e6f5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -18,7 +18,6 @@
 
 import itertools
 import logging
-import sys
 
 import six
 from six import iteritems, itervalues
@@ -106,7 +105,7 @@ class FederationHandler(BaseHandler):
 
         self.hs = hs
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastore()  # type: synapse.storage.DataStore
         self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
@@ -323,14 +322,22 @@ class FederationHandler(BaseHandler):
                         affected=pdu.event_id,
                     )
 
-                # Calculate the state of the previous events, and
-                # de-conflict them to find the current state.
-                state_groups = []
+                # Calculate the state after each of the previous events, and
+                # resolve them to find the correct state at the current event.
                 auth_chains = set()
+                event_map = {
+                    event_id: pdu,
+                }
                 try:
                     # Get the state of the events we know about
-                    ours = yield self.store.get_state_groups(room_id, list(seen))
-                    state_groups.append(ours)
+                    ours = yield self.store.get_state_groups_ids(room_id, seen)
+
+                    # state_maps is a list of mappings from (type, state_key) to event_id
+                    # type: list[dict[tuple[str, str], str]]
+                    state_maps = list(ours.values())
+
+                    # we don't need this any more, let's delete it.
+                    del ours
 
                     # Ask the remote server for the states we don't
                     # know about
@@ -350,28 +357,54 @@ class FederationHandler(BaseHandler):
                                 )
                             )
 
+                            # we want the state *after* p; get_state_for_room returns the
+                            # state *before* p.
+                            remote_event = yield self.federation_client.get_pdu(
+                                [origin], p, outlier=True,
+                            )
+
+                            if remote_event is None:
+                                raise Exception(
+                                    "Unable to get missing prev_event %s" % (p, )
+                                )
+
+                            if remote_event.is_state():
+                                remote_state.append(remote_event)
+
                             # XXX hrm I'm not convinced that duplicate events will compare
                             # for equality, so I'm not sure this does what the author
                             # hoped.
                             auth_chains.update(got_auth_chain)
 
-                            state_group = {
+                            remote_state_map = {
                                 (x.type, x.state_key): x.event_id for x in remote_state
                             }
-                            state_groups.append(state_group)
+                            state_maps.append(remote_state_map)
+
+                            for x in remote_state:
+                                event_map[x.event_id] = x
 
                     # Resolve any conflicting state
+                    @defer.inlineCallbacks
                     def fetch(ev_ids):
-                        return self.store.get_events(
-                            ev_ids, get_prev_content=False, check_redacted=False
+                        fetched = yield self.store.get_events(
+                            ev_ids, get_prev_content=False, check_redacted=False,
                         )
+                        # add any events we fetch here to the `event_map` so that we
+                        # can use them to build the state event list below.
+                        event_map.update(fetched)
+                        defer.returnValue(fetched)
 
                     room_version = yield self.store.get_room_version(room_id)
                     state_map = yield resolve_events_with_factory(
-                        room_version, state_groups, {event_id: pdu}, fetch
+                        room_version, state_maps, event_map, fetch,
                     )
 
-                    state = (yield self.store.get_events(state_map.values())).values()
+                    # we need to give _process_received_pdu the actual state events
+                    # rather than event ids, so generate that now.
+                    state = [
+                        event_map[e] for e in six.itervalues(state_map)
+                    ]
                     auth_chain = list(auth_chains)
                 except Exception:
                     logger.warn(
@@ -1568,6 +1601,9 @@ class FederationHandler(BaseHandler):
             auth_events=auth_events,
         )
 
+        # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
+        # hack around with a try/finally instead.
+        success = False
         try:
             if not event.internal_metadata.is_outlier() and not backfilled:
                 yield self.action_generator.handle_push_actions_for_event(
@@ -1578,15 +1614,13 @@ class FederationHandler(BaseHandler):
                 [(event, context)],
                 backfilled=backfilled,
             )
-        except:  # noqa: E722, as we reraise the exception this is fine.
-            tp, value, tb = sys.exc_info()
-
-            logcontext.run_in_background(
-                self.store.remove_push_actions_from_staging,
-                event.event_id,
-            )
-
-            six.reraise(tp, value, tb)
+            success = True
+        finally:
+            if not success:
+                logcontext.run_in_background(
+                    self.store.remove_push_actions_from_staging,
+                    event.event_id,
+                )
 
         defer.returnValue(context)
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e484061cc0..4954b23a0d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -14,9 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import sys
 
-import six
 from six import iteritems, itervalues, string_types
 
 from canonicaljson import encode_canonical_json, json
@@ -624,6 +622,9 @@ class EventCreationHandler(object):
             event, context
         )
 
+        # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
+        # hack around with a try/finally instead.
+        success = False
         try:
             # If we're a worker we need to hit out to the master.
             if self.config.worker_app:
@@ -636,6 +637,7 @@ class EventCreationHandler(object):
                     ratelimit=ratelimit,
                     extra_users=extra_users,
                 )
+                success = True
                 return
 
             yield self.persist_and_notify_client_event(
@@ -645,17 +647,16 @@ class EventCreationHandler(object):
                 ratelimit=ratelimit,
                 extra_users=extra_users,
             )
-        except:  # noqa: E722, as we reraise the exception this is fine.
-            # Ensure that we actually remove the entries in the push actions
-            # staging area, if we calculated them.
-            tp, value, tb = sys.exc_info()
-
-            run_in_background(
-                self.store.remove_push_actions_from_staging,
-                event.event_id,
-            )
 
-            six.reraise(tp, value, tb)
+            success = True
+        finally:
+            if not success:
+                # Ensure that we actually remove the entries in the push actions
+                # staging area, if we calculated them.
+                run_in_background(
+                    self.store.remove_push_actions_from_staging,
+                    event.event_id,
+                )
 
     @defer.inlineCallbacks
     def persist_and_notify_client_event(
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index f284d5a385..1dfbde84fd 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -142,10 +142,8 @@ class BaseProfileHandler(BaseHandler):
                 if e.code != 404:
                     logger.exception("Failed to get displayname")
                 raise
-            except Exception:
-                logger.exception("Failed to get displayname")
-            else:
-                defer.returnValue(result["displayname"])
+
+            defer.returnValue(result["displayname"])
 
     @defer.inlineCallbacks
     def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
@@ -199,8 +197,6 @@ class BaseProfileHandler(BaseHandler):
                 if e.code != 404:
                     logger.exception("Failed to get avatar_url")
                 raise
-            except Exception:
-                logger.exception("Failed to get avatar_url")
 
             defer.returnValue(result["avatar_url"])
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c7d69d9d80..67b8ca28c7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -567,13 +567,13 @@ class SyncHandler(object):
         # be a valid name or canonical_alias - i.e. we're checking that they
         # haven't been "deleted" by blatting {} over the top.
         if name_id:
-            name = yield self.store.get_event(name_id, allow_none=False)
+            name = yield self.store.get_event(name_id, allow_none=True)
             if name and name.content:
                 defer.returnValue(summary)
 
         if canonical_alias_id:
             canonical_alias = yield self.store.get_event(
-                canonical_alias_id, allow_none=False,
+                canonical_alias_id, allow_none=True,
             )
             if canonical_alias and canonical_alias.content:
                 defer.returnValue(summary)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 65f475d639..c610933dd4 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -224,6 +224,7 @@ class TypingHandler(object):
 
             for domain in set(get_domain_from_id(u) for u in users):
                 if domain != self.server_name:
+                    logger.debug("sending typing update to %s", domain)
                     self.federation.send_edu(
                         destination=domain,
                         edu_type="m.typing",
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 50be2de3bb..e508c0bd4f 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -75,14 +75,14 @@ class SynapseRequest(Request):
         return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
             self.__class__.__name__,
             id(self),
-            self.method.decode('ascii', errors='replace'),
+            self.get_method(),
             self.get_redacted_uri(),
             self.clientproto.decode('ascii', errors='replace'),
             self.site.site_tag,
         )
 
     def get_request_id(self):
-        return "%s-%i" % (self.method.decode('ascii'), self.request_seq)
+        return "%s-%i" % (self.get_method(), self.request_seq)
 
     def get_redacted_uri(self):
         uri = self.uri
@@ -90,6 +90,21 @@ class SynapseRequest(Request):
             uri = self.uri.decode('ascii')
         return redact_uri(uri)
 
+    def get_method(self):
+        """Gets the method associated with the request (or placeholder if not
+        method has yet been received).
+
+        Note: This is necessary as the placeholder value in twisted is str
+        rather than bytes, so we need to sanitise `self.method`.
+
+        Returns:
+            str
+        """
+        method = self.method
+        if isinstance(method, bytes):
+            method = self.method.decode('ascii')
+        return method
+
     def get_user_agent(self):
         return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
 
@@ -119,7 +134,7 @@ class SynapseRequest(Request):
             # dispatching to the handler, so that the handler
             # can update the servlet name in the request
             # metrics
-            requests_counter.labels(self.method.decode('ascii'),
+            requests_counter.labels(self.get_method(),
                                     self.request_metrics.name).inc()
 
     @contextlib.contextmanager
@@ -207,14 +222,14 @@ class SynapseRequest(Request):
         self.start_time = time.time()
         self.request_metrics = RequestMetrics()
         self.request_metrics.start(
-            self.start_time, name=servlet_name, method=self.method.decode('ascii'),
+            self.start_time, name=servlet_name, method=self.get_method(),
         )
 
         self.site.access_logger.info(
             "%s - %s - Received request: %s %s",
             self.getClientIP(),
             self.site.site_tag,
-            self.method.decode('ascii'),
+            self.get_method(),
             self.get_redacted_uri()
         )
 
@@ -280,7 +295,7 @@ class SynapseRequest(Request):
             int(usage.db_txn_count),
             self.sentLength,
             code,
-            self.method.decode('ascii'),
+            self.get_method(),
             self.get_redacted_uri(),
             self.clientproto.decode('ascii', errors='replace'),
             user_agent,
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 173908299c..037f1c490e 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -101,9 +101,13 @@ class _Collector(object):
             labels=["name"],
         )
 
-        # We copy the dict so that it doesn't change from underneath us
+        # We copy the dict so that it doesn't change from underneath us.
+        # We also copy the process lists as that can also change
         with _bg_metrics_lock:
-            _background_processes_copy = dict(_background_processes)
+            _background_processes_copy = {
+                k: list(v)
+                for k, v in six.iteritems(_background_processes)
+            }
 
         for desc, processes in six.iteritems(_background_processes_copy):
             background_process_in_flight_count.add_metric(
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 0f339a0320..d4d983b00a 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -58,7 +58,10 @@ REQUIREMENTS = {
     "msgpack-python>=0.3.0": ["msgpack"],
     "phonenumbers>=8.2.0": ["phonenumbers"],
     "six>=1.10": ["six"],
-    "prometheus_client>=0.0.18": ["prometheus_client"],
+
+    # prometheus_client 0.4.0 changed the format of counter metrics
+    # (cf https://github.com/matrix-org/synapse/issues/4001)
+    "prometheus_client>=0.0.18,<0.4.0": ["prometheus_client"],
 
     # we use attr.s(slots), which arrived in 16.0.0
     "attrs>=16.0.0": ["attr>=16.0.0"],
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index e7487311ce..03cedf3a75 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -38,6 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.event_federation import EventFederationStore
 from synapse.storage.events_worker import EventsWorkerStore
 from synapse.types import RoomStreamToken, get_domain_from_id
+from synapse.util import batch_iter
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.frozenutils import frozendict_json_encoder
@@ -386,12 +387,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
                             )
 
                         for room_id, ev_ctx_rm in iteritems(events_by_room):
-                            # Work out new extremities by recursively adding and removing
-                            # the new events.
                             latest_event_ids = yield self.get_latest_event_ids_in_room(
                                 room_id
                             )
-                            new_latest_event_ids = yield self._calculate_new_extremeties(
+                            new_latest_event_ids = yield self._calculate_new_extremities(
                                 room_id, ev_ctx_rm, latest_event_ids
                             )
 
@@ -400,6 +399,12 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
                                 # No change in extremities, so no change in state
                                 continue
 
+                            # there should always be at least one forward extremity.
+                            # (except during the initial persistence of the send_join
+                            # results, in which case there will be no existing
+                            # extremities, so we'll `continue` above and skip this bit.)
+                            assert new_latest_event_ids, "No forward extremities left!"
+
                             new_forward_extremeties[room_id] = new_latest_event_ids
 
                             len_1 = (
@@ -517,44 +522,79 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
                     )
 
     @defer.inlineCallbacks
-    def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
-        """Calculates the new forward extremeties for a room given events to
+    def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
+        """Calculates the new forward extremities for a room given events to
         persist.
 
         Assumes that we are only persisting events for one room at a time.
         """
-        new_latest_event_ids = set(latest_event_ids)
-        # First, add all the new events to the list
-        new_latest_event_ids.update(
-            event.event_id for event, ctx in event_contexts
+
+        # we're only interested in new events which aren't outliers and which aren't
+        # being rejected.
+        new_events = [
+            event for event, ctx in event_contexts
             if not event.internal_metadata.is_outlier() and not ctx.rejected
+        ]
+
+        # start with the existing forward extremities
+        result = set(latest_event_ids)
+
+        # add all the new events to the list
+        result.update(
+            event.event_id for event in new_events
         )
-        # Now remove all events that are referenced by the to-be-added events
-        new_latest_event_ids.difference_update(
+
+        # Now remove all events which are prev_events of any of the new events
+        result.difference_update(
             e_id
-            for event, ctx in event_contexts
+            for event in new_events
             for e_id, _ in event.prev_events
-            if not event.internal_metadata.is_outlier() and not ctx.rejected
         )
 
-        # And finally remove any events that are referenced by previously added
-        # events.
-        rows = yield self._simple_select_many_batch(
-            table="event_edges",
-            column="prev_event_id",
-            iterable=list(new_latest_event_ids),
-            retcols=["prev_event_id"],
-            keyvalues={
-                "is_state": False,
-            },
-            desc="_calculate_new_extremeties",
-        )
+        # Finally, remove any events which are prev_events of any existing events.
+        existing_prevs = yield self._get_events_which_are_prevs(result)
+        result.difference_update(existing_prevs)
 
-        new_latest_event_ids.difference_update(
-            row["prev_event_id"] for row in rows
-        )
+        defer.returnValue(result)
 
-        defer.returnValue(new_latest_event_ids)
+    @defer.inlineCallbacks
+    def _get_events_which_are_prevs(self, event_ids):
+        """Filter the supplied list of event_ids to get those which are prev_events of
+        existing (non-outlier/rejected) events.
+
+        Args:
+            event_ids (Iterable[str]): event ids to filter
+
+        Returns:
+            Deferred[List[str]]: filtered event ids
+        """
+        results = []
+
+        def _get_events(txn, batch):
+            sql = """
+            SELECT prev_event_id
+            FROM event_edges
+                INNER JOIN events USING (event_id)
+                LEFT JOIN rejections USING (event_id)
+            WHERE
+                prev_event_id IN (%s)
+                AND NOT events.outlier
+                AND rejections.event_id IS NULL
+            """ % (
+                ",".join("?" for _ in batch),
+            )
+
+            txn.execute(sql, batch)
+            results.extend(r[0] for r in txn)
+
+        for chunk in batch_iter(event_ids, 100):
+            yield self.runInteraction(
+                "_get_events_which_are_prevs",
+                _get_events,
+                chunk,
+            )
+
+        defer.returnValue(results)
 
     @defer.inlineCallbacks
     def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
@@ -586,10 +626,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
             the new current state is only returned if we've already calculated
             it.
         """
-
-        if not new_latest_event_ids:
-            return
-
         # map from state_group to ((type, key) -> event_id) state map
         state_groups_map = {}
 
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index baf0379a68..a3032cdce9 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -23,6 +23,7 @@ from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.caches.expiringcache import ExpiringCache
 
 from ._base import SQLBaseStore, db_to_json
 
@@ -49,6 +50,8 @@ _UpdateTransactionRow = namedtuple(
     )
 )
 
+SENTINEL = object()
+
 
 class TransactionStore(SQLBaseStore):
     """A collection of queries for handling PDUs.
@@ -59,6 +62,12 @@ class TransactionStore(SQLBaseStore):
 
         self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
 
+        self._destination_retry_cache = ExpiringCache(
+            cache_name="get_destination_retry_timings",
+            clock=self._clock,
+            expiry_ms=5 * 60 * 1000,
+        )
+
     def get_received_txn_response(self, transaction_id, origin):
         """For an incoming transaction from a given origin, check if we have
         already responded to it. If so, return the response code and response
@@ -155,6 +164,7 @@ class TransactionStore(SQLBaseStore):
         """
         pass
 
+    @defer.inlineCallbacks
     def get_destination_retry_timings(self, destination):
         """Gets the current retry timings (if any) for a given destination.
 
@@ -165,10 +175,20 @@ class TransactionStore(SQLBaseStore):
             None if not retrying
             Otherwise a dict for the retry scheme
         """
-        return self.runInteraction(
+
+        result = self._destination_retry_cache.get(destination, SENTINEL)
+        if result is not SENTINEL:
+            defer.returnValue(result)
+
+        result = yield self.runInteraction(
             "get_destination_retry_timings",
             self._get_destination_retry_timings, destination)
 
+        # We don't hugely care about race conditions between getting and
+        # invalidating the cache, since we time out fairly quickly anyway.
+        self._destination_retry_cache[destination] = result
+        defer.returnValue(result)
+
     def _get_destination_retry_timings(self, txn, destination):
         result = self._simple_select_one_txn(
             txn,
@@ -196,6 +216,7 @@ class TransactionStore(SQLBaseStore):
             retry_interval (int) - how long until next retry in ms
         """
 
+        self._destination_retry_cache.pop(destination, None)
         return self.runInteraction(
             "set_destination_retry_timings",
             self._set_destination_retry_timings,
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 9af4ec4aa8..f369780277 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -16,7 +16,7 @@
 import logging
 from collections import OrderedDict
 
-from six import itervalues
+from six import iteritems, itervalues
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.caches import register_cache
@@ -24,6 +24,9 @@ from synapse.util.caches import register_cache
 logger = logging.getLogger(__name__)
 
 
+SENTINEL = object()
+
+
 class ExpiringCache(object):
     def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
                  reset_expiry_on_get=False, iterable=False):
@@ -95,6 +98,21 @@ class ExpiringCache(object):
 
         return entry.value
 
+    def pop(self, key, default=SENTINEL):
+        """Removes and returns the value with the given key from the cache.
+
+        If the key isn't in the cache then `default` will be returned if
+        specified, otherwise `KeyError` will get raised.
+
+        Identical functionality to `dict.pop(..)`.
+        """
+
+        value = self._cache.pop(key, default)
+        if value is SENTINEL:
+            raise KeyError(key)
+
+        return value
+
     def __contains__(self, key):
         return key in self._cache
 
@@ -122,7 +140,7 @@ class ExpiringCache(object):
 
         keys_to_delete = set()
 
-        for key, cache_entry in self._cache.items():
+        for key, cache_entry in iteritems(self._cache):
             if now - cache_entry.time > self._expiry_ms:
                 keys_to_delete.add(key)
 
@@ -146,6 +164,8 @@ class ExpiringCache(object):
 
 
 class _CacheEntry(object):
+    __slots__ = ["time", "value"]
+
     def __init__(self, time, value):
         self.time = time
         self.value = value