This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
var defaultCommissionTolerance = decimal.RequireFromString("0.01")
|
||||
|
||||
type Engine struct {
|
||||
mu *sync.Mutex
|
||||
repo repository.Repository
|
||||
gateway tinvest.Gateway
|
||||
accountID string
|
||||
@@ -31,6 +33,7 @@ type Engine struct {
|
||||
|
||||
func New(repo repository.Repository, gateway tinvest.Gateway, accountID, accountIDHash string) Engine {
|
||||
return Engine{
|
||||
mu: &sync.Mutex{},
|
||||
repo: repo,
|
||||
gateway: gateway,
|
||||
accountID: accountID,
|
||||
@@ -64,6 +67,10 @@ func (e Engine) WithCommissionPolicy(requireZero, quarantineOnNonZero bool, tole
|
||||
}
|
||||
|
||||
func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
|
||||
if e.mu != nil {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
}
|
||||
localOrders, err := e.repo.ListActiveOrders(ctx, e.accountIDHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,6 +157,7 @@ func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
|
||||
})
|
||||
}
|
||||
}
|
||||
diffs = append(diffs, compareCash(localPositions, portfolio, e.commissionTolerance)...)
|
||||
from := now.Add(-e.window)
|
||||
recentOrders, err := e.repo.ListOrders(ctx, e.accountIDHash, from, now)
|
||||
if err != nil {
|
||||
@@ -204,6 +212,55 @@ func compareOperations(orders []domain.Order, operations []domain.Operation) []d
|
||||
return compareOperationsWithPolicy(orders, operations, false, defaultCommissionTolerance)
|
||||
}
|
||||
|
||||
func compareCash(localPositions []domain.Position, portfolio domain.Portfolio, tolerance decimal.Decimal) []domain.ReconciliationDiff {
|
||||
if tolerance.IsNegative() {
|
||||
tolerance = decimal.Zero
|
||||
}
|
||||
expectedCash, ok := expectedCashFromLocalPositions(localPositions, portfolio)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
diff := money.Abs(expectedCash.Sub(portfolio.Cash))
|
||||
if diff.LessThanOrEqual(tolerance) {
|
||||
return nil
|
||||
}
|
||||
return []domain.ReconciliationDiff{{
|
||||
Kind: "cash_mismatch",
|
||||
Message: fmt.Sprintf("expected cash=%s broker cash=%s diff=%s", expectedCash.StringFixed(2), portfolio.Cash.StringFixed(2), diff.StringFixed(2)),
|
||||
Critical: true,
|
||||
}}
|
||||
}
|
||||
|
||||
func expectedCashFromLocalPositions(localPositions []domain.Position, portfolio domain.Portfolio) (decimal.Decimal, bool) {
|
||||
if !portfolio.Equity.IsPositive() {
|
||||
return decimal.Zero, false
|
||||
}
|
||||
if len(localPositions) == 0 {
|
||||
if len(portfolio.Holdings) != 0 {
|
||||
return decimal.Zero, false
|
||||
}
|
||||
return portfolio.Equity, true
|
||||
}
|
||||
holdingByInstrument := make(map[string]domain.Holding, len(portfolio.Holdings))
|
||||
for _, holding := range portfolio.Holdings {
|
||||
holdingByInstrument[holding.InstrumentUID] = holding
|
||||
}
|
||||
positionMarketValue := decimal.Zero
|
||||
for _, pos := range localPositions {
|
||||
if pos.Lots <= 0 {
|
||||
continue
|
||||
}
|
||||
holding, ok := holdingByInstrument[pos.InstrumentUID]
|
||||
if !ok || holding.QuantityLots <= 0 || !holding.MarketValue.IsPositive() {
|
||||
return decimal.Zero, false
|
||||
}
|
||||
positionMarketValue = positionMarketValue.Add(holding.MarketValue.
|
||||
Mul(decimal.NewFromInt(pos.Lots)).
|
||||
Div(decimal.NewFromInt(holding.QuantityLots)))
|
||||
}
|
||||
return portfolio.Equity.Sub(positionMarketValue), true
|
||||
}
|
||||
|
||||
func compareOperationsWithPolicy(orders []domain.Order, operations []domain.Operation, requireZeroCommission bool, commissionTolerance decimal.Decimal) []domain.ReconciliationDiff {
|
||||
var diffs []domain.ReconciliationDiff
|
||||
if commissionTolerance.IsNegative() {
|
||||
|
||||
@@ -170,3 +170,37 @@ func TestReconciliationSkipsFreshInFlightLocalOrders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconciliationFindsCashMismatch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
gateway := tinvest.NewFakeGateway()
|
||||
if err := repo.UpsertPosition(ctx, domain.Position{
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
OpenTradeDate: time.Now().UTC(),
|
||||
Lots: 2,
|
||||
Status: domain.PositionHoldingOvernight,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gateway.Portfolio = domain.Portfolio{
|
||||
Equity: decimal.NewFromInt(1000),
|
||||
Cash: decimal.NewFromInt(700),
|
||||
Holdings: []domain.Holding{{
|
||||
InstrumentUID: "uid",
|
||||
QuantityLots: 2,
|
||||
MarketValue: decimal.NewFromInt(200),
|
||||
}},
|
||||
}
|
||||
diffs, err := New(repo, gateway, "account", "hash").Run(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, diff := range diffs {
|
||||
if diff.Kind == "cash_mismatch" && diff.Critical {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("missing cash_mismatch in %+v", diffs)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user