| diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index e4ce087afe..c20c89aa8f 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -87,13 +87,26 @@ class LoggingContext(object):
         """Get the current logging context from thread local storage"""
         return getattr(cls.thread_local, "current_context", cls.sentinel)
 
+    @classmethod
+    def set_current_context(cls, context):
+        """Set the current logging context in thread local storage
+        Args:
+            context(LoggingContext): The context to activate.
+        Returns:
+            The context that was previously active
+        """
+        current = cls.current_context()
+        if current is not context:
+            current.stop()
+            cls.thread_local.current_context = context
+            context.start()
+        return current
+
     def __enter__(self):
         """Enters this logging context into thread local storage"""
         if self.parent_context is not None:
             raise Exception("Attempt to enter logging context multiple times")
-        self.parent_context = self.current_context()
-        self.thread_local.current_context = self
-        self.start()
+        self.parent_context = self.set_current_context(self)
         return self
 
     def __exit__(self, type, value, traceback):
@@ -102,17 +115,16 @@ class LoggingContext(object):
         Returns:
             None to avoid suppressing any exeptions that were thrown.
         """
-        if self.thread_local.current_context is not self:
-            if self.thread_local.current_context is self.sentinel:
+        current = self.set_current_context(self.parent_context)
+        if current is not self:
+            if current is self.sentinel:
                 logger.debug("Expected logging context %s has been lost", self)
             else:
                 logger.warn(
                     "Current logging context %s is not expected context %s",
-                    self.thread_local.current_context,
+                    current,
                     self
                 )
-        self.thread_local.current_context = self.parent_context
-        self.stop()
         self.parent_context = None
 
     def __getattr__(self, name):
@@ -194,17 +206,13 @@ class PreserveLoggingContext(object):
 
     def __enter__(self):
         """Captures the current logging context"""
-        self.current_context = LoggingContext.current_context()
-        if self.new_context is not self.current_context:
-            self.current_context.stop()
-        LoggingContext.thread_local.current_context = self.new_context
+        self.current_context = LoggingContext.set_current_context(
+            self.new_context
+        )
 
     def __exit__(self, type, value, traceback):
         """Restores the current logging context"""
-        context = LoggingContext.thread_local.current_context
-        LoggingContext.thread_local.current_context = self.current_context
-        if context is not self.current_context:
-            self.current_context.start()
+        LoggingContext.set_current_context(self.current_context)
         if self.current_context is not LoggingContext.sentinel:
             if self.current_context.parent_context is None:
                 logger.warn(
 |