summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6629.misc1
-rw-r--r--synapse/handlers/message.py43
-rw-r--r--synapse/handlers/room_member.py11
-rw-r--r--synapse/storage/data_stores/main/event_federation.py70
-rw-r--r--synapse/types.py12
-rw-r--r--tests/storage/test_event_federation.py15
-rw-r--r--tests/unittest.py6
7 files changed, 55 insertions, 103 deletions
diff --git a/changelog.d/6629.misc b/changelog.d/6629.misc
new file mode 100644
index 0000000000..68f77af05b
--- /dev/null
+++ b/changelog.d/6629.misc
@@ -0,0 +1 @@
+Simplify event creation code by removing redundant queries on the event_reference_hashes table.
\ No newline at end of file
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4ad752205f..8ea3aca2f4 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -48,7 +48,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import RoomAlias, UserID, create_requester
+from synapse.types import Collection, RoomAlias, UserID, create_requester
 from synapse.util.async_helpers import Linearizer
 from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.metrics import measure_func
@@ -422,7 +422,7 @@ class EventCreationHandler(object):
         event_dict,
         token_id=None,
         txn_id=None,
-        prev_events_and_hashes=None,
+        prev_event_ids: Optional[Collection[str]] = None,
         require_consent=True,
     ):
         """
@@ -439,10 +439,9 @@ class EventCreationHandler(object):
             token_id (str)
             txn_id (str)
 
-            prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+            prev_event_ids:
                 the forward extremities to use as the prev_events for the
-                new event. For each event, a tuple of (event_id, hashes, depth)
-                where *hashes* is a map from algorithm to hash.
+                new event.
 
                 If None, they will be requested from the database.
 
@@ -498,9 +497,7 @@ class EventCreationHandler(object):
             builder.internal_metadata.txn_id = txn_id
 
         event, context = yield self.create_new_client_event(
-            builder=builder,
-            requester=requester,
-            prev_events_and_hashes=prev_events_and_hashes,
+            builder=builder, requester=requester, prev_event_ids=prev_event_ids,
         )
 
         # In an ideal world we wouldn't need the second part of this condition. However,
@@ -714,7 +711,7 @@ class EventCreationHandler(object):
     @measure_func("create_new_client_event")
     @defer.inlineCallbacks
     def create_new_client_event(
-        self, builder, requester=None, prev_events_and_hashes=None
+        self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None
     ):
         """Create a new event for a local client
 
@@ -723,10 +720,9 @@ class EventCreationHandler(object):
 
             requester (synapse.types.Requester|None):
 
-            prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+            prev_event_ids:
                 the forward extremities to use as the prev_events for the
-                new event. For each event, a tuple of (event_id, hashes, depth)
-                where *hashes* is a map from algorithm to hash.
+                new event.
 
                 If None, they will be requested from the database.
 
@@ -734,22 +730,15 @@ class EventCreationHandler(object):
             Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
         """
 
-        if prev_events_and_hashes is not None:
-            assert len(prev_events_and_hashes) <= 10, (
+        if prev_event_ids is not None:
+            assert len(prev_event_ids) <= 10, (
                 "Attempting to create an event with %i prev_events"
-                % (len(prev_events_and_hashes),)
+                % (len(prev_event_ids),)
             )
         else:
-            prev_events_and_hashes = yield self.store.get_prev_events_for_room(
-                builder.room_id
-            )
-
-        prev_events = [
-            (event_id, prev_hashes)
-            for event_id, prev_hashes, _ in prev_events_and_hashes
-        ]
+            prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id)
 
-        event = yield builder.build(prev_event_ids=[p for p, _ in prev_events])
+        event = yield builder.build(prev_event_ids=prev_event_ids)
         context = yield self.state.compute_event_context(event)
         if requester:
             context.app_service = requester.app_service
@@ -1042,9 +1031,7 @@ class EventCreationHandler(object):
             # For each room we need to find a joined member we can use to send
             # the dummy event with.
 
-            prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id)
-
-            latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes)
+            latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
 
             members = yield self.state.get_current_users_in_room(
                 room_id, latest_event_ids=latest_event_ids
@@ -1063,7 +1050,7 @@ class EventCreationHandler(object):
                             "room_id": room_id,
                             "sender": user_id,
                         },
-                        prev_events_and_hashes=prev_events_and_hashes,
+                        prev_event_ids=latest_event_ids,
                     )
 
                     event.internal_metadata.proactively_send = False
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index dbb0c3dda2..03bb52ccfb 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
 from synapse import types
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, Codes, SynapseError
-from synapse.types import RoomID, UserID
+from synapse.types import Collection, RoomID, UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room, user_left_room
 
@@ -149,7 +149,7 @@ class RoomMemberHandler(object):
         target,
         room_id,
         membership,
-        prev_events_and_hashes,
+        prev_event_ids: Collection[str],
         txn_id=None,
         ratelimit=True,
         content=None,
@@ -177,7 +177,7 @@ class RoomMemberHandler(object):
             },
             token_id=requester.access_token_id,
             txn_id=txn_id,
-            prev_events_and_hashes=prev_events_and_hashes,
+            prev_event_ids=prev_event_ids,
             require_consent=require_consent,
         )
 
@@ -370,8 +370,7 @@ class RoomMemberHandler(object):
             if block_invite:
                 raise SynapseError(403, "Invites have been disabled on this server")
 
-        prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id)
-        latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes)
+        latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
 
         current_state_ids = yield self.state_handler.get_current_state_ids(
             room_id, latest_event_ids=latest_event_ids
@@ -485,7 +484,7 @@ class RoomMemberHandler(object):
             membership=effective_membership_state,
             txn_id=txn_id,
             ratelimit=ratelimit,
-            prev_events_and_hashes=prev_events_and_hashes,
+            prev_event_ids=latest_event_ids,
             content=content,
             require_consent=require_consent,
         )
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 1f517e8fad..5cb8cd96d1 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -14,13 +14,10 @@
 # limitations under the License.
 import itertools
 import logging
-import random
 
 from six.moves import range
 from six.moves.queue import Empty, PriorityQueue
 
-from unpaddedbase64 import encode_base64
-
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
@@ -148,8 +145,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             retcol="event_id",
         )
 
-    @defer.inlineCallbacks
-    def get_prev_events_for_room(self, room_id):
+    def get_prev_events_for_room(self, room_id: str):
         """
         Gets a subset of the current forward extremities in the given room.
 
@@ -160,40 +156,29 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             room_id (str): room_id
 
         Returns:
