Improve sorting for SETs
1 files changed, 47 insertions, 53 deletions
diff --git a/crypto/src/asn1/Asn1Set.cs b/crypto/src/asn1/Asn1Set.cs
index 2e77ca2a9..cf039d7fe 100644
--- a/crypto/src/asn1/Asn1Set.cs
+++ b/crypto/src/asn1/Asn1Set.cs
@@ -278,67 +278,30 @@ namespace Org.BouncyCastle.Asn1
return encObj;
}
- /**
- * return true if a <= b (arrays are assumed padded with zeros).
- */
- private bool LessThanOrEqual(
- byte[] a,
- byte[] b)
+ protected internal void Sort()
{
- int len = System.Math.Min(a.Length, b.Length);
- for (int i = 0; i != len; ++i)
+ if (_set.Count < 2)
+ return;
+
+ Asn1Encodable[] items = new Asn1Encodable[_set.Count];
+ byte[][] keys = new byte[_set.Count][];
+
+ for (int i = 0; i < _set.Count; ++i)
{
- if (a[i] != b[i])
- {
- return a[i] < b[i];
- }
+ Asn1Encodable item = (Asn1Encodable)_set[i];
+ items[i] = item;
+ keys[i] = item.GetEncoded(Asn1Encodable.Der);
}
- return len == a.Length;
- }
- protected internal void Sort()
- {
- if (_set.Count > 1)
- {
- bool swapped = true;
- int lastSwap = _set.Count - 1;
+ Array.Sort(keys, items, new DerComparer());
- while (swapped)
- {
- int index = 0;
- int swapIndex = 0;
- byte[] a = ((Asn1Encodable) _set[0]).GetEncoded();
-
- swapped = false;
-
- while (index != lastSwap)
- {
- byte[] b = ((Asn1Encodable) _set[index + 1]).GetEncoded();
-
- if (LessThanOrEqual(a, b))
- {
- a = b;
- }
- else
- {
- object o = _set[index];
- _set[index] = _set[index + 1];
- _set[index + 1] = o;
-
- swapped = true;
- swapIndex = index;
- }
-
- index++;
- }
-
- lastSwap = swapIndex;
- }
+ for (int i = 0; i < _set.Count; ++i)
+ {
+ _set[i] = items[i];
}
}
- protected internal void AddObject(
- Asn1Encodable obj)
+ protected internal void AddObject(Asn1Encodable obj)
{
_set.Add(obj);
}
@@ -347,5 +310,36 @@ namespace Org.BouncyCastle.Asn1
{
return CollectionUtilities.ToString(_set);
}
+
+ private class DerComparer
+ : IComparer
+ {
+ public int Compare(object x, object y)
+ {
+ byte[] a = (byte[])x, b = (byte[])y;
+ int len = System.Math.Min(a.Length, b.Length);
+ for (int i = 0; i != len; ++i)
+ {
+ byte ai = a[i], bi = b[i];
+ if (ai != bi)
+ return ai < bi ? -1 : 1;
+ }
+ if (a.Length > b.Length)
+ return AllZeroesFrom(a, len) ? 0 : 1;
+ if (a.Length < b.Length)
+ return AllZeroesFrom(b, len) ? 0 : -1;
+ return 0;
+ }
+
+ private bool AllZeroesFrom(byte[] bs, int pos)
+ {
+ while (pos < bs.Length)
+ {
+ if (bs[pos++] != 0)
+ return false;
+ }
+ return true;
+ }
+ }
}
}
|