Skip to content

Commit

Permalink
v1.1.0: Support enumerations with gaps. (awakesecurity#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
j6carey authored Jul 11, 2019
1 parent 43d8220 commit 4f355bb
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 19 deletions.
3 changes: 2 additions & 1 deletion proto3-wire.cabal
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: proto3-wire
version: 1.0.0
version: 1.1.0
synopsis: A low-level implementation of the Protocol Buffers (version 3) wire format
license: Apache-2.0
license-file: LICENSE
Expand All @@ -13,6 +13,7 @@ cabal-version: >=1.10
library
exposed-modules: Proto3.Wire
Proto3.Wire.Builder
Proto3.Wire.Class
Proto3.Wire.Decode
Proto3.Wire.Encode
Proto3.Wire.Tutorial
Expand Down
7 changes: 5 additions & 2 deletions src/Proto3/Wire.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
-- | See the "Proto3.Wire.Tutorial" module.

module Proto3.Wire
( -- * Message Structure
FieldNumber(..)
( -- * Support Classes
ProtoEnum(..)
-- * Message Structure
, FieldNumber(..)
, fieldNumber
-- * Decoding Messages
, at
Expand All @@ -27,5 +29,6 @@ module Proto3.Wire
, repeated
) where

import Proto3.Wire.Class
import Proto3.Wire.Types
import Proto3.Wire.Decode
44 changes: 44 additions & 0 deletions src/Proto3/Wire/Class.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{-# LANGUAGE DefaultSignatures #-}
{-
Copyright 2019 Awake Networks
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-}

-- | This module defines classes which are shared by the encoding and decoding
-- modules.

module Proto3.Wire.Class
( ProtoEnum(..)
) where

import Data.Int (Int32)
import qualified Safe

-- | Similar to 'Enum', but allows gaps in the sequence of numeric codes,
-- and uses 'Int32' in order to match the proto3 specification.
--
-- Absent gaps, you can use an automatic derivation of 'Bounded' and 'Enum',
-- then use the default implementations for all methods of this class. But
-- if gaps are involved, then you must instantiate this class directly and
-- supply the specific numeric codes desired for each constructor.
class ProtoEnum a where
-- | Default implementation: `Safe.toEnumMay`.
toProtoEnumMay :: Int32 -> Maybe a
default toProtoEnumMay :: (Bounded a, Enum a) => Int32 -> Maybe a
toProtoEnumMay = Safe.toEnumMay . fromIntegral

-- | Default implementation: 'fromEnum'.
fromProtoEnum :: a -> Int32
default fromProtoEnum :: Enum a => a -> Int32
fromProtoEnum = fromIntegral . fromEnum
19 changes: 11 additions & 8 deletions src/Proto3/Wire/Decode.hs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ import Data.Text.Lazy.Encoding ( decodeUtf8' )
import qualified Data.Traversable as T
import Data.Int ( Int32, Int64 )
import Data.Word ( Word8, Word32, Word64 )
import Proto3.Wire.Class
import Proto3.Wire.Types
import qualified Safe

-- | Decode a zigzag-encoded numeric type.
-- See: http://stackoverflow.com/questions/2210923/zig-zag-decoding
Expand Down Expand Up @@ -348,7 +348,10 @@ bytes = Parser $

-- | Parse a Boolean value.
bool :: Parser RawPrimitive Bool
bool = fmap (Safe.toEnumDef False) parseVarInt
bool = Parser $
\case
VarintField i -> return $! i /= 0
wrong -> throwWireTypeError "bool" wrong

-- | Parse a primitive with the @int32@ wire type.
int32 :: Parser RawPrimitive Int32
Expand Down Expand Up @@ -395,14 +398,14 @@ text = Parser $

-- | Parse a primitive with an enumerated type.
--
-- This parser will return 'Left' if the encoded integer value is outside the
-- acceptable range of the 'Bounded' instance.
enum :: forall e. (Enum e, Bounded e) => Parser RawPrimitive (Either Int e)
-- This parser will return 'Left' if the encoded integer value
-- is not a code for a known enumerator.
enum :: forall e. ProtoEnum e => Parser RawPrimitive (Either Int32 e)
enum = fmap toEither parseVarInt
where
toEither :: Int -> Either Int e
toEither :: Int32 -> Either Int32 e
toEither i
| Just e <- Safe.toEnumMay i = Right e
| Just e <- toProtoEnumMay i = Right e
| otherwise = Left i

-- | Parse a packed collection of variable-width integer values (any of @int32@,
Expand Down Expand Up @@ -573,4 +576,4 @@ embedded' parser = Parser $
wrong -> throwWireTypeError "embedded" wrong


-- TODO test repeated and embedded better for reverse logic...
-- TODO test repeated and embedded better for reverse logic...
39 changes: 32 additions & 7 deletions src/Proto3/Wire/Encode.hs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ module Proto3.Wire.Encode
, float
, double
, enum
, bool
-- * Strings
, bytes
, string
Expand Down Expand Up @@ -94,6 +95,7 @@ import qualified Data.Text.Lazy as Text.Lazy
import qualified Data.Text.Lazy.Encoding as Text.Lazy.Encoding
import Data.Word ( Word8, Word32, Word64 )
import qualified Proto3.Wire.Builder as WB
import Proto3.Wire.Class
import Proto3.Wire.Types

-- $setup
Expand Down Expand Up @@ -265,16 +267,39 @@ double num d = fieldHeader num Fixed64 <> MessageBuilder (WB.doubleLE d)

-- | Encode a value with an enumerable type.
--
-- It can be useful to derive an 'Enum' instance for a type in order to
-- emulate enums appearing in .proto files.
-- You should instantiate 'ProtoEnum' for a type in
-- order to emulate enums appearing in .proto files.
--
-- For example:
--
-- >>> data Shape = Circle | Square | Triangle deriving (Enum)
-- >>> 1 `enum` True <> 2 `enum` Circle
-- Proto3.Wire.Encode.unsafeFromLazyByteString "\b\SOH\DLE\NUL"
enum :: Enum e => FieldNumber -> e -> MessageBuilder
enum num e = fieldHeader num Varint <> base128Varint (fromIntegral (fromEnum e))
-- >>> :{
-- data Shape = Circle | Square | Triangle deriving (Bounded, Enum)
-- instance ProtoEnum Shape
-- data Gap = Gap0 | Gap3
-- instance ProtoEnum Gap where
-- toProtoEnumMay i = case i of
-- 0 -> Just Gap0
-- 3 -> Just Gap3
-- _ -> Nothing
-- fromProtoEnum g = case g of
-- Gap0 -> 0
-- Gap3 -> 3
-- :}
--
-- >>> 1 `enum` Triangle <> 2 `enum` Gap3
-- Proto3.Wire.Encode.unsafeFromLazyByteString "\b\STX\DLE\ETX"
enum :: ProtoEnum e => FieldNumber -> e -> MessageBuilder
enum num e =
fieldHeader num Varint <> base128Varint (fromIntegral (fromProtoEnum e))

-- | Encode a boolean value
--
-- For example:
--
-- >>> 1 `bool` True
-- Proto3.Wire.Encode.unsafeFromLazyByteString "\b\SOH"
bool :: FieldNumber -> Bool -> MessageBuilder
bool num i = fieldHeader num Varint <> base128Varint (fromIntegral (fromEnum i))

-- | Encode a sequence of octets as a field of type 'bytes'.
--
Expand Down
2 changes: 1 addition & 1 deletion test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ roundTripTests = testGroup "Roundtrip tests"
(Encode.double (fieldNumber 1))
(one Decode.double 0 `at` fieldNumber 1)
, roundTrip "bool"
(Encode.enum (fieldNumber 1))
(Encode.bool (fieldNumber 1))
(one Decode.bool False `at` fieldNumber 1)
, roundTrip "text"
(Encode.text (fieldNumber 1) . T.pack)
Expand Down

0 comments on commit 4f355bb

Please sign in to comment.