summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crypto/src/asn1/Asn1InputStream.cs60
-rw-r--r--crypto/src/asn1/BERBitString.cs179
-rw-r--r--crypto/src/asn1/BerOctetString.cs2
-rw-r--r--crypto/src/asn1/DerBitString.cs242
4 files changed, 358 insertions, 125 deletions
diff --git a/crypto/src/asn1/Asn1InputStream.cs b/crypto/src/asn1/Asn1InputStream.cs
index b56d890fa..20734fd59 100644
--- a/crypto/src/asn1/Asn1InputStream.cs
+++ b/crypto/src/asn1/Asn1InputStream.cs
@@ -95,30 +95,18 @@ namespace Org.BouncyCastle.Asn1
             if (!isConstructed)
                 return CreatePrimitiveDerObject(tagNo, defIn, tmpBuffers);
 
-            // TODO There are other tags that may be constructed (e.g. BitString)
             switch (tagNo)
             {
+            case Asn1Tags.BitString:
+            {
+                return BuildConstructedBitString(ReadVector(defIn));
+            }
             case Asn1Tags.OctetString:
             {
                 //
                 // yes, people actually do this...
                 //
-                Asn1EncodableVector v = ReadVector(defIn);
-                Asn1OctetString[] strings = new Asn1OctetString[v.Count];
-
-                for (int i = 0; i != strings.Length; i++)
-                {
-                    Asn1Encodable asn1Obj = v[i];
-                    if (!(asn1Obj is Asn1OctetString))
-                    {
-                        throw new Asn1Exception("unknown object encountered in constructed OCTET STRING: "
-                            + Platform.GetTypeName(asn1Obj));
-                    }
-
-                    strings[i] = (Asn1OctetString)asn1Obj;
-                }
-
-                return new BerOctetString(strings);
+                return BuildConstructedOctetString(ReadVector(defIn));
             }
             case Asn1Tags.Sequence:
                 return CreateDerSequence(defIn);
@@ -227,6 +215,42 @@ namespace Org.BouncyCastle.Asn1
             }
         }
 
