summary refs log tree commit diff
path: root/synapse/storage/event_federation.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2017-05-24 14:22:41 +0100
committerErik Johnston <erik@matrix.org>2017-05-24 15:23:31 +0100
commitc049472b8ad75d1d9a627803cd698cfe8c5570b8 (patch)
tree1af3e88e793ec2b47bc5cfacfe876e4319f4e026 /synapse/storage/event_federation.py
parentMerge pull request #2242 from matrix-org/erikj/email_refactor (diff)
downloadsynapse-c049472b8ad75d1d9a627803cd698cfe8c5570b8.tar.xz
Only store event_auth for state events
Diffstat (limited to 'synapse/storage/event_federation.py')
-rw-r--r--synapse/storage/event_federation.py35
1 files changed, 29 insertions, 6 deletions
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 519059c306..72126c682e 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -44,18 +44,41 @@ class EventFederationStore(SQLBaseStore):
             self._delete_old_forward_extrem_cache, 60 * 60 * 1000
         )
 
-    def get_auth_chain(self, event_ids):
-        return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
+    def get_auth_chain(self, event_ids, include_given=False):
+        """Get auth events for given event_ids. The events *must* be state events.
 
-    def get_auth_chain_ids(self, event_ids):
+        Args:
+            event_ids (list): state events
+            include_given (bool): include the given events in result
+
+        Returns:
+            list of events
+        """
+        return self.get_auth_chain_ids(
+            event_ids, include_given=include_given,
+        ).addCallback(self._get_events)
+
+    def get_auth_chain_ids(self, event_ids, include_given=False):
+        """Get auth events for given event_ids. The events *must* be state events.
+
+        Args:
+            event_ids (list): state events
+            include_given (bool): include the given events in result
+
+        Returns:
+            list of event_ids
+        """
         return self.runInteraction(
             "get_auth_chain_ids",
             self._get_auth_chain_ids_txn,
-            event_ids
+            event_ids, include_given
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_ids):
-        results = set()
+    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+        if include_given:
+            results = set(event_ids)
+        else:
+            results = set()
 
         base_sql = (
             "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"