diff --git a/exchanges/orderbook/depth.go b/exchanges/orderbook/depth.go index d97b8cea282..e5b572ba94a 100644 --- a/exchanges/orderbook/depth.go +++ b/exchanges/orderbook/depth.go @@ -8,9 +8,11 @@ import ( "github.com/gofrs/uuid" "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/common/key" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/dispatch" "github.com/thrasher-corp/gocryptotrader/exchanges/alert" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -784,11 +786,29 @@ func (d *Depth) GetTranches(count int) (ask, bid []Tranche, err error) { } // GetPair returns the pair associated with the depth -func (d *Depth) GetPair() (currency.Pair, error) { +func (d *Depth) GetPair() currency.Pair { d.m.RLock() defer d.m.RUnlock() - if d.pair.IsEmpty() { - return currency.Pair{}, currency.ErrCurrencyPairEmpty - } - return d.pair, nil + return d.pair +} + +// GetAsset returns the asset associated with the depth +func (d *Depth) GetAsset() asset.Item { + d.m.RLock() + defer d.m.RUnlock() + return d.asset +} + +// GetExchange returns the exchange associated with the depth +func (d *Depth) GetExchange() string { + d.m.RLock() + defer d.m.RUnlock() + return d.exchange +} + +// GetKey returns the key associated with the depth +func (d *Depth) GetKey() key.ExchangePairAsset { + d.m.RLock() + defer d.m.RUnlock() + return key.ExchangePairAsset{Exchange: d.exchange, Base: d.pair.Base.Item, Quote: d.pair.Quote.Item, Asset: d.asset} } diff --git a/exchanges/orderbook/depth_test.go b/exchanges/orderbook/depth_test.go index da96ef1b558..9db77e43b2e 100644 --- a/exchanges/orderbook/depth_test.go +++ b/exchanges/orderbook/depth_test.go @@ -10,6 +10,8 @@ import ( "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common/key" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" ) @@ -677,22 +679,6 @@ func TestGetTranches(t *testing.T) { assert.Len(t, bidT, 5, "bids should have correct number of tranches") } -func TestGetPair(t *testing.T) { - t.Parallel() - depth := NewDepth(id) - - _, err := depth.GetPair() - assert.ErrorIs(t, err, currency.ErrCurrencyPairEmpty, "GetPair should error correctly") - - expected := currency.NewPair(currency.BTC, currency.WABI) - depth.pair = expected - - pair, err := depth.GetPair() - assert.NoError(t, err, "GetPair should not error") - - assert.Equal(t, expected, pair, "GetPair should return correct pair") -} - func getInvalidDepth() *Depth { depth := NewDepth(id) _ = depth.Invalidate(errors.New("invalid reasoning")) @@ -910,3 +896,39 @@ var movementTests = []struct { {[]any{20.0, true}, Movement{NominalPercentage: 0.7105459985041137, ImpactPercentage: FullLiquidityExhaustedPercentage, SlippageCost: 190.0, FullBookSideConsumed: true}}, }}, } + +func TestGetPair(t *testing.T) { + t.Parallel() + depth := NewDepth(id) + require.Empty(t, depth.GetPair()) + depth.pair = currency.NewPair(currency.BTC, currency.WABI) + require.Equal(t, depth.pair, depth.GetPair()) +} + +func TestGetAsset(t *testing.T) { + t.Parallel() + depth := NewDepth(id) + require.Empty(t, depth.GetAsset()) + depth.asset = asset.Spot + require.Equal(t, depth.asset, depth.GetAsset()) +} + +func TestGetExchange(t *testing.T) { + t.Parallel() + depth := NewDepth(id) + require.Empty(t, depth.GetExchange()) + depth.exchange = "test" + require.Equal(t, depth.exchange, depth.GetExchange()) +} + +func TestGetKey(t *testing.T) { + t.Parallel() + depth := NewDepth(id) + require.Empty(t, depth.GetKey()) + depth.exchange = "test" + depth.pair = currency.NewPair(currency.BTC, currency.WABI) + depth.asset = asset.Spot + require.Equal(t, + key.ExchangePairAsset{Exchange: depth.exchange, Base: depth.pair.Base.Item, Quote: depth.pair.Quote.Item, Asset: depth.asset}, + depth.GetKey()) +}