+        internal virtual DerBitString BuildConstructedBitString(Asn1EncodableVector contentsElements)
+        {
+            DerBitString[] bitStrings = new DerBitString[contentsElements.Count];
+
+            for (int i = 0; i != bitStrings.Length; i++)
+            {
+                DerBitString bitString = contentsElements[i] as DerBitString;
+                if (null == bitString)
+                    throw new Asn1Exception("unknown object encountered in constructed BIT STRING: "
+                        + Platform.GetTypeName(contentsElements[i]));
+
+                bitStrings[i] = bitString;
+            }
+
+            // TODO Probably ought to be DLBitString
+            return new BerBitString(bitStrings);
+        }
+
+        internal virtual Asn1OctetString BuildConstructedOctetString(Asn1EncodableVector contentsElements)
+        {
+            Asn1OctetString[] octetStrings = new Asn1OctetString[contentsElements.Count];
+
+            for (int i = 0; i != octetStrings.Length; i++)
+            {
+                Asn1OctetString octetString = contentsElements[i] as Asn1OctetString;
+                if (null == octetString)
+                    throw new Asn1Exception("unknown object encountered in constructed OCTET STRING: "
+                        + Platform.GetTypeName(contentsElements[i]));
+
+                octetStrings[i] = octetString;
+            }
+
+            // TODO Probably ought to be DerOctetString (no DLOctetString available)
+            return new BerOctetString(octetStrings);
+        }
+
         internal virtual int Limit
         {
             get { return limit; }
@@ -404,7 +428,7 @@ namespace Org.BouncyCastle.Asn1
             switch (tagNo)
             {
                 case Asn1Tags.BitString:
-                    return DerBitString.FromAsn1Octets(bytes);
+                    return DerBitString.CreatePrimitive(bytes);
                 case Asn1Tags.GeneralizedTime:
                     return new DerGeneralizedTime(bytes);
                 case Asn1Tags.GeneralString:
diff --git a/crypto/src/asn1/BERBitString.cs b/crypto/src/asn1/BERBitString.cs
index a738a75e6..1756ee9c0 100644
--- a/crypto/src/asn1/BERBitString.cs
+++ b/crypto/src/asn1/BERBitString.cs
@@ -1,4 +1,5 @@
 using System;
+using System.Diagnostics;
 
 using Org.BouncyCastle.Utilities;
 
@@ -7,36 +8,192 @@ namespace Org.BouncyCastle.Asn1
     public class BerBitString
         : DerBitString
     {
-        public BerBitString(byte[] data, int padBits)
+        private const int DefaultSegmentLimit = 1000;
+
+        internal static byte[] FlattenBitStrings(DerBitString[] bitStrings)
+        {
+            int count = bitStrings.Length;
+            switch (count)
+            {
+            case 0:
+                // No bits
+                return new byte[]{ 0 };
+            case 1:
+                return bitStrings[0].contents;
+            default:
+            {
+                int last = count - 1, totalLength = 0;
+                for (int i = 0; i < last; ++i)
+                {
+                    byte[] elementContents = bitStrings[i].contents;
+                    if (elementContents[0] != 0)
+                        throw new ArgumentException("only the last nested bitstring can have padding", "bitStrings");
+
+                    totalLength += elementContents.Length - 1;
+                }
+
+                // Last one can have padding
+                byte[] lastElementContents = bitStrings[last].contents;
+                byte padBits = lastElementContents[0];
+                totalLength += lastElementContents.Length;
+
+                byte[] contents = new byte[totalLength];
+                contents[0] = padBits;
+
+                int pos = 1;
+                for (int i = 0; i < count; ++i)
+                {
+                    byte[] elementContents = bitStrings[i].contents;
+                    int length = elementContents.Length - 1;
+                    Array.Copy(elementContents, 1, contents, pos, length);
+                    pos += length;
+                }
+
+                Debug.Assert(pos == totalLength);
+                return contents;
+            }
+            }
+        }
+
+        private readonly int segmentLimit;
+        private readonly DerBitString[] elements;
+
+        public BerBitString(byte data, int padBits)
             : base(data, padBits)
-		{
-		}
+        {
+            this.elements = null;
+            this.segmentLimit = DefaultSegmentLimit;
+        }
+
+        public BerBitString(byte[] data)
+            : this(data, 0)
+        {
+        }
 
-		public BerBitString(byte[] data)
-            : base(data)
+        public BerBitString(byte[] data, int padBits)
+            : this(data, padBits, DefaultSegmentLimit)
 		{
-		}
+        }
+
+        public BerBitString(byte[] data, int padBits, int segmentLimit)
+            : base(data, padBits)
+        {
+            this.elements = null;
+            this.segmentLimit = segmentLimit;
+        }
 
         public BerBitString(int namedBits)
             : base(namedBits)
         {
+            this.elements = null;
+            this.segmentLimit = DefaultSegmentLimit;
         }
 
         public BerBitString(Asn1Encodable obj)
-            : base(obj)
+            : this(obj.GetDerEncoded(), 0)
 		{
-		}
+        }
+
+        public BerBitString(DerBitString[] elements)
+            : this(elements, DefaultSegmentLimit)
+        {
+        }
+
+        public BerBitString(DerBitString[] elements, int segmentLimit)
+            : base(FlattenBitStrings(elements), false)
+        {
+            this.elements = elements;
+            this.segmentLimit = segmentLimit;
+        }
+
+        internal BerBitString(byte[] contents, bool check)
+            : base(contents, check)
+        {
+            this.elements = null;
+            this.segmentLimit = DefaultSegmentLimit;
+        }
+
+        private bool IsConstructed
+        {
+            get { return null != elements || contents.Length > segmentLimit; }
+        }
+
+        internal override int EncodedLength(bool withID)
+        {
+            throw Platform.CreateNotImplementedException("BerBitString.EncodedLength");
+
+            // TODO This depends on knowing it's not DER
+            //if (!IsConstructed)
+            //    return EncodedLength(withID, contents.Length);
+
+            //int totalLength = withID ? 4 : 3;
+
+            //if (null != elements)
+            //{
+            //    for (int i = 0; i < elements.Length; ++i)
+            //    {
+            //        totalLength += elements[i].EncodedLength(true);
+            //    }
+            //}
+            //else if (contents.Length < 2)
+            //{
+            //    // No bits
+            //}
+            //else
+            //{
+            //    int extraSegments = (contents.Length - 2) / (segmentLimit - 1);
+            //    totalLength += extraSegments * EncodedLength(true, segmentLimit);
+
+            //    int lastSegmentLength = contents.Length - (extraSegments * (segmentLimit - 1));
+            //    totalLength += EncodedLength(true, lastSegmentLength);
+            //}
+
+            //return totalLength;
+        }
 
         internal override void Encode(Asn1OutputStream asn1Out, bool withID)
         {
-            if (asn1Out.IsBer)
+            if (!asn1Out.IsBer)
+            {
+                base.Encode(asn1Out, withID);
+                return;
+            }
+
+            if (!IsConstructed)
             {
-                asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, (byte)mPadBits, mData, 0, mData.Length);
+                Encode(asn1Out, withID, contents, 0, contents.Length);
+                return;
+            }
+
+            asn1Out.WriteIdentifier(withID, Asn1Tags.Constructed | Asn1Tags.BitString);
+            asn1Out.WriteByte(0x80);
+
+            if (null != elements)
+            {
+                asn1Out.WritePrimitives(elements);
+            }
+            else if (contents.Length < 2)
+            {
+                // No bits
             }
             else
             {
-                base.Encode(asn1Out, withID);
+                byte pad = contents[0];
+                int length = contents.Length;
+                int remaining = length - 1;
+                int segmentLength = segmentLimit - 1;
+
+                while (remaining > segmentLength)
+                {
+                    Encode(asn1Out, true, (byte)0, contents, length - remaining, segmentLength);
+                    remaining -= segmentLength;
+                }
+
+                Encode(asn1Out, true, pad, contents, length - remaining, remaining);
             }
+
+            asn1Out.WriteByte(0x00);
+            asn1Out.WriteByte(0x00);
         }
     }
 }