-            Deferred[list[(str, dict[str, str], int)]]
-                for each event, a tuple of (event_id, hashes, depth)
-                where *hashes* is a map from algorithm to hash.
+            Deferred[List[str]]: the event ids of the forward extremites
+
         """
-        res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
-        if len(res) > 10:
-            # Sort by reverse depth, so we point to the most recent.
-            res.sort(key=lambda a: -a[2])
 
-            # we use half of the limit for the actual most recent events, and
-            # the other half to randomly point to some of the older events, to
-            # make sure that we don't completely ignore the older events.
-            res = res[0:5] + random.sample(res[5:], 5)
+        return self.db.runInteraction(
+            "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
+        )
 
-        return res
+    def _get_prev_events_for_room_txn(self, txn, room_id: str):
+        # we just use the 10 newest events. Older events will become
+        # prev_events of future events.
 
-    def get_latest_event_ids_and_hashes_in_room(self, room_id):
+        sql = """
+            SELECT e.event_id FROM event_forward_extremities AS f
+            INNER JOIN events AS e USING (event_id)
+            WHERE f.room_id = ?
+            ORDER BY e.depth DESC
+            LIMIT 10
         """
-        Gets the current forward extremities in the given room
 
-        Args:
-            room_id (str): room_id
+        txn.execute(sql, (room_id,))
 
-        Returns:
-            Deferred[list[(str, dict[str, str], int)]]
-                for each event, a tuple of (event_id, hashes, depth)
-                where *hashes* is a map from algorithm to hash.
-        """
-
-        return self.db.runInteraction(
-            "get_latest_event_ids_and_hashes_in_room",
-            self._get_latest_event_ids_and_hashes_in_room,
-            room_id,
-        )
+        return [row[0] for row in txn]
 
     def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
         """Get the top rooms with at least N extremities.
@@ -243,27 +228,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             desc="get_latest_event_ids_in_room",
         )
 
-    def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id):
-        sql = (
-            "SELECT e.event_id, e.depth FROM events as e "
-            "INNER JOIN event_forward_extremities as f "
-            "ON e.event_id = f.event_id "
-            "AND e.room_id = f.room_id "
-            "WHERE f.room_id = ?"
-        )
-
-        txn.execute(sql, (room_id,))
-
-        results = []
-        for event_id, depth in txn.fetchall():
-            hashes = self._get_event_reference_hashes_txn(txn, event_id)
-            prev_hashes = {
-                k: encode_base64(v) for k, v in hashes.items() if k == "sha256"
-            }
-            results.append((event_id, prev_hashes, depth))
-
-        return results
-
     def get_min_depth(self, room_id):
         """ For hte given room, get the minimum depth we have seen for it.
         """
diff --git a/synapse/types.py b/synapse/types.py
index aafc3ffe74..cd996c0b5a 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 import re
 import string
+import sys
 from collections import namedtuple
 
 import attr
@@ -23,6 +24,17 @@ from unpaddedbase64 import decode_base64
 
 from synapse.api.errors import SynapseError
 
+# define a version of typing.Collection that works on python 3.5
+if sys.version_info[:3] >= (3, 6, 0):
+    from typing import Collection
+else:
+    from typing import Sized, Iterable, Container, TypeVar
+
+    T_co = TypeVar("T_co", covariant=True)
+
+    class Collection(Iterable[T_co], Container[T_co], Sized):
+        __slots__ = ()
+
 
 class Requester(
     namedtuple(
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index eadfb90a22..a331517f4d 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -60,21 +60,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
                 (event_id, bytearray(b"ffff")),
             )
 
-        for i in range(0, 11):
+        for i in range(0, 20):
             yield self.store.db.runInteraction("insert", insert_event, i)
 
-        # this should get the last five and five others
+        # this should get the last ten
         r = yield self.store.get_prev_events_for_room(room_id)
         self.assertEqual(10, len(r))
-        for i in range(0, 5):
-            el = r[i]
-            depth = el[2]
-            self.assertEqual(10 - i, depth)
-
-        for i in range(5, 5):
-            el = r[i]
-            depth = el[2]
-            self.assertLessEqual(5, depth)
+        for i in range(0, 10):
+            self.assertEqual("$event_%i:local" % (19 - i), r[i])
 
     @defer.inlineCallbacks
     def test_get_rooms_with_many_extremities(self):
diff --git a/tests/unittest.py b/tests/unittest.py
index cbda237278..ddcd4becfe 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -531,10 +531,6 @@ class HomeserverTestCase(TestCase):
         secrets = self.hs.get_secrets()
         requester = Requester(user, None, False, None, None)
 
-        prev_events_and_hashes = None
-        if prev_event_ids:
-            prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
-
         event, context = self.get_success(
             event_creator.create_event(
                 requester,
@@ -544,7 +540,7 @@ class HomeserverTestCase(TestCase):
                     "sender": user.to_string(),
                     "content": {"body": secrets.token_hex(), "msgtype": "m.text"},
                 },
-                prev_events_and_hashes=prev_events_and_hashes,
+                prev_event_ids=prev_event_ids,
             )
         )