summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation.py62
-rw-r--r--synapse/storage/data_stores/main/events_worker.py34
-rw-r--r--synapse/storage/data_stores/main/room.py12
3 files changed, 34 insertions, 74 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d9d0cd9eef..7784b80b77 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2122,14 +2122,9 @@ class FederationHandler(BaseHandler):
         #
         # we start by checking if they are in the store, and then try calling /event_auth/.
         if missing_auth:
-            # TODO: can we use store.have_seen_events here instead?
-            have_events = yield self.store.get_seen_events_with_rejections(missing_auth)
-            logger.debug("Found events %s in the store", have_events)
-            missing_auth.difference_update(have_events.keys())
-        else:
-            have_events = {}
-
-        have_events.update({e.event_id: "" for e in auth_events.values()})
+            have_events = yield self.store.have_seen_events(missing_auth)
+            logger.debug("Events %s are in the store", have_events)
+            missing_auth.difference_update(have_events)
 
         if missing_auth:
             # If we don't have all the auth events, we need to get them.
@@ -2175,9 +2170,6 @@ class FederationHandler(BaseHandler):
                     except AuthError:
                         pass
 
-                have_events = yield self.store.get_seen_events_with_rejections(
-                    event.auth_event_ids()
-                )
             except Exception:
                 logger.exception("Failed to get auth chain")
 
@@ -2207,39 +2199,33 @@ class FederationHandler(BaseHandler):
         # idea of them.
 
         room_version = yield self.store.get_room_version(event.room_id)
-        different_event_ids = [
-            d for d in different_auth if d in have_events and not have_events[d]
-        ]
 
-        if different_event_ids:
-            # XXX: currently this checks for redactions but I'm not convinced that is
-            # necessary?
-            different_events = yield self.store.get_events_as_list(different_event_ids)
+        # XXX: currently this checks for redactions but I'm not convinced that is
+        # necessary?
+        different_events = yield self.store.get_events_as_list(different_auth)
 
-            local_view = dict(auth_events)
-            remote_view = dict(auth_events)
-            remote_view.update({(d.type, d.state_key): d for d in different_events})
+        local_view = dict(auth_events)
+        remote_view = dict(auth_events)
+        remote_view.update({(d.type, d.state_key): d for d in different_events})
 
-            new_state = yield self.state_handler.resolve_events(
-                room_version,
-                [list(local_view.values()), list(remote_view.values())],
-                event,
-            )
+        new_state = yield self.state_handler.resolve_events(
+            room_version, [list(local_view.values()), list(remote_view.values())], event
+        )
 
-            logger.info(
-                "After state res: updating auth_events with new state %s",
-                {
-                    (d.type, d.state_key): d.event_id
-                    for d in new_state.values()
-                    if auth_events.get((d.type, d.state_key)) != d
-                },
-            )
+        logger.info(
+            "After state res: updating auth_events with new state %s",
+            {
+                (d.type, d.state_key): d.event_id
+                for d in new_state.values()
+                if auth_events.get((d.type, d.state_key)) != d
+            },
+        )
 
-            auth_events.update(new_state)
+        auth_events.update(new_state)
 
-            context = yield self._update_context_for_auth_events(
-                event, context, auth_events
-            )
+        context = yield self._update_context_for_auth_events(
+            event, context, auth_events
+        )
 
         return context
 
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index ec4af29299..6a08a746b6 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -783,40 +783,6 @@ class EventsWorkerStore(SQLBaseStore):
             yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
         return results
 
-    def get_seen_events_with_rejections(self, event_ids):
-        """Given a list of event ids, check if we rejected them.
-
-        Args:
-            event_ids (list[str])
-
-        Returns:
-            Deferred[dict[str, str|None):
-                Has an entry for each event id we already have seen. Maps to
-                the rejected reason string if we rejected the event, else maps
-                to None.
-        """
-        if not event_ids:
-            return defer.succeed({})
-
-        def f(txn):
-            sql = (
-                "SELECT e.event_id, reason FROM events as e "
-                "LEFT JOIN rejections as r ON e.event_id = r.event_id "
-                "WHERE e.event_id = ?"
-            )
-
-            res = {}
-            for event_id in event_ids:
-                txn.execute(sql, (event_id,))
-                row = txn.fetchone()
-                if row:
-                    _, rejected = row
-                    res[event_id] = rejected
-
-            return res
-
-        return self.runInteraction("get_seen_events_with_rejections", f)
-
     def _get_total_state_event_counts_txn(self, txn, room_id):
         """
         See get_total_state_event_counts.
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 8f9b6365c1..f309e3640c 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -28,6 +28,7 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.data_stores.main.search import SearchStore
 from synapse.types import ThirdPartyInstanceID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -360,9 +361,9 @@ class RoomWorkerStore(SQLBaseStore):
         defer.returnValue(row)
 
 
-class RoomStore(RoomWorkerStore, SearchStore):
+class RoomBackgroundUpdateStore(BackgroundUpdateStore):
     def __init__(self, db_conn, hs):
-        super(RoomStore, self).__init__(db_conn, hs)
+        super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs)
 
         self.config = hs.config
 
@@ -438,6 +439,13 @@ class RoomStore(RoomWorkerStore, SearchStore):
 
         defer.returnValue(batch_size)
 
+
+class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
+    def __init__(self, db_conn, hs):
+        super(RoomStore, self).__init__(db_conn, hs)
+
+        self.config = hs.config
+
     @defer.inlineCallbacks
     def store_room(self, room_id, room_creator_user_id, is_public):
         """Stores a room.