diff --git a/lib/Echidna/ABI.hs b/lib/Echidna/ABI.hs index 6c6d6d10e..7533abb9b 100644 --- a/lib/Echidna/ABI.hs +++ b/lib/Echidna/ABI.hs @@ -1,7 +1,7 @@ module Echidna.ABI where import Control.Monad (liftM2, liftM3, foldM, replicateM) -import Control.Monad.Random.Strict (MonadRandom, join, getRandom, getRandoms, getRandomR) +import Control.Monad.Random.Strict (MonadRandom, join, getRandom, getRandoms, getRandomR, uniform, fromList) import Control.Monad.Random.Strict qualified as Random import Data.Binary.Put (runPut, putWord32be) import Data.BinaryWord (unsignedWord) @@ -274,7 +274,31 @@ shrinkAbiValue = \case -- | Given a 'SolCall', generate a random \"smaller\" (simpler) call. shrinkAbiCall :: MonadRandom m => SolCall -> m SolCall -shrinkAbiCall = traverse $ traverse shrinkAbiValue +shrinkAbiCall (name, vals) = do + let numShrinkable = length $ filter canShrinkAbiValue vals + + halfwayVal <- getRandomR (0, numShrinkable) + -- This list was made arbitrarily. Feel free to change + let numToShrinkOptions = [1, 2, halfwayVal, numShrinkable] + + numToShrink <- min numShrinkable <$> uniform numToShrinkOptions + shrunkVals <- shrinkVals (fromIntegral numShrinkable) (fromIntegral numToShrink) vals + pure (name, shrunkVals) + where + shrinkVals 0 _ l = pure l + shrinkVals _ 0 l = pure l + shrinkVals _ _ [] = pure [] + shrinkVals numShrinkable numToShrink (h:t) + | not (canShrinkAbiValue h) = (h:) <$> shrinkVals numShrinkable numToShrink t + | otherwise = do + -- We want to pick which ones to shrink uniformly from the vals list. + -- Odds of shrinking one element is numToShrink/numShrinkable. + shouldShrink <- fromList [(True, numToShrink), (False, numShrinkable-numToShrink)] + h' <- if shouldShrink then shrinkAbiValue h else pure h + let + numShrinkable' = numShrinkable-1 + numToShrink' = if shouldShrink then numToShrink-1 else numToShrink + (h':) <$> shrinkVals numShrinkable' numToShrink' t -- | Given an 'AbiValue', generate a random \"similar\" value of the same 'AbiType'. mutateAbiValue :: MonadRandom m => AbiValue -> m AbiValue