This commit is contained in:
@@ -76,6 +76,7 @@ APP_MODE=backtest go run ./cmd/bot
|
||||
| `STRATEGY_ROLLING_SHORT` | количество торговых дней | `60` | рекомендуется `> 0` | Короткое окно статистики overnight-доходности. Больше - стабильнее оценка, но медленнее реакция; меньше - быстрее реакция, но больше шум. |
|
||||
| `STRATEGY_ROLLING_LONG` | количество торговых дней | `252` | рекомендуется `>= STRATEGY_ROLLING_SHORT` и `> 0` | Длинное окно для проверки положительного долгосрочного edge и глубины backfill. Больше требует больше истории. |
|
||||
| `STRATEGY_EWMA_LAMBDA` | дробь для EWMA | `0.08` | рабочий диапазон `(0, 1]`; вне диапазона EWMA-функция использует `0.08` | Вес новых наблюдений в EWMA. Больше - свежее движение влияет сильнее. |
|
||||
| `STRATEGY_ALLOCATION_METHOD` | `equal_weight` | `equal_weight` | сейчас поддерживается только `equal_weight` | Метод распределения капитала между выбранными сигналами. Текущая реализация делит лимит экспозиции поровну между выбранными инструментами. |
|
||||
| `STRATEGY_MIN_TSTAT_60` | decimal t-stat | `1.25` | валидации нет; обычно `>= 0` | Минимальная статистическая значимость короткого edge. Выше - меньше входов, ниже - больше входов. |
|
||||
| `STRATEGY_MIN_WIN_RATE_60` | доля прибыльных overnight-дней | `0.55` | рекомендуется `0..1` | Минимальная доля положительных overnight-наблюдений. Выше - строже фильтр сигналов. |
|
||||
| `STRATEGY_MIN_NET_EDGE_BPS` | bps | `10` | валидации нет; обычно `>= 0` | Минимальный ожидаемый edge после издержек. Выше - меньше, но потенциально качественнее сигналы. |
|
||||
@@ -146,7 +147,7 @@ APP_MODE=backtest go run ./cmd/bot
|
||||
| --- | --- | --- | --- | --- |
|
||||
| `COMM_REQUIRE_ZERO_COMMISSION` | `true` или `false` | `true` | boolean | При `true` сигналы по инструментам с ожидаемой комиссией `> 0` отклоняются. |
|
||||
| `COMM_QUARANTINE_ON_NONZERO` | `true` или `false` | `true` | boolean | При фактической брокерской комиссии `> 0` инструмент переводится в quarantine, а система останавливается через HALT по zero-commission policy. |
|
||||
| `COMM_FREE_ORDER_COUNT_POLICY` | `submitted` | `submitted` | жёстко только `submitted` | Политика учёта бесплатных заявок: счётчик увеличивается при отправке заявки. Другие значения запрещены валидацией. |
|
||||
| `COMM_FREE_ORDER_COUNT_POLICY` | `submitted` или `cancel_counts` | `submitted` | одно из двух значений | Политика учёта бесплатных заявок: `submitted` считает только отправку новой заявки, `cancel_counts` дополнительно считает успешные отмены перед repost. |
|
||||
|
||||
### BT
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ func run() error {
|
||||
MaxSpreadBps: spread,
|
||||
MaxTickBps: tick,
|
||||
AssumedSpreadBps: assumed,
|
||||
RequireZeroCommission: *requireZeroCommission,
|
||||
RequireZeroCommission: requireZeroCommission,
|
||||
UseMinuteModel: *useMinuteModel,
|
||||
})
|
||||
result, err := engine.RunWithMinuteCandles(candles, minuteCandles)
|
||||
|
||||
@@ -244,6 +244,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con
|
||||
})
|
||||
execEngine := execution.NewEngine(cfg.App.Mode, cfg.TInvest.AccountID, gateway, repo)
|
||||
execEngine.SetMaxQuoteAge(time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second)
|
||||
execEngine.SetFreeOrderCountPolicy(cfg.Commission.FreeOrderCountPolicy)
|
||||
services := scheduler.Services{
|
||||
Repo: repo,
|
||||
Gateway: gateway,
|
||||
@@ -272,6 +273,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con
|
||||
EntryWindowEnd: cfg.Execution.EntryWindowEnd,
|
||||
NoNewEntryAfter: cfg.Execution.NoNewEntryAfter,
|
||||
ExitWatchStart: cfg.Execution.ExitWatchStart,
|
||||
ExitNotBefore: cfg.Execution.ExitNotBefore,
|
||||
ExitWindowStart: cfg.Execution.ExitWindowStart,
|
||||
ExitWindowEnd: cfg.Execution.ExitWindowEnd,
|
||||
HardExitDeadline: cfg.Execution.HardExitDeadline,
|
||||
@@ -287,6 +289,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con
|
||||
APIOutageHalt: time.Duration(cfg.Risk.APIOutageHaltSec) * time.Second,
|
||||
RequireZeroCommission: cfg.Commission.RequireZeroCommission,
|
||||
QuarantineOnNonZero: cfg.Commission.QuarantineOnNonZero,
|
||||
FreeOrderCountPolicy: cfg.Commission.FreeOrderCountPolicy,
|
||||
ReconciliationInterval: 5 * time.Minute,
|
||||
MaxOpenPositions: minPositive(cfg.Strategy.MaxPositions, cfg.Risk.MaxOpenPositions),
|
||||
}, services)
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
package backtest
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
func TestRequireZeroCommissionDefaultDoesNotOverrideExplicitFalse(t *testing.T) {
|
||||
defaultEngine := New(Config{})
|
||||
if !defaultEngine.requireZeroCommission() {
|
||||
t.Fatal("default require_zero_commission should be true")
|
||||
}
|
||||
requireZero := false
|
||||
explicitEngine := New(Config{RequireZeroCommission: &requireZero})
|
||||
if explicitEngine.requireZeroCommission() {
|
||||
t.Fatal("explicit require_zero_commission=false was overridden")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssumedSpreadUsesFundTypeSpecificDefaults(t *testing.T) {
|
||||
engine := New(Config{
|
||||
AssumedSpreadBps: decimal.NewFromInt(20),
|
||||
InstrumentFundTypes: map[string]string{
|
||||
"mm": "money_market",
|
||||
"eq": "equity",
|
||||
},
|
||||
})
|
||||
if got := engine.assumedSpreadBps("mm"); !got.Equal(decimal.NewFromInt(5)) {
|
||||
t.Fatalf("money market spread=%s, want 5", got)
|
||||
}
|
||||
if got := engine.assumedSpreadBps("eq"); !got.Equal(decimal.NewFromInt(25)) {
|
||||
t.Fatalf("equity spread=%s, want 25", got)
|
||||
}
|
||||
if got := engine.assumedSpreadBps("unknown"); !got.Equal(decimal.NewFromInt(20)) {
|
||||
t.Fatalf("default spread=%s, want 20", got)
|
||||
}
|
||||
}
|
||||
+106
-35
@@ -8,6 +8,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
@@ -19,35 +20,40 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
EntrySlippageBps decimal.Decimal
|
||||
ExitSlippageBps decimal.Decimal
|
||||
CommissionRoundtripBps decimal.Decimal
|
||||
RiskBufferBps decimal.Decimal
|
||||
InitialEquity decimal.Decimal
|
||||
OutputDir string
|
||||
RollingShort int
|
||||
RollingLong int
|
||||
EWMALambda float64
|
||||
MinTStat60 decimal.Decimal
|
||||
MinWinRate60 decimal.Decimal
|
||||
MinNetEdgeBps decimal.Decimal
|
||||
MinADVRUB decimal.Decimal
|
||||
MaxSpreadBps decimal.Decimal
|
||||
MaxTickBps decimal.Decimal
|
||||
RequireZeroCommission bool
|
||||
MaxPositions int
|
||||
MaxPositionPct decimal.Decimal
|
||||
MaxTotalExposurePct decimal.Decimal
|
||||
MaxParticipationRate decimal.Decimal
|
||||
CashUsageBuffer decimal.Decimal
|
||||
RiskBudgetPct decimal.Decimal
|
||||
MinOrderNotionalRUB decimal.Decimal
|
||||
AssumedSpreadBps decimal.Decimal
|
||||
AssumedTickBps decimal.Decimal
|
||||
Lot int64
|
||||
UseMinuteModel bool
|
||||
EntryWindow TimeWindow
|
||||
ExitWindow TimeWindow
|
||||
EntrySlippageBps decimal.Decimal
|
||||
ExitSlippageBps decimal.Decimal
|
||||
CommissionRoundtripBps decimal.Decimal
|
||||
RiskBufferBps decimal.Decimal
|
||||
InitialEquity decimal.Decimal
|
||||
OutputDir string
|
||||
RollingShort int
|
||||
RollingLong int
|
||||
EWMALambda float64
|
||||
MinTStat60 decimal.Decimal
|
||||
MinWinRate60 decimal.Decimal
|
||||
MinNetEdgeBps decimal.Decimal
|
||||
MinADVRUB decimal.Decimal
|
||||
MaxSpreadBps decimal.Decimal
|
||||
MaxSpreadBpsMoneyMarket decimal.Decimal
|
||||
MaxSpreadBpsBondFunds decimal.Decimal
|
||||
MaxSpreadBpsEquityFunds decimal.Decimal
|
||||
MaxTickBps decimal.Decimal
|
||||
RequireZeroCommission *bool
|
||||
MaxPositions int
|
||||
MaxPositionPct decimal.Decimal
|
||||
MaxTotalExposurePct decimal.Decimal
|
||||
MaxParticipationRate decimal.Decimal
|
||||
CashUsageBuffer decimal.Decimal
|
||||
RiskBudgetPct decimal.Decimal
|
||||
MinOrderNotionalRUB decimal.Decimal
|
||||
AssumedSpreadBps decimal.Decimal
|
||||
AssumedSpreadBpsByFundType map[string]decimal.Decimal
|
||||
InstrumentFundTypes map[string]string
|
||||
AssumedTickBps decimal.Decimal
|
||||
Lot int64
|
||||
UseMinuteModel bool
|
||||
EntryWindow TimeWindow
|
||||
ExitWindow TimeWindow
|
||||
}
|
||||
|
||||
type TimeWindow struct {
|
||||
@@ -120,6 +126,15 @@ func (cfg Config) withDefaults() Config {
|
||||
if cfg.MaxSpreadBps.IsZero() {
|
||||
cfg.MaxSpreadBps = decimal.NewFromInt(20)
|
||||
}
|
||||
if cfg.MaxSpreadBpsMoneyMarket.IsZero() {
|
||||
cfg.MaxSpreadBpsMoneyMarket = decimal.NewFromInt(5)
|
||||
}
|
||||
if cfg.MaxSpreadBpsBondFunds.IsZero() {
|
||||
cfg.MaxSpreadBpsBondFunds = decimal.NewFromInt(10)
|
||||
}
|
||||
if cfg.MaxSpreadBpsEquityFunds.IsZero() {
|
||||
cfg.MaxSpreadBpsEquityFunds = decimal.NewFromInt(25)
|
||||
}
|
||||
if cfg.MaxTickBps.IsZero() {
|
||||
cfg.MaxTickBps = decimal.NewFromInt(10)
|
||||
}
|
||||
@@ -132,8 +147,9 @@ func (cfg Config) withDefaults() Config {
|
||||
if cfg.AssumedTickBps.IsZero() {
|
||||
cfg.AssumedTickBps = cfg.MaxTickBps
|
||||
}
|
||||
if !cfg.RequireZeroCommission && cfg.CommissionRoundtripBps.IsZero() {
|
||||
cfg.RequireZeroCommission = true
|
||||
if cfg.RequireZeroCommission == nil {
|
||||
requireZero := true
|
||||
cfg.RequireZeroCommission = &requireZero
|
||||
}
|
||||
if cfg.MaxPositions == 0 {
|
||||
cfg.MaxPositions = 5
|
||||
@@ -260,7 +276,7 @@ func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Can
|
||||
Lots: lots,
|
||||
Notional: notional,
|
||||
NetPnL: pnl,
|
||||
SpreadBps: e.cfg.AssumedSpreadBps,
|
||||
SpreadBps: c.spreadBps,
|
||||
SlippageBps: e.cfg.EntrySlippageBps.Add(e.cfg.ExitSlippageBps),
|
||||
OvernightGap: c.overnightGap,
|
||||
CapacityRUB: capacity,
|
||||
@@ -355,6 +371,7 @@ type candidate struct {
|
||||
buy decimal.Decimal
|
||||
sell decimal.Decimal
|
||||
netEdge decimal.Decimal
|
||||
spreadBps decimal.Decimal
|
||||
adv decimal.Decimal
|
||||
q05Abs decimal.Decimal
|
||||
overnightGap decimal.Decimal
|
||||
@@ -381,7 +398,8 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
|
||||
return candidate{}, false, nil
|
||||
}
|
||||
rawEdge := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000))
|
||||
cost := e.cfg.AssumedSpreadBps.
|
||||
spreadBps := e.assumedSpreadBps(instrumentUID)
|
||||
cost := spreadBps.
|
||||
Add(e.cfg.EntrySlippageBps).
|
||||
Add(e.cfg.ExitSlippageBps).
|
||||
Add(e.cfg.CommissionRoundtripBps).
|
||||
@@ -389,7 +407,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
|
||||
netEdge := rawEdge.Sub(cost)
|
||||
adv := features.ADV(history, e.cfg.Lot, 20)
|
||||
switch {
|
||||
case e.cfg.RequireZeroCommission && e.cfg.CommissionRoundtripBps.IsPositive():
|
||||
case e.requireZeroCommission() && e.cfg.CommissionRoundtripBps.IsPositive():
|
||||
return candidate{}, false, nil
|
||||
case !decimal.NewFromFloat(short.Mean).IsPositive() || !decimal.NewFromFloat(long.Mean).IsPositive():
|
||||
return candidate{}, false, nil
|
||||
@@ -399,7 +417,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
|
||||
return candidate{}, false, nil
|
||||
case netEdge.LessThan(e.cfg.MinNetEdgeBps):
|
||||
return candidate{}, false, nil
|
||||
case e.cfg.AssumedSpreadBps.GreaterThan(e.cfg.MaxSpreadBps):
|
||||
case spreadBps.GreaterThan(e.maxSpreadBps(instrumentUID)):
|
||||
return candidate{}, false, nil
|
||||
case e.cfg.AssumedTickBps.GreaterThan(e.cfg.MaxTickBps):
|
||||
return candidate{}, false, nil
|
||||
@@ -425,6 +443,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
|
||||
buy: buy,
|
||||
sell: sell,
|
||||
netEdge: netEdge,
|
||||
spreadBps: spreadBps,
|
||||
adv: adv,
|
||||
q05Abs: q05Abs,
|
||||
overnightGap: gap,
|
||||
@@ -432,6 +451,58 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func (e Engine) requireZeroCommission() bool {
|
||||
return e.cfg.RequireZeroCommission != nil && *e.cfg.RequireZeroCommission
|
||||
}
|
||||
|
||||
func (e Engine) assumedSpreadBps(instrumentUID string) decimal.Decimal {
|
||||
fundType := normalizedFundType(e.cfg.InstrumentFundTypes[instrumentUID])
|
||||
if !fundType.IsZeroValue {
|
||||
if spread, ok := e.cfg.AssumedSpreadBpsByFundType[fundType.Key]; ok {
|
||||
return spread
|
||||
}
|
||||
return e.maxSpreadBpsForFundType(fundType.Raw)
|
||||
}
|
||||
return e.cfg.AssumedSpreadBps
|
||||
}
|
||||
|
||||
func (e Engine) maxSpreadBps(instrumentUID string) decimal.Decimal {
|
||||
fundType := normalizedFundType(e.cfg.InstrumentFundTypes[instrumentUID])
|
||||
if fundType.IsZeroValue {
|
||||
return e.cfg.MaxSpreadBps
|
||||
}
|
||||
return e.maxSpreadBpsForFundType(fundType.Raw)
|
||||
}
|
||||
|
||||
func (e Engine) maxSpreadBpsForFundType(fundType string) decimal.Decimal {
|
||||
switch {
|
||||
case strings.Contains(fundType, "money"):
|
||||
return e.cfg.MaxSpreadBpsMoneyMarket
|
||||
case strings.Contains(fundType, "bond"):
|
||||
return e.cfg.MaxSpreadBpsBondFunds
|
||||
case strings.Contains(fundType, "equity"):
|
||||
return e.cfg.MaxSpreadBpsEquityFunds
|
||||
default:
|
||||
return e.cfg.MaxSpreadBps
|
||||
}
|
||||
}
|
||||
|
||||
type normalizedType struct {
|
||||
Raw string
|
||||
Key string
|
||||
IsZeroValue bool
|
||||
}
|
||||
|
||||
func normalizedFundType(raw string) normalizedType {
|
||||
raw = strings.ToLower(strings.TrimSpace(raw))
|
||||
if raw == "" {
|
||||
return normalizedType{IsZeroValue: true}
|
||||
}
|
||||
key := strings.ReplaceAll(raw, "-", "_")
|
||||
key = strings.ReplaceAll(key, " ", "_")
|
||||
return normalizedType{Raw: raw, Key: key}
|
||||
}
|
||||
|
||||
func prepareCandles(candlesByInstrument map[string][]domain.Candle) map[string][]domain.Candle {
|
||||
prepared := make(map[string][]domain.Candle, len(candlesByInstrument))
|
||||
for instrumentUID, candles := range candlesByInstrument {
|
||||
|
||||
+22
-10
@@ -68,14 +68,15 @@ type TelegramConfig struct {
|
||||
}
|
||||
|
||||
type StrategyConfig struct {
|
||||
RollingShort int `env:"ROLLING_SHORT" envDefault:"60"`
|
||||
RollingLong int `env:"ROLLING_LONG" envDefault:"252"`
|
||||
EWMALambda float64 `env:"EWMA_LAMBDA" envDefault:"0.08"`
|
||||
MinTStat60 decimal.Decimal `env:"MIN_TSTAT_60" envDefault:"1.25"`
|
||||
MinWinRate60 decimal.Decimal `env:"MIN_WIN_RATE_60" envDefault:"0.55"`
|
||||
MinNetEdgeBps decimal.Decimal `env:"MIN_NET_EDGE_BPS" envDefault:"10"`
|
||||
RiskBufferBps decimal.Decimal `env:"RISK_BUFFER_BPS" envDefault:"5"`
|
||||
MaxPositions int `env:"MAX_POSITIONS" envDefault:"5"`
|
||||
RollingShort int `env:"ROLLING_SHORT" envDefault:"60"`
|
||||
RollingLong int `env:"ROLLING_LONG" envDefault:"252"`
|
||||
EWMALambda float64 `env:"EWMA_LAMBDA" envDefault:"0.08"`
|
||||
AllocationMethod string `env:"ALLOCATION_METHOD" envDefault:"equal_weight"`
|
||||
MinTStat60 decimal.Decimal `env:"MIN_TSTAT_60" envDefault:"1.25"`
|
||||
MinWinRate60 decimal.Decimal `env:"MIN_WIN_RATE_60" envDefault:"0.55"`
|
||||
MinNetEdgeBps decimal.Decimal `env:"MIN_NET_EDGE_BPS" envDefault:"10"`
|
||||
RiskBufferBps decimal.Decimal `env:"RISK_BUFFER_BPS" envDefault:"5"`
|
||||
MaxPositions int `env:"MAX_POSITIONS" envDefault:"5"`
|
||||
}
|
||||
|
||||
type ExecutionConfig struct {
|
||||
@@ -203,8 +204,19 @@ func (c *Config) Validate() error {
|
||||
if c.Risk.CommissionToleranceRUB.IsNegative() {
|
||||
return errors.New("RISK_COMMISSION_TOLERANCE_RUB must be non-negative")
|
||||
}
|
||||
if c.Commission.FreeOrderCountPolicy != "submitted" {
|
||||
return fmt.Errorf("COMM_FREE_ORDER_COUNT_POLICY must be submitted, got %q", c.Commission.FreeOrderCountPolicy)
|
||||
if c.Commission.FreeOrderCountPolicy == "" {
|
||||
c.Commission.FreeOrderCountPolicy = "submitted"
|
||||
}
|
||||
switch c.Commission.FreeOrderCountPolicy {
|
||||
case "submitted", "cancel_counts":
|
||||
default:
|
||||
return fmt.Errorf("COMM_FREE_ORDER_COUNT_POLICY must be submitted or cancel_counts, got %q", c.Commission.FreeOrderCountPolicy)
|
||||
}
|
||||
if c.Strategy.AllocationMethod == "" {
|
||||
c.Strategy.AllocationMethod = "equal_weight"
|
||||
}
|
||||
if c.Strategy.AllocationMethod != "equal_weight" {
|
||||
return fmt.Errorf("STRATEGY_ALLOCATION_METHOD must be equal_weight, got %q", c.Strategy.AllocationMethod)
|
||||
}
|
||||
if err := c.validateWindows(); err != nil {
|
||||
return err
|
||||
|
||||
@@ -19,6 +19,14 @@ func TestValidateRequiresAccountIDForBrokerModes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAllowsCancelCountsFreeOrderPolicy(t *testing.T) {
|
||||
cfg := minimalBrokerConfig(domain.ModeSandbox)
|
||||
cfg.Commission.FreeOrderCountPolicy = "cancel_counts"
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate cancel_counts: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func minimalBrokerConfig(mode domain.Mode) Config {
|
||||
return Config{
|
||||
App: AppConfig{
|
||||
@@ -44,6 +52,7 @@ func minimalBrokerConfig(mode domain.Mode) Config {
|
||||
QuoteDepth: 20,
|
||||
OrderPollIntervalMS: 500,
|
||||
},
|
||||
Strategy: StrategyConfig{AllocationMethod: "equal_weight"},
|
||||
Risk: RiskConfig{
|
||||
APIOutageHaltSec: 180,
|
||||
ReconciliationWindowHours: 72,
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
var ErrBrokerOrdersDisabled = errors.New("broker orders are disabled for current mode")
|
||||
var ErrEmptyOrderBook = errors.New("order book has no usable bid/ask")
|
||||
|
||||
const (
|
||||
FreeOrderPolicySubmitted = "submitted"
|
||||
FreeOrderPolicyCancelCounts = "cancel_counts"
|
||||
)
|
||||
|
||||
type Gateway interface {
|
||||
PostLimitOrder(ctx context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error)
|
||||
CancelOrder(ctx context.Context, accountID, orderID string) error
|
||||
@@ -24,12 +29,13 @@ type Gateway interface {
|
||||
}
|
||||
|
||||
type Engine struct {
|
||||
mode domain.Mode
|
||||
accountID string
|
||||
gateway Gateway
|
||||
store repository.Repository
|
||||
maxQuoteAge time.Duration
|
||||
mu sync.Map
|
||||
mode domain.Mode
|
||||
accountID string
|
||||
gateway Gateway
|
||||
store repository.Repository
|
||||
maxQuoteAge time.Duration
|
||||
freeOrderCountPolicy string
|
||||
mu sync.Map
|
||||
}
|
||||
|
||||
type MonitorConfig struct {
|
||||
@@ -44,13 +50,22 @@ type MonitorConfig struct {
|
||||
}
|
||||
|
||||
func NewEngine(mode domain.Mode, accountID string, gateway Gateway, store repository.Repository) Engine {
|
||||
return Engine{mode: mode, accountID: accountID, gateway: gateway, store: store}
|
||||
return Engine{mode: mode, accountID: accountID, gateway: gateway, store: store, freeOrderCountPolicy: FreeOrderPolicySubmitted}
|
||||
}
|
||||
|
||||
func (e *Engine) SetMaxQuoteAge(maxQuoteAge time.Duration) {
|
||||
e.maxQuoteAge = maxQuoteAge
|
||||
}
|
||||
|
||||
func (e *Engine) SetFreeOrderCountPolicy(policy string) {
|
||||
switch policy {
|
||||
case FreeOrderPolicyCancelCounts:
|
||||
e.freeOrderCountPolicy = policy
|
||||
default:
|
||||
e.freeOrderCountPolicy = FreeOrderPolicySubmitted
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) PlaceEntry(ctx context.Context, accountIDHash string, instrument domain.Instrument, tradeDate time.Time, lots int64, book domain.OrderBook, improveTicks int, attempt int) (domain.Order, error) {
|
||||
if err := e.checkQuoteFresh(book); err != nil {
|
||||
return domain.Order{}, err
|
||||
@@ -251,7 +266,15 @@ func (e *Engine) Cancel(ctx context.Context, order domain.Order) error {
|
||||
return err
|
||||
}
|
||||
if e.store != nil {
|
||||
return e.store.UpdateOrderStatus(ctx, order.ClientOrderID, domain.OrderStatusCancelled, order.FilledLots, order.RawStateJSON)
|
||||
return e.store.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
|
||||
if err := repo.UpdateOrderStatus(ctx, order.ClientOrderID, domain.OrderStatusCancelled, order.FilledLots, order.RawStateJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
if e.cancelCountsAsFreeOrder() {
|
||||
return repo.IncrementFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -485,12 +508,21 @@ func (e *Engine) ensureRepostBudget(ctx context.Context, order domain.Order, ins
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if instrument.FreeOrderLimitPerDay-sent < 1 {
|
||||
return fmt.Errorf("%w: %s remaining=0", risk.ErrFreeOrderBudget, instrument.InstrumentUID)
|
||||
needed := 1
|
||||
if e.cancelCountsAsFreeOrder() {
|
||||
needed = 2
|
||||
}
|
||||
remaining := instrument.FreeOrderLimitPerDay - sent
|
||||
if remaining < needed {
|
||||
return fmt.Errorf("%w: %s remaining=%d needed=%d", risk.ErrFreeOrderBudget, instrument.InstrumentUID, remaining, needed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) cancelCountsAsFreeOrder() bool {
|
||||
return e.freeOrderCountPolicy == FreeOrderPolicyCancelCounts
|
||||
}
|
||||
|
||||
func (e *Engine) checkQuoteFresh(book domain.OrderBook) error {
|
||||
if e.maxQuoteAge <= 0 {
|
||||
return nil
|
||||
|
||||
@@ -96,6 +96,40 @@ func TestPaperPlaceEntryFillsAndCountsSubmittedOrder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelCountsAsFreeOrderWhenPolicyEnabled(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
gateway := tinvest.NewFakeGateway()
|
||||
engine := NewEngine(domain.ModeSandbox, "account", gateway, repo)
|
||||
engine.SetFreeOrderCountPolicy(FreeOrderPolicyCancelCounts)
|
||||
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
|
||||
order, err := engine.PlaceLimit(ctx, domain.Order{
|
||||
ClientOrderID: "order-1",
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
TradeDate: tradeDate,
|
||||
Side: domain.SideBuy,
|
||||
OrderType: domain.OrderTypeLimit,
|
||||
LimitPrice: decimal.NewFromInt(100),
|
||||
QuantityLots: 1,
|
||||
Status: domain.OrderStatusNew,
|
||||
AttemptNo: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := engine.Cancel(ctx, order); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sent, err := repo.GetFreeOrdersSent(ctx, tradeDate, "uid")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sent != 2 {
|
||||
t.Fatalf("free order counter=%d, want submit+cancel", sent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlaceEntryRejectsStaleQuote(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
engine := NewEngine(domain.ModeSandbox, "account", tinvest.NewFakeGateway(), testutil.NewMemoryRepository())
|
||||
|
||||
@@ -40,7 +40,8 @@ func NewPipeline(repo repository.Repository, cfg PipelineConfig) Pipeline {
|
||||
|
||||
func (p Pipeline) Recompute(ctx context.Context, instrument domain.Instrument, tradeDate time.Time, spread SpreadResult) (domain.FeatureSet, error) {
|
||||
from := tradeDate.AddDate(0, 0, -p.cfg.RollingLong-5)
|
||||
candles, err := p.repo.ListDailyCandles(ctx, instrument.InstrumentUID, from, tradeDate)
|
||||
to := dateOnly(tradeDate).AddDate(0, 0, -1)
|
||||
candles, err := p.repo.ListDailyCandles(ctx, instrument.InstrumentUID, from, to)
|
||||
if err != nil {
|
||||
return domain.FeatureSet{}, err
|
||||
}
|
||||
@@ -74,8 +75,9 @@ func (p Pipeline) intervalVolume(ctx context.Context, instrument domain.Instrume
|
||||
if lookback <= 0 {
|
||||
lookback = defaultIntervalVolumeLookback
|
||||
}
|
||||
from := window.Start.On(date.AddDate(0, 0, -lookback), loc).UTC()
|
||||
to := window.End.On(date, loc).UTC()
|
||||
toDate := dateOnly(date).AddDate(0, 0, -1)
|
||||
from := window.Start.On(toDate.AddDate(0, 0, -lookback+1), loc).UTC()
|
||||
to := window.End.On(toDate, loc).UTC()
|
||||
candles, err := p.repo.ListMinuteCandles(ctx, instrument.InstrumentUID, from, to)
|
||||
if err != nil {
|
||||
return decimal.Zero, err
|
||||
@@ -84,6 +86,7 @@ func (p Pipeline) intervalVolume(ctx context.Context, instrument domain.Instrume
|
||||
}
|
||||
|
||||
func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate time.Time, spread SpreadResult, cfg PipelineConfig, entryVolume, exitVolume decimal.Decimal) (domain.FeatureSet, error) {
|
||||
candles = historicalDailyCandles(candles, tradeDate)
|
||||
if len(candles) < 2 {
|
||||
return domain.FeatureSet{}, fmt.Errorf("need at least 2 candles, got %d", len(candles))
|
||||
}
|
||||
@@ -138,6 +141,22 @@ func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate ti
|
||||
}, nil
|
||||
}
|
||||
|
||||
func historicalDailyCandles(candles []domain.Candle, tradeDate time.Time) []domain.Candle {
|
||||
tradeDay := dateOnly(tradeDate)
|
||||
out := make([]domain.Candle, 0, len(candles))
|
||||
for _, candle := range candles {
|
||||
if dateOnly(candle.TradeDate).Before(tradeDay) {
|
||||
out = append(out, candle)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func dateOnly(ts time.Time) time.Time {
|
||||
year, month, day := ts.UTC().Date()
|
||||
return time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func IntervalVolume(candles []domain.Candle, lot int64) decimal.Decimal {
|
||||
if lot <= 0 {
|
||||
return decimal.Zero
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/testutil"
|
||||
"overnight-trading-bot/internal/timeutil"
|
||||
)
|
||||
|
||||
@@ -74,6 +76,41 @@ func TestAverageIntervalVolumeUsesExecutionWindowsAcrossDays(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecomputeExcludesTradeDateDailyCandle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
start := time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC)
|
||||
var candles []domain.Candle
|
||||
for i := 0; i < 6; i++ {
|
||||
closePrice := decimal.NewFromInt(100)
|
||||
if i == 5 {
|
||||
closePrice = decimal.NewFromInt(100000)
|
||||
}
|
||||
candles = append(candles, domain.Candle{
|
||||
InstrumentUID: "uid",
|
||||
TradeDate: start.AddDate(0, 0, i),
|
||||
Open: decimal.NewFromInt(100),
|
||||
Close: closePrice,
|
||||
VolumeLots: decimal.NewFromInt(1),
|
||||
})
|
||||
}
|
||||
if err := repo.UpsertDailyCandles(ctx, candles); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pipeline := NewPipeline(repo, PipelineConfig{
|
||||
RollingShort: 2,
|
||||
RollingLong: 2,
|
||||
EWMALambda: 0.08,
|
||||
})
|
||||
got, err := pipeline.Recompute(ctx, domain.Instrument{InstrumentUID: "uid", Lot: 1}, start.AddDate(0, 0, 5), SpreadResult{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !got.ADV20.Equal(decimal.NewFromInt(100)) {
|
||||
t.Fatalf("ADV20=%s, want tradeDate candle excluded", got.ADV20)
|
||||
}
|
||||
}
|
||||
|
||||
func mustTOD(raw string) timeutil.TimeOfDay {
|
||||
tod, err := timeutil.ParseTimeOfDay(raw)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -22,7 +24,10 @@ func New(level string, out io.Writer) *slog.Logger {
|
||||
default:
|
||||
slogLevel = slog.LevelInfo
|
||||
}
|
||||
return slog.New(slog.NewJSONHandler(out, &slog.HandlerOptions{Level: slogLevel}))
|
||||
return slog.New(slog.NewJSONHandler(out, &slog.HandlerOptions{
|
||||
Level: slogLevel,
|
||||
ReplaceAttr: redactAttr,
|
||||
}))
|
||||
}
|
||||
|
||||
type SDKLogger struct {
|
||||
@@ -31,18 +36,69 @@ type SDKLogger struct {
|
||||
|
||||
func (l SDKLogger) Infof(template string, args ...any) {
|
||||
if l.Logger != nil {
|
||||
l.Logger.Info(template, "args", args)
|
||||
l.Logger.Info(RedactString(template), "args", redactArgs(args))
|
||||
}
|
||||
}
|
||||
|
||||
func (l SDKLogger) Errorf(template string, args ...any) {
|
||||
if l.Logger != nil {
|
||||
l.Logger.Error(template, "args", args)
|
||||
l.Logger.Error(RedactString(template), "args", redactArgs(args))
|
||||
}
|
||||
}
|
||||
|
||||
func (l SDKLogger) Fatalf(template string, args ...any) {
|
||||
if l.Logger != nil {
|
||||
l.Logger.Error(template, "args", args)
|
||||
l.Logger.Error(RedactString(template), "args", redactArgs(args))
|
||||
}
|
||||
}
|
||||
|
||||
var sensitiveStringPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)((?:account[_-]?id|token)\s*[:=]\s*)("[^"]+"|'[^']+'|[^\s,}]+)`),
|
||||
regexp.MustCompile(`(?i)("(?:accountId|account_id|token)"\s*:\s*)("[^"]*"|null)`),
|
||||
}
|
||||
|
||||
func redactAttr(_ []string, attr slog.Attr) slog.Attr {
|
||||
if attr.Value.Kind() == slog.KindString {
|
||||
attr.Value = slog.StringValue(RedactString(attr.Value.String()))
|
||||
}
|
||||
return attr
|
||||
}
|
||||
|
||||
func redactArgs(args []any) []any {
|
||||
out := make([]any, len(args))
|
||||
for i, arg := range args {
|
||||
out[i] = redactAny(arg)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func redactAny(value any) any {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return RedactString(typed)
|
||||
case []string:
|
||||
out := make([]string, len(typed))
|
||||
for i, item := range typed {
|
||||
out[i] = RedactString(item)
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]any, len(typed))
|
||||
for i, item := range typed {
|
||||
out[i] = redactAny(item)
|
||||
}
|
||||
return out
|
||||
case fmt.Stringer:
|
||||
return RedactString(typed.String())
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func RedactString(raw string) string {
|
||||
redacted := raw
|
||||
for _, pattern := range sensitiveStringPatterns {
|
||||
redacted = pattern.ReplaceAllString(redacted, `${1}"[REDACTED]"`)
|
||||
}
|
||||
return redacted
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
var defaultCommissionTolerance = decimal.RequireFromString("0.01")
|
||||
|
||||
type Engine struct {
|
||||
mu *sync.Mutex
|
||||
repo repository.Repository
|
||||
gateway tinvest.Gateway
|
||||
accountID string
|
||||
@@ -31,6 +33,7 @@ type Engine struct {
|
||||
|
||||
func New(repo repository.Repository, gateway tinvest.Gateway, accountID, accountIDHash string) Engine {
|
||||
return Engine{
|
||||
mu: &sync.Mutex{},
|
||||
repo: repo,
|
||||
gateway: gateway,
|
||||
accountID: accountID,
|
||||
@@ -64,6 +67,10 @@ func (e Engine) WithCommissionPolicy(requireZero, quarantineOnNonZero bool, tole
|
||||
}
|
||||
|
||||
func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
|
||||
if e.mu != nil {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
}
|
||||
localOrders, err := e.repo.ListActiveOrders(ctx, e.accountIDHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,6 +157,7 @@ func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
|
||||
})
|
||||
}
|
||||
}
|
||||
diffs = append(diffs, compareCash(localPositions, portfolio, e.commissionTolerance)...)
|
||||
from := now.Add(-e.window)
|
||||
recentOrders, err := e.repo.ListOrders(ctx, e.accountIDHash, from, now)
|
||||
if err != nil {
|
||||
@@ -204,6 +212,55 @@ func compareOperations(orders []domain.Order, operations []domain.Operation) []d
|
||||
return compareOperationsWithPolicy(orders, operations, false, defaultCommissionTolerance)
|
||||
}
|
||||
|
||||
func compareCash(localPositions []domain.Position, portfolio domain.Portfolio, tolerance decimal.Decimal) []domain.ReconciliationDiff {
|
||||
if tolerance.IsNegative() {
|
||||
tolerance = decimal.Zero
|
||||
}
|
||||
expectedCash, ok := expectedCashFromLocalPositions(localPositions, portfolio)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
diff := money.Abs(expectedCash.Sub(portfolio.Cash))
|
||||
if diff.LessThanOrEqual(tolerance) {
|
||||
return nil
|
||||
}
|
||||
return []domain.ReconciliationDiff{{
|
||||
Kind: "cash_mismatch",
|
||||
Message: fmt.Sprintf("expected cash=%s broker cash=%s diff=%s", expectedCash.StringFixed(2), portfolio.Cash.StringFixed(2), diff.StringFixed(2)),
|
||||
Critical: true,
|
||||
}}
|
||||
}
|
||||
|
||||
func expectedCashFromLocalPositions(localPositions []domain.Position, portfolio domain.Portfolio) (decimal.Decimal, bool) {
|
||||
if !portfolio.Equity.IsPositive() {
|
||||
return decimal.Zero, false
|
||||
}
|
||||
if len(localPositions) == 0 {
|
||||
if len(portfolio.Holdings) != 0 {
|
||||
return decimal.Zero, false
|
||||
}
|
||||
return portfolio.Equity, true
|
||||
}
|
||||
holdingByInstrument := make(map[string]domain.Holding, len(portfolio.Holdings))
|
||||
for _, holding := range portfolio.Holdings {
|
||||
holdingByInstrument[holding.InstrumentUID] = holding
|
||||
}
|
||||
positionMarketValue := decimal.Zero
|
||||
for _, pos := range localPositions {
|
||||
if pos.Lots <= 0 {
|
||||
continue
|
||||
}
|
||||
holding, ok := holdingByInstrument[pos.InstrumentUID]
|
||||
if !ok || holding.QuantityLots <= 0 || !holding.MarketValue.IsPositive() {
|
||||
return decimal.Zero, false
|
||||
}
|
||||
positionMarketValue = positionMarketValue.Add(holding.MarketValue.
|
||||
Mul(decimal.NewFromInt(pos.Lots)).
|
||||
Div(decimal.NewFromInt(holding.QuantityLots)))
|
||||
}
|
||||
return portfolio.Equity.Sub(positionMarketValue), true
|
||||
}
|
||||
|
||||
func compareOperationsWithPolicy(orders []domain.Order, operations []domain.Operation, requireZeroCommission bool, commissionTolerance decimal.Decimal) []domain.ReconciliationDiff {
|
||||
var diffs []domain.ReconciliationDiff
|
||||
if commissionTolerance.IsNegative() {
|
||||
|
||||
@@ -170,3 +170,37 @@ func TestReconciliationSkipsFreshInFlightLocalOrders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconciliationFindsCashMismatch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
gateway := tinvest.NewFakeGateway()
|
||||
if err := repo.UpsertPosition(ctx, domain.Position{
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
OpenTradeDate: time.Now().UTC(),
|
||||
Lots: 2,
|
||||
Status: domain.PositionHoldingOvernight,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gateway.Portfolio = domain.Portfolio{
|
||||
Equity: decimal.NewFromInt(1000),
|
||||
Cash: decimal.NewFromInt(700),
|
||||
Holdings: []domain.Holding{{
|
||||
InstrumentUID: "uid",
|
||||
QuantityLots: 2,
|
||||
MarketValue: decimal.NewFromInt(200),
|
||||
}},
|
||||
}
|
||||
diffs, err := New(repo, gateway, "account", "hash").Run(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, diff := range diffs {
|
||||
if diff.Kind == "cash_mismatch" && diff.Critical {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("missing cash_mismatch in %+v", diffs)
|
||||
}
|
||||
|
||||
+154
-22
@@ -45,6 +45,7 @@ type Config struct {
|
||||
EntryWindowEnd timeutil.TimeOfDay
|
||||
NoNewEntryAfter timeutil.TimeOfDay
|
||||
ExitWatchStart timeutil.TimeOfDay
|
||||
ExitNotBefore timeutil.TimeOfDay
|
||||
ExitWindowStart timeutil.TimeOfDay
|
||||
ExitWindowEnd timeutil.TimeOfDay
|
||||
HardExitDeadline timeutil.TimeOfDay
|
||||
@@ -60,6 +61,7 @@ type Config struct {
|
||||
APIOutageHalt time.Duration
|
||||
RequireZeroCommission bool
|
||||
QuarantineOnNonZero bool
|
||||
FreeOrderCountPolicy string
|
||||
ReconciliationInterval time.Duration
|
||||
MaxOpenPositions int
|
||||
}
|
||||
@@ -133,6 +135,13 @@ func (s *Scheduler) Step(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
now := s.clock.Now().In(s.cfg.Location)
|
||||
reported, err := s.sendMissedDailyReport(ctx, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reported {
|
||||
return nil
|
||||
}
|
||||
phase := s.phase(now)
|
||||
switch phase {
|
||||
case domain.StateWaitExitWindow:
|
||||
@@ -158,10 +167,14 @@ func (s *Scheduler) Step(ctx context.Context) error {
|
||||
|
||||
func (s Scheduler) phase(now time.Time) domain.SystemState {
|
||||
tod := sinceMidnight(now)
|
||||
exitWindowStart := s.cfg.ExitWindowStart.Duration
|
||||
if s.cfg.ExitNotBefore.Duration > exitWindowStart {
|
||||
exitWindowStart = s.cfg.ExitNotBefore.Duration
|
||||
}
|
||||
switch {
|
||||
case tod >= s.cfg.ExitWatchStart.Duration && tod < s.cfg.ExitWindowStart.Duration:
|
||||
case tod >= s.cfg.ExitWatchStart.Duration && tod < exitWindowStart:
|
||||
return domain.StateWaitExitWindow
|
||||
case tod >= s.cfg.ExitWindowStart.Duration && tod < s.cfg.ExitWindowEnd.Duration:
|
||||
case tod >= exitWindowStart && tod < s.cfg.ExitWindowEnd.Duration:
|
||||
return domain.StatePlaceExitOrders
|
||||
case tod >= s.cfg.ExitWindowEnd.Duration && tod < s.cfg.HardExitDeadline.Duration:
|
||||
return domain.StateMonitorExitOrders
|
||||
@@ -463,7 +476,7 @@ func (s *Scheduler) placeEntryOrders(ctx context.Context, now time.Time) error {
|
||||
}
|
||||
continue
|
||||
}
|
||||
pre, err := s.preTradeCheck(ctx, now, portfolio, projectedOpenPositions, tradingStatus, book.ReceivedAt)
|
||||
pre, err := s.preTradeCheck(ctx, now, sig.InstrumentUID, portfolio, projectedOpenPositions, tradingStatus, book.ReceivedAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -585,7 +598,7 @@ func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("instrument %s is not in registry", pos.InstrumentUID)
|
||||
}
|
||||
if _, err := s.svc.FreeOrders.Check(ctx, exitTradeDate, instrument, s.cfg.MaxExitOrderAttempts); err != nil {
|
||||
if _, err := s.svc.FreeOrders.Check(ctx, exitTradeDate, instrument, s.orderBudgetNeededForAttempts(s.cfg.MaxExitOrderAttempts)); err != nil {
|
||||
if insertErr := s.recordPreTradeReject(ctx, pos.InstrumentUID, err.Error(), `{"reason":"free_order_budget_insufficient"}`); insertErr != nil {
|
||||
return insertErr
|
||||
}
|
||||
@@ -609,7 +622,7 @@ func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pre, err := s.preTradeCheck(ctx, now, portfolio, len(positionsList), tradingStatus, book.ReceivedAt)
|
||||
pre, err := s.preTradeCheck(ctx, now, pos.InstrumentUID, portfolio, len(positionsList), tradingStatus, book.ReceivedAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -722,12 +735,47 @@ func (s *Scheduler) reconcileAndReport(ctx context.Context, now time.Time) error
|
||||
if err := s.transitionTo(ctx, domain.StateReconcile); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.reconcileCritical(ctx, "reconciliation_critical"); err != nil {
|
||||
return err
|
||||
if s.cfg.Mode.AllowsBrokerOrders() {
|
||||
if err := s.reconcileCritical(ctx, "reconciliation_critical"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.sendDailyReport(ctx, now, "ok")
|
||||
}
|
||||
|
||||
func (s *Scheduler) sendMissedDailyReport(ctx context.Context, now time.Time) (bool, error) {
|
||||
if s.svc.Repo == nil || !s.hasStateMachine() {
|
||||
return false, nil
|
||||
}
|
||||
tod := sinceMidnight(now)
|
||||
if tod < s.cfg.EntrySignalTime.Duration {
|
||||
return false, nil
|
||||
}
|
||||
phase := s.phase(now)
|
||||
if phase == domain.StateReconcile || phase == domain.StateReport {
|
||||
return false, nil
|
||||
}
|
||||
state, halted, _, err := s.svc.Repo.GetSystemState(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if halted || state == domain.StateHalted {
|
||||
return false, nil
|
||||
}
|
||||
if state != domain.StateInit && state != domain.StateSleep {
|
||||
return false, nil
|
||||
}
|
||||
tradeDate := tradingDate(now)
|
||||
sent, err := s.svc.Repo.WasDailyReportSent(ctx, tradeDate, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if sent {
|
||||
return false, nil
|
||||
}
|
||||
return true, s.reconcileAndReport(ctx, now)
|
||||
}
|
||||
|
||||
func (s *Scheduler) sendDailyReport(ctx context.Context, now time.Time, riskStatus string) error {
|
||||
tradeDate := tradingDate(now)
|
||||
sent, err := s.svc.Repo.WasDailyReportSent(ctx, tradeDate, s.svc.AccountIDHash)
|
||||
@@ -1081,7 +1129,7 @@ func (s Scheduler) repostPreTradeCheck(ctx context.Context, now time.Time, order
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pre, err := s.preTradeCheck(ctx, now, portfolio, len(openPositions), tradingStatus, book.ReceivedAt)
|
||||
pre, err := s.preTradeCheck(ctx, now, order.InstrumentUID, portfolio, len(openPositions), tradingStatus, book.ReceivedAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1092,23 +1140,96 @@ func (s Scheduler) repostPreTradeCheck(ctx context.Context, now time.Time, order
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, portfolio domain.Portfolio, openPositions int, tradingStatus domain.TradingStatus, quoteReceivedAt time.Time) (risk.PreTradeResult, error) {
|
||||
func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, instrumentUID string, portfolio domain.Portfolio, openPositions int, tradingStatus domain.TradingStatus, quoteReceivedAt time.Time) (risk.PreTradeResult, error) {
|
||||
metrics, err := s.riskMetrics(ctx, now, portfolio)
|
||||
if err != nil {
|
||||
if haltErr := s.halt(ctx, "database_unavailable", fmt.Sprintf("pre-trade risk metrics unavailable: %s", err), instrumentUID); haltErr != nil {
|
||||
return risk.PreTradeResult{}, fmt.Errorf("database_unavailable: %w; halt failed: %v", err, haltErr)
|
||||
}
|
||||
return risk.PreTradeResult{Allowed: false, Reason: "database_unavailable"}, fmt.Errorf("%w: database_unavailable", statemachine.ErrSystemHalted)
|
||||
}
|
||||
unknownOrder, unknownHolding, err := s.unknownBrokerState(ctx, portfolio)
|
||||
if err != nil {
|
||||
return risk.PreTradeResult{}, err
|
||||
}
|
||||
return s.svc.Risk.PreTradeCheck(risk.PreTradeInput{
|
||||
Portfolio: portfolio,
|
||||
OpenPositions: openPositions,
|
||||
DailyPnL: metrics.dailyPnL,
|
||||
WeeklyPnL: metrics.weeklyPnL,
|
||||
MonthlyDrawdownPct: metrics.monthlyDrawdownPct,
|
||||
AvgSlippageBps10: metrics.avgSlippageBps10,
|
||||
TradingStatus: tradingStatus,
|
||||
QuoteReceivedAt: quoteReceivedAt,
|
||||
Now: now.UTC(),
|
||||
MarketClose: s.marketCloseOn(now),
|
||||
}), nil
|
||||
result := s.svc.Risk.PreTradeCheck(risk.PreTradeInput{
|
||||
Portfolio: portfolio,
|
||||
OpenPositions: openPositions,
|
||||
DailyPnL: metrics.dailyPnL,
|
||||
WeeklyPnL: metrics.weeklyPnL,
|
||||
MonthlyDrawdownPct: metrics.monthlyDrawdownPct,
|
||||
AvgSlippageBps10: metrics.avgSlippageBps10,
|
||||
TradingStatus: tradingStatus,
|
||||
QuoteReceivedAt: quoteReceivedAt,
|
||||
Now: now.UTC(),
|
||||
MarketClose: s.marketCloseOn(now),
|
||||
UnknownBrokerOrder: unknownOrder,
|
||||
UnknownBrokerHolding: unknownHolding,
|
||||
})
|
||||
if !result.Allowed && isHardHaltPreTradeReason(result.Reason) {
|
||||
if err := s.halt(ctx, result.Reason, fmt.Sprintf("pre-trade hard limit breached: %s", result.Reason), instrumentUID); err != nil {
|
||||
return result, err
|
||||
}
|
||||
return result, fmt.Errorf("%w: %s", statemachine.ErrSystemHalted, result.Reason)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s Scheduler) unknownBrokerState(ctx context.Context, portfolio domain.Portfolio) (bool, bool, error) {
|
||||
if !s.cfg.Mode.AllowsBrokerOrders() {
|
||||
return false, false, nil
|
||||
}
|
||||
localOrders, err := s.svc.Repo.ListActiveOrders(ctx, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
localByBroker := make(map[string]struct{}, len(localOrders))
|
||||
for _, order := range localOrders {
|
||||
if order.BrokerOrderID != "" {
|
||||
localByBroker[order.BrokerOrderID] = struct{}{}
|
||||
}
|
||||
}
|
||||
brokerOrders, err := s.svc.Gateway.GetActiveOrders(ctx, s.svc.AccountID)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
for _, brokerOrder := range brokerOrders {
|
||||
if brokerOrder.BrokerOrderID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := localByBroker[brokerOrder.BrokerOrderID]; !ok {
|
||||
return true, false, nil
|
||||
}
|
||||
}
|
||||
localPositions, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
localLots := make(map[string]int64, len(localPositions))
|
||||
for _, pos := range localPositions {
|
||||
localLots[pos.InstrumentUID] += pos.Lots
|
||||
}
|
||||
for _, holding := range portfolio.Holdings {
|
||||
if holding.QuantityLots > 0 && localLots[holding.InstrumentUID] == 0 {
|
||||
return false, true, nil
|
||||
}
|
||||
}
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func isHardHaltPreTradeReason(reason string) bool {
|
||||
switch reason {
|
||||
case "database_unavailable",
|
||||
"unknown_broker_order",
|
||||
"unknown_broker_position",
|
||||
"trading_status_unknown_before_order",
|
||||
"max_daily_loss",
|
||||
"max_weekly_loss",
|
||||
"max_monthly_drawdown":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type preTradeMetrics struct {
|
||||
@@ -1255,13 +1376,24 @@ func repostAfter(now, deadline time.Time, attempts int, poll time.Duration) time
|
||||
}
|
||||
|
||||
func (s Scheduler) maxOrderAttemptsPerTrade() int {
|
||||
needed := s.cfg.MaxEntryOrderAttempts + s.cfg.MaxExitOrderAttempts
|
||||
needed := s.orderBudgetNeededForAttempts(s.cfg.MaxEntryOrderAttempts) + s.orderBudgetNeededForAttempts(s.cfg.MaxExitOrderAttempts)
|
||||
if needed <= 0 {
|
||||
return 1
|
||||
}
|
||||
return needed
|
||||
}
|
||||
|
||||
func (s Scheduler) orderBudgetNeededForAttempts(attempts int) int {
|
||||
if attempts <= 0 {
|
||||
attempts = 1
|
||||
}
|
||||
needed := attempts
|
||||
if s.cfg.FreeOrderCountPolicy == execution.FreeOrderPolicyCancelCounts {
|
||||
needed += attempts - 1
|
||||
}
|
||||
return needed
|
||||
}
|
||||
|
||||
func isSizingSkipReason(reason string) bool {
|
||||
return reason == "lots_below_one" || reason == "min_order_notional"
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -57,6 +58,33 @@ func TestPhaseUsesMoscowWindows(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPhaseHonorsExitNotBeforeWhenWindowStartsEarlier(t *testing.T) {
|
||||
loc := time.FixedZone("MSK", 3*60*60)
|
||||
s := Scheduler{cfg: Config{
|
||||
Location: loc,
|
||||
EntrySignalTime: mustTOD("18:10:00"),
|
||||
ExitWatchStart: mustTOD("09:50:00"),
|
||||
ExitNotBefore: mustTOD("10:03:00"),
|
||||
ExitWindowStart: mustTOD("10:00:00"),
|
||||
ExitWindowEnd: mustTOD("10:25:00"),
|
||||
HardExitDeadline: mustTOD("10:45:00"),
|
||||
}}
|
||||
at, err := time.Parse(time.RFC3339, "2026-06-06T10:01:00+03:00")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := s.phase(at.In(loc)); got != domain.StateWaitExitWindow {
|
||||
t.Fatalf("phase before ExitNotBefore=%s, want WAIT_EXIT_WINDOW", got)
|
||||
}
|
||||
at, err = time.Parse(time.RFC3339, "2026-06-06T10:04:00+03:00")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := s.phase(at.In(loc)); got != domain.StatePlaceExitOrders {
|
||||
t.Fatalf("phase after ExitNotBefore=%s, want PLACE_EXIT_ORDERS", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInfrastructureOutageRequiresThreshold(t *testing.T) {
|
||||
gateway := tinvest.NewFakeGateway()
|
||||
gateway.ServerTime = time.Now().UTC().Add(-10 * time.Second)
|
||||
@@ -260,6 +288,80 @@ func TestNonZeroCommissionQuarantinesInstrumentAndHalts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreTradeDailyLossBreachHalts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
now := time.Date(2026, 6, 8, 18, 20, 0, 0, time.UTC)
|
||||
closedAt := now.Add(-time.Hour)
|
||||
if err := repo.UpsertPosition(ctx, domain.Position{
|
||||
AccountIDHash: "hash",
|
||||
InstrumentUID: "uid",
|
||||
OpenTradeDate: tradingDate(now),
|
||||
Status: domain.PositionExitFilled,
|
||||
NetPnL: decimal.NewFromInt(-200),
|
||||
ClosedAt: &closedAt,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
notifier := &countNotifier{}
|
||||
s := Scheduler{
|
||||
cfg: Config{Mode: domain.ModePaper, Location: time.UTC},
|
||||
svc: Services{
|
||||
Repo: repo,
|
||||
Risk: risk.NewManager(repo, risk.ManagerConfig{MaxDailyLossPct: decimal.RequireFromString("0.01")}),
|
||||
Notifier: notifier,
|
||||
AccountIDHash: "hash",
|
||||
},
|
||||
}
|
||||
_, err := s.preTradeCheck(ctx, now, "uid", domain.Portfolio{
|
||||
Equity: decimal.NewFromInt(10000),
|
||||
Cash: decimal.NewFromInt(10000),
|
||||
}, 0, domain.TradingStatusNormal, now)
|
||||
if !errors.Is(err, statemachine.ErrSystemHalted) {
|
||||
t.Fatalf("err=%v, want ErrSystemHalted", err)
|
||||
}
|
||||
if !repo.Halted || repo.HaltReason != "pre-trade hard limit breached: max_daily_loss" {
|
||||
t.Fatalf("halted=%v reason=%q", repo.Halted, repo.HaltReason)
|
||||
}
|
||||
if notifier.alerts != 1 {
|
||||
t.Fatalf("alerts=%d, want 1", notifier.alerts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStepSendsMissedDailyReportAfterEntrySignalTime(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
notifier := &countNotifier{}
|
||||
now := time.Date(2026, 6, 8, 18, 15, 0, 0, time.UTC)
|
||||
s := Scheduler{
|
||||
clock: fixedClock{now: now},
|
||||
cfg: Config{
|
||||
Mode: domain.ModePaper,
|
||||
Location: time.UTC,
|
||||
EntrySignalTime: mustTOD("18:10:00"),
|
||||
},
|
||||
sm: statemachine.New(repo, domain.ModePaper),
|
||||
svc: Services{
|
||||
Repo: repo,
|
||||
Notifier: notifier,
|
||||
AccountIDHash: "hash",
|
||||
},
|
||||
}
|
||||
if err := s.Step(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if notifier.reports != 1 {
|
||||
t.Fatalf("reports=%d, want catch-up report", notifier.reports)
|
||||
}
|
||||
sent, err := repo.WasDailyReportSent(ctx, now, "hash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !sent {
|
||||
t.Fatal("daily report was not marked as sent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSizeReductionRuleCutsSizerAfterBadExpectedErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
|
||||
@@ -37,6 +37,9 @@ func (s System) Recover(ctx context.Context, reconcile reconciliation.Engine) (d
|
||||
case domain.StatePlaceEntryOrders, domain.StateMonitorEntryOrders,
|
||||
domain.StatePlaceExitOrders, domain.StateMonitorExitOrders,
|
||||
domain.StateHoldOvernight:
|
||||
if !s.mode.AllowsBrokerOrders() {
|
||||
return state, nil
|
||||
}
|
||||
diffs, err := reconcile.Run(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -80,7 +80,7 @@ func TestCalendarRecoveryAllowsRestartInsideExitWindow(t *testing.T) {
|
||||
func TestRecoverFromMonitorEntryHaltsOnCriticalReconciliationDiff(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := testutil.NewMemoryRepository()
|
||||
if err := repo.SaveSystemState(ctx, domain.StateMonitorEntryOrders, domain.ModePaper, false, "", "{}"); err != nil {
|
||||
if err := repo.SaveSystemState(ctx, domain.StateMonitorEntryOrders, domain.ModeSandbox, false, "", "{}"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := repo.UpsertOrder(ctx, domain.Order{
|
||||
@@ -97,7 +97,7 @@ func TestRecoverFromMonitorEntryHaltsOnCriticalReconciliationDiff(t *testing.T)
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
system := New(repo, domain.ModePaper)
|
||||
system := New(repo, domain.ModeSandbox)
|
||||
state, err := system.Recover(ctx, reconciliation.New(repo, tinvest.NewFakeGateway(), "account", "hash"))
|
||||
if err == nil {
|
||||
t.Fatal("expected critical reconciliation error")
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/shopspring/decimal"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
"overnight-trading-bot/internal/logging"
|
||||
@@ -407,6 +408,13 @@ func orderFromPostResponse(resp *pb.PostOrderResponse, accountID, clientOrderID
|
||||
return domain.Order{}
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
avgFillPrice := decimal.Zero
|
||||
if resp.GetLotsExecuted() > 0 {
|
||||
avgFillPrice = money.MoneyValueToDecimal(resp.GetExecutedOrderPrice())
|
||||
if !avgFillPrice.IsPositive() {
|
||||
avgFillPrice = limitPrice
|
||||
}
|
||||
}
|
||||
return domain.Order{
|
||||
ClientOrderID: clientOrderID,
|
||||
BrokerOrderID: resp.GetOrderId(),
|
||||
@@ -417,7 +425,7 @@ func orderFromPostResponse(resp *pb.PostOrderResponse, accountID, clientOrderID
|
||||
LimitPrice: limitPrice,
|
||||
QuantityLots: resp.GetLotsRequested(),
|
||||
FilledLots: resp.GetLotsExecuted(),
|
||||
AvgFillPrice: limitPrice,
|
||||
AvgFillPrice: avgFillPrice,
|
||||
Status: mapOrderStatus(resp.GetExecutionReportStatus()),
|
||||
Commission: money.MoneyValueToDecimal(resp.GetExecutedCommission()),
|
||||
RawStateJSON: marshalProto(resp),
|
||||
@@ -438,6 +446,10 @@ func orderFromState(state *pb.OrderState, accountID string) domain.Order {
|
||||
if state.GetOrderDate() != nil {
|
||||
orderDate = state.GetOrderDate().AsTime().UTC()
|
||||
}
|
||||
avgFillPrice := decimal.Zero
|
||||
if state.GetLotsExecuted() > 0 {
|
||||
avgFillPrice = money.MoneyValueToDecimal(state.GetAveragePositionPrice())
|
||||
}
|
||||
return domain.Order{
|
||||
ClientOrderID: state.GetOrderRequestId(),
|
||||
BrokerOrderID: state.GetOrderId(),
|
||||
@@ -448,7 +460,7 @@ func orderFromState(state *pb.OrderState, accountID string) domain.Order {
|
||||
LimitPrice: money.MoneyValueToDecimal(state.GetInitialSecurityPrice()),
|
||||
QuantityLots: state.GetLotsRequested(),
|
||||
FilledLots: state.GetLotsExecuted(),
|
||||
AvgFillPrice: money.MoneyValueToDecimal(state.GetAveragePositionPrice()),
|
||||
AvgFillPrice: avgFillPrice,
|
||||
Status: mapOrderStatus(state.GetExecutionReportStatus()),
|
||||
Commission: money.MoneyValueToDecimal(state.GetExecutedCommission()),
|
||||
RawStateJSON: marshalProto(state),
|
||||
@@ -478,10 +490,52 @@ func marshalProto(msg proto.Message) string {
|
||||
if msg == nil {
|
||||
return "{}"
|
||||
}
|
||||
raw, err := protojson.Marshal(msg)
|
||||
sanitized := proto.Clone(msg)
|
||||
clearSensitiveProtoFields(sanitized.ProtoReflect())
|
||||
raw, err := protojson.Marshal(sanitized)
|
||||
if err != nil {
|
||||
fallback, _ := json.Marshal(map[string]string{"marshal_error": err.Error()})
|
||||
return string(fallback)
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func clearSensitiveProtoFields(message protoreflect.Message) {
|
||||
if !message.IsValid() {
|
||||
return
|
||||
}
|
||||
fields := message.Descriptor().Fields()
|
||||
for i := 0; i < fields.Len(); i++ {
|
||||
field := fields.Get(i)
|
||||
if isSensitiveProtoField(field.Name()) {
|
||||
message.Clear(field)
|
||||
continue
|
||||
}
|
||||
value := message.Get(field)
|
||||
switch {
|
||||
case field.IsList():
|
||||
list := value.List()
|
||||
if field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind {
|
||||
for j := 0; j < list.Len(); j++ {
|
||||
clearSensitiveProtoFields(list.Get(j).Message())
|
||||
}
|
||||
}
|
||||
case field.IsMap():
|
||||
if field.MapValue().Kind() == protoreflect.MessageKind || field.MapValue().Kind() == protoreflect.GroupKind {
|
||||
value.Map().Range(func(_ protoreflect.MapKey, value protoreflect.Value) bool {
|
||||
clearSensitiveProtoFields(value.Message())
|
||||
return true
|
||||
})
|
||||
}
|
||||
case field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind:
|
||||
if message.Has(field) {
|
||||
clearSensitiveProtoFields(value.Message())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isSensitiveProtoField(name protoreflect.Name) bool {
|
||||
normalized := strings.ReplaceAll(strings.ToLower(string(name)), "_", "")
|
||||
return normalized == "accountid"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package tinvest
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/russianinvestments/invest-api-go-sdk/proto"
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"overnight-trading-bot/internal/domain"
|
||||
)
|
||||
|
||||
func TestOrderFromPostResponseZeroFillHasZeroAvgPrice(t *testing.T) {
|
||||
order := orderFromPostResponse(&pb.PostOrderResponse{
|
||||
OrderId: "broker",
|
||||
ExecutionReportStatus: pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_NEW,
|
||||
LotsRequested: 1,
|
||||
LotsExecuted: 0,
|
||||
ExecutedOrderPrice: &pb.MoneyValue{Currency: "rub", Units: 100},
|
||||
InstrumentUid: "uid",
|
||||
}, "account", "client", domain.SideBuy, decimal.NewFromInt(100))
|
||||
if !order.AvgFillPrice.IsZero() {
|
||||
t.Fatalf("avg fill price=%s, want zero for unfilled order", order.AvgFillPrice)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalProtoRedactsAccountID(t *testing.T) {
|
||||
raw := marshalProto(&pb.OrderTrades{
|
||||
OrderId: "order",
|
||||
AccountId: "plain-account-id",
|
||||
InstrumentUid: "uid",
|
||||
})
|
||||
if strings.Contains(raw, "plain-account-id") || strings.Contains(raw, "accountId") || strings.Contains(raw, "account_id") {
|
||||
t.Fatalf("raw proto leaked account id: %s", raw)
|
||||
}
|
||||
if !strings.Contains(raw, "order") {
|
||||
t.Fatalf("sanitizer removed non-sensitive data: %s", raw)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user