diff --git a/crypto/src/asn1/BerOctetString.cs b/crypto/src/asn1/BerOctetString.cs
index 4855e31d1..9963819cf 100644
--- a/crypto/src/asn1/BerOctetString.cs
+++ b/crypto/src/asn1/BerOctetString.cs
@@ -9,7 +9,7 @@ namespace Org.BouncyCastle.Asn1
     public class BerOctetString
         : DerOctetString, IEnumerable
     {
-        private static readonly int DefaultSegmentLimit = 1000;
+        private const int DefaultSegmentLimit = 1000;
 
         public static BerOctetString FromSequence(Asn1Sequence seq)
         {
diff --git a/crypto/src/asn1/DerBitString.cs b/crypto/src/asn1/DerBitString.cs
index 4dabb398f..caf5d6f9f 100644
--- a/crypto/src/asn1/DerBitString.cs
+++ b/crypto/src/asn1/DerBitString.cs
@@ -13,16 +13,12 @@ namespace Org.BouncyCastle.Asn1
 		private static readonly char[] table
 			= { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' };
 
-		protected readonly byte[]   mData;
-		protected readonly int      mPadBits;
-
         /**
 		 * return a Bit string from the passed in object
 		 *
 		 * @exception ArgumentException if the object cannot be converted.
 		 */
-		public static DerBitString GetInstance(
-			object obj)
+		public static DerBitString GetInstance(object obj)
 		{
 			if (obj == null || obj is DerBitString)
 			{
@@ -52,9 +48,7 @@ namespace Org.BouncyCastle.Asn1
 		 * @exception ArgumentException if the tagged object cannot
 		 *               be converted.
 		 */
-		public static DerBitString GetInstance(
-			Asn1TaggedObject	obj,
-			bool				isExplicit)
+		public static DerBitString GetInstance(Asn1TaggedObject obj, bool isExplicit)
 		{
 			Asn1Object o = obj.GetObject();
 
@@ -63,16 +57,30 @@ namespace Org.BouncyCastle.Asn1
 				return GetInstance(o);
 			}
 
-			return FromAsn1Octets(((Asn1OctetString)o).GetOctets());
+            // Not copied because assumed to be a tagged implicit primitive from the parser
+			return CreatePrimitive(((Asn1OctetString)o).GetOctets());
 		}
 
+        internal readonly byte[] contents;
+
+        public DerBitString(byte data, int padBits)
+        {
+            if (padBits > 7 || padBits < 0)
+                throw new ArgumentException("pad bits cannot be greater than 7 or less than 0", "padBits");
+
+            this.contents = new byte[] { (byte)padBits, data };
+        }
+
+        public DerBitString(byte[] data)
+            : this(data, 0)
+        {
+        }
+
         /**
 		 * @param data the octets making up the bit string.
 		 * @param padBits the number of extra bits at the end of the string.
 		 */
-		public DerBitString(
-			byte[]	data,
-			int		padBits)
+        public DerBitString(byte[] data, int padBits)
 		{
             if (data == null)
                 throw new ArgumentNullException("data");
@@ -81,42 +89,30 @@ namespace Org.BouncyCastle.Asn1
             if (data.Length == 0 && padBits != 0)
                 throw new ArgumentException("if 'data' is empty, 'padBits' must be 0");
 
-            this.mData = Arrays.Clone(data);
-			this.mPadBits = padBits;
-		}
-
-		public DerBitString(
-			byte[] data)
-            : this(data, 0)
-		{
-		}
+            this.contents = Arrays.Prepend(data, (byte)padBits);
+        }
 
-        public DerBitString(
-            int namedBits)
+        public DerBitString(int namedBits)
         {
             if (namedBits == 0)
             {
-                this.mData = new byte[0];
-                this.mPadBits = 0;
+                this.contents = new byte[]{ 0 };
                 return;
             }
 
             int bits = BigInteger.BitLen(namedBits);
             int bytes = (bits + 7) / 8;
-
             Debug.Assert(0 < bytes && bytes <= 4);
 
-            byte[] data = new byte[bytes];
-            --bytes;
+            byte[] data = new byte[1 + bytes];
 
-            for (int i = 0; i < bytes; i++)
+            for (int i = 1; i < bytes; i++)
             {
                 data[i] = (byte)namedBits;
                 namedBits >>= 8;
             }
 
             Debug.Assert((namedBits & 0xFF) != 0);
-
             data[bytes] = (byte)namedBits;
 
             int padBits = 0;
@@ -126,17 +122,38 @@ namespace Org.BouncyCastle.Asn1
             }
 
             Debug.Assert(padBits < 8);
+            data[0] = (byte)padBits;
 
-            this.mData = data;
-            this.mPadBits = padBits;
+            this.contents = data;
         }
 
-        public DerBitString(
-			Asn1Encodable obj)
+        public DerBitString(Asn1Encodable obj)
             : this(obj.GetDerEncoded())
 		{
 		}
 
+        internal DerBitString(byte[] contents, bool check)
+        {
+            if (check)
+            {
+                if (null == contents)
+                    throw new ArgumentNullException("contents");
+                if (contents.Length < 1)
+                    throw new ArgumentException("cannot be empty", "contents");
+
+                int padBits = contents[0];
+                if (padBits > 0)
+                {
+                    if (contents.Length < 2)
+                        throw new ArgumentException("zero length data with non-zero pad bits", "contents");
+                    if (padBits > 7)
+                        throw new ArgumentException("pad bits cannot be greater than 7 or less than 0", "contents");
+                }
+            }
+
+            this.contents = contents;
+        }
+
         /**
          * Return the octets contained in this BIT STRING, checking that this BIT STRING really
          * does represent an octet aligned string. Only use this method when the standard you are
@@ -146,28 +163,27 @@ namespace Org.BouncyCastle.Asn1
          */
         public virtual byte[] GetOctets()
         {
-            if (mPadBits != 0)
+            if (contents[0] != 0)
                 throw new InvalidOperationException("attempt to get non-octet aligned data from BIT STRING");
 
-            return Arrays.Clone(mData);
+            return Arrays.CopyOfRange(contents, 1, contents.Length);
         }
 
         public virtual byte[] GetBytes()
 		{
-            byte[] data = Arrays.Clone(mData);
+            if (contents.Length == 1)
+                return Asn1OctetString.EmptyOctets;
 
+            int padBits = contents[0];
+            byte[] rv = Arrays.CopyOfRange(contents, 1, contents.Length);
             // DER requires pad bits be zero
-            if (mPadBits > 0)
-            {
-                data[data.Length - 1] &= (byte)(0xFF << mPadBits);
-            }
-
-            return data;
-		}
+            rv[rv.Length - 1] &= (byte)(0xFF << padBits);
+            return rv;
+        }
 
         public virtual int PadBits
 		{
-			get { return mPadBits; }
+			get { return contents[0]; }
 		}
 
 		/**
@@ -177,68 +193,90 @@ namespace Org.BouncyCastle.Asn1
 		{
 			get
 			{
-                int value = 0, length = System.Math.Min(4, mData.Length);
-                for (int i = 0; i < length; ++i)
+                int value = 0, end = System.Math.Min(5, contents.Length - 1);
+                for (int i = 1; i < end; ++i)
                 {
-                    value |= (int)mData[i] << (8 * i);
+                    value |= (int)contents[i] << (8 * (i - 1));
                 }
-                if (mPadBits > 0 && length == mData.Length)
+                if (1 <= end && end < 5)
                 {
-                    int mask = (1 << mPadBits) - 1;
-                    value &= ~(mask << (8 * (length - 1)));
+                    int padBits = contents[0];
+                    byte der = (byte)(contents[end] & (0xFF << padBits));
+                    value |= (int)der << (8 * (end - 1));
                 }
                 return value;
-			}
+            }
 		}
 
         internal override int EncodedLength(bool withID)
         {
-            return Asn1OutputStream.GetLengthOfEncodingDL(withID, 1 + mData.Length);
+            return Asn1OutputStream.GetLengthOfEncodingDL(withID, contents.Length);
         }
 
         internal override void Encode(Asn1OutputStream asn1Out, bool withID)
 		{
-            if (mPadBits > 0)
-            {
-                int last = mData[mData.Length - 1];
-                int mask = (1 << mPadBits) - 1;
-                int unusedBits = last & mask;
-
-                if (unusedBits != 0)
-                {
-                    byte[] contents = Arrays.Prepend(mData, (byte)mPadBits);
+            int padBits = contents[0];
+            int length = contents.Length;
+            int last = length - 1;
 
-                    /*
-                     * X.690-0207 11.2.1: Each unused bit in the final octet of the encoding of a bit string value shall be set to zero.
-                     */
-                    contents[contents.Length - 1] = (byte)(last ^ unusedBits);
+            byte lastOctet = contents[last];
+            byte lastOctetDer = (byte)(contents[last] & (0xFF << padBits));
 
-                    asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, contents);
-                    return;
-                }
+            if (lastOctet == lastOctetDer)
+            {
+                asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, contents);
+            }
+            else
+            {
+                asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, contents, 0, last, lastOctetDer);
             }
-
-            asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, (byte)mPadBits, mData, 0, mData.Length);
 		}
 
         protected override int Asn1GetHashCode()
 		{
-			return mPadBits.GetHashCode() ^ Arrays.GetHashCode(mData);
-		}
+            if (contents.Length < 2)
+                return 1;
+
+            int padBits = contents[0];
+            int last = contents.Length - 1;
+
+            byte lastOctetDer = (byte)(contents[last] & (0xFF << padBits));
 
-		protected override bool Asn1Equals(
-			Asn1Object asn1Object)
+            int hc = Arrays.GetHashCode(contents, 0, last);
+            hc *= 257;
+            hc ^= lastOctetDer;
+            return hc;
+        }
+
+        protected override bool Asn1Equals(Asn1Object asn1Object)
 		{
-			DerBitString other = asn1Object as DerBitString;
+            DerBitString that = asn1Object as DerBitString;
+            if (null == that)
+                return false;
 
-			if (other == null)
-				return false;
+            byte[] thisContents = this.contents, thatContents = that.contents;
 
-			return this.mPadBits == other.mPadBits
-				&& Arrays.AreEqual(this.mData, other.mData);
-		}
+            int length = thisContents.Length;
+            if (thatContents.Length != length)
+                return false;
+            if (length == 1)
+                return true;
+
+            int last = length - 1;
+            for (int i = 0; i < last; ++i)
+            {
+                if (thisContents[i] != thatContents[i])
+                    return false;
+            }
+
+            int padBits = thisContents[0];
+            byte thisLastOctetDer = (byte)(thisContents[last] & (0xFF << padBits));
+            byte thatLastOctetDer = (byte)(thatContents[last] & (0xFF << padBits));
+
+            return thisLastOctetDer == thatLastOctetDer;
+        }
 
-		public override string GetString()
+        public override string GetString()
 		{
 			StringBuilder buffer = new StringBuilder("#");
 
@@ -254,27 +292,41 @@ namespace Org.BouncyCastle.Asn1
 			return buffer.ToString();
 		}
 
-		internal static DerBitString FromAsn1Octets(byte[] octets)
-		{
-	        if (octets.Length < 1)
-	            throw new ArgumentException("truncated BIT STRING detected", "octets");
+        internal static int EncodedLength(bool withID, int contentsLength)
+        {
+            return Asn1OutputStream.GetLengthOfEncodingDL(withID, contentsLength);
+        }
 
-            int padBits = octets[0];
-            byte[] data = Arrays.CopyOfRange(octets, 1, octets.Length);
+        internal static void Encode(Asn1OutputStream asn1Out, bool withID, byte[] buf, int off, int len)
+        {
+            asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, buf, off, len);
+        }
+
+        internal static void Encode(Asn1OutputStream asn1Out, bool withID, byte pad, byte[] buf, int off, int len)
+        {
+            asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, pad, buf, off, len);
+        }
+
+		internal static DerBitString CreatePrimitive(byte[] contents)
+		{
+            int length = contents.Length;
+            if (length < 1)
+                throw new ArgumentException("truncated BIT STRING detected", "contents");
 
-            if (padBits > 0 && padBits < 8 && data.Length > 0)
+            int padBits = contents[0];
+            if (padBits > 0)
             {
-                int last = data[data.Length - 1];
-                int mask = (1 << padBits) - 1;
+                if (padBits > 7 || length < 2)
+                    throw new ArgumentException("invalid pad bits detected", "contents");
 
-                if ((last & mask) != 0)
+                byte finalOctet = contents[length - 1];
+                if (finalOctet != (byte)(finalOctet & (0xFF << padBits)))
                 {
-                    return new BerBitString(data, padBits);
+                    return new BerBitString(contents, false);
                 }
             }
 
-            return new DerBitString(data, padBits);
+            return new DerBitString(contents, false);
 		}
 	}
 }
-