diff --git a/README.md b/README.md index 8da3bda..32d0a5d 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ APP_MODE=backtest go run ./cmd/bot | `STRATEGY_ROLLING_SHORT` | количество торговых дней | `60` | рекомендуется `> 0` | Короткое окно статистики overnight-доходности. Больше - стабильнее оценка, но медленнее реакция; меньше - быстрее реакция, но больше шум. | | `STRATEGY_ROLLING_LONG` | количество торговых дней | `252` | рекомендуется `>= STRATEGY_ROLLING_SHORT` и `> 0` | Длинное окно для проверки положительного долгосрочного edge и глубины backfill. Больше требует больше истории. | | `STRATEGY_EWMA_LAMBDA` | дробь для EWMA | `0.08` | рабочий диапазон `(0, 1]`; вне диапазона EWMA-функция использует `0.08` | Вес новых наблюдений в EWMA. Больше - свежее движение влияет сильнее. | +| `STRATEGY_ALLOCATION_METHOD` | `equal_weight` | `equal_weight` | сейчас поддерживается только `equal_weight` | Метод распределения капитала между выбранными сигналами. Текущая реализация делит лимит экспозиции поровну между выбранными инструментами. | | `STRATEGY_MIN_TSTAT_60` | decimal t-stat | `1.25` | валидации нет; обычно `>= 0` | Минимальная статистическая значимость короткого edge. Выше - меньше входов, ниже - больше входов. | | `STRATEGY_MIN_WIN_RATE_60` | доля прибыльных overnight-дней | `0.55` | рекомендуется `0..1` | Минимальная доля положительных overnight-наблюдений. Выше - строже фильтр сигналов. | | `STRATEGY_MIN_NET_EDGE_BPS` | bps | `10` | валидации нет; обычно `>= 0` | Минимальный ожидаемый edge после издержек. Выше - меньше, но потенциально качественнее сигналы. | @@ -146,7 +147,7 @@ APP_MODE=backtest go run ./cmd/bot | --- | --- | --- | --- | --- | | `COMM_REQUIRE_ZERO_COMMISSION` | `true` или `false` | `true` | boolean | При `true` сигналы по инструментам с ожидаемой комиссией `> 0` отклоняются. | | `COMM_QUARANTINE_ON_NONZERO` | `true` или `false` | `true` | boolean | При фактической брокерской комиссии `> 0` инструмент переводится в quarantine, а система останавливается через HALT по zero-commission policy. | -| `COMM_FREE_ORDER_COUNT_POLICY` | `submitted` | `submitted` | жёстко только `submitted` | Политика учёта бесплатных заявок: счётчик увеличивается при отправке заявки. Другие значения запрещены валидацией. | +| `COMM_FREE_ORDER_COUNT_POLICY` | `submitted` или `cancel_counts` | `submitted` | одно из двух значений | Политика учёта бесплатных заявок: `submitted` считает только отправку новой заявки, `cancel_counts` дополнительно считает успешные отмены перед repost. | ### BT diff --git a/cmd/backtest/main.go b/cmd/backtest/main.go index e42aac5..3b93151 100644 --- a/cmd/backtest/main.go +++ b/cmd/backtest/main.go @@ -130,7 +130,7 @@ func run() error { MaxSpreadBps: spread, MaxTickBps: tick, AssumedSpreadBps: assumed, - RequireZeroCommission: *requireZeroCommission, + RequireZeroCommission: requireZeroCommission, UseMinuteModel: *useMinuteModel, }) result, err := engine.RunWithMinuteCandles(candles, minuteCandles) diff --git a/internal/app/app.go b/internal/app/app.go index 5990a1d..71c5276 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -244,6 +244,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con }) execEngine := execution.NewEngine(cfg.App.Mode, cfg.TInvest.AccountID, gateway, repo) execEngine.SetMaxQuoteAge(time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second) + execEngine.SetFreeOrderCountPolicy(cfg.Commission.FreeOrderCountPolicy) services := scheduler.Services{ Repo: repo, Gateway: gateway, @@ -272,6 +273,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con EntryWindowEnd: cfg.Execution.EntryWindowEnd, NoNewEntryAfter: cfg.Execution.NoNewEntryAfter, ExitWatchStart: cfg.Execution.ExitWatchStart, + ExitNotBefore: cfg.Execution.ExitNotBefore, ExitWindowStart: cfg.Execution.ExitWindowStart, ExitWindowEnd: cfg.Execution.ExitWindowEnd, HardExitDeadline: cfg.Execution.HardExitDeadline, @@ -287,6 +289,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con APIOutageHalt: time.Duration(cfg.Risk.APIOutageHaltSec) * time.Second, RequireZeroCommission: cfg.Commission.RequireZeroCommission, QuarantineOnNonZero: cfg.Commission.QuarantineOnNonZero, + FreeOrderCountPolicy: cfg.Commission.FreeOrderCountPolicy, ReconciliationInterval: 5 * time.Minute, MaxOpenPositions: minPositive(cfg.Strategy.MaxPositions, cfg.Risk.MaxOpenPositions), }, services) diff --git a/internal/backtest/config_test.go b/internal/backtest/config_test.go new file mode 100644 index 0000000..81a0023 --- /dev/null +++ b/internal/backtest/config_test.go @@ -0,0 +1,38 @@ +package backtest + +import ( + "testing" + + "github.com/shopspring/decimal" +) + +func TestRequireZeroCommissionDefaultDoesNotOverrideExplicitFalse(t *testing.T) { + defaultEngine := New(Config{}) + if !defaultEngine.requireZeroCommission() { + t.Fatal("default require_zero_commission should be true") + } + requireZero := false + explicitEngine := New(Config{RequireZeroCommission: &requireZero}) + if explicitEngine.requireZeroCommission() { + t.Fatal("explicit require_zero_commission=false was overridden") + } +} + +func TestAssumedSpreadUsesFundTypeSpecificDefaults(t *testing.T) { + engine := New(Config{ + AssumedSpreadBps: decimal.NewFromInt(20), + InstrumentFundTypes: map[string]string{ + "mm": "money_market", + "eq": "equity", + }, + }) + if got := engine.assumedSpreadBps("mm"); !got.Equal(decimal.NewFromInt(5)) { + t.Fatalf("money market spread=%s, want 5", got) + } + if got := engine.assumedSpreadBps("eq"); !got.Equal(decimal.NewFromInt(25)) { + t.Fatalf("equity spread=%s, want 25", got) + } + if got := engine.assumedSpreadBps("unknown"); !got.Equal(decimal.NewFromInt(20)) { + t.Fatalf("default spread=%s, want 20", got) + } +} diff --git a/internal/backtest/engine.go b/internal/backtest/engine.go index 5013933..bbea3cc 100644 --- a/internal/backtest/engine.go +++ b/internal/backtest/engine.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "sort" + "strings" "time" "github.com/shopspring/decimal" @@ -19,35 +20,40 @@ import ( ) type Config struct { - EntrySlippageBps decimal.Decimal - ExitSlippageBps decimal.Decimal - CommissionRoundtripBps decimal.Decimal - RiskBufferBps decimal.Decimal - InitialEquity decimal.Decimal - OutputDir string - RollingShort int - RollingLong int - EWMALambda float64 - MinTStat60 decimal.Decimal - MinWinRate60 decimal.Decimal - MinNetEdgeBps decimal.Decimal - MinADVRUB decimal.Decimal - MaxSpreadBps decimal.Decimal - MaxTickBps decimal.Decimal - RequireZeroCommission bool - MaxPositions int - MaxPositionPct decimal.Decimal - MaxTotalExposurePct decimal.Decimal - MaxParticipationRate decimal.Decimal - CashUsageBuffer decimal.Decimal - RiskBudgetPct decimal.Decimal - MinOrderNotionalRUB decimal.Decimal - AssumedSpreadBps decimal.Decimal - AssumedTickBps decimal.Decimal - Lot int64 - UseMinuteModel bool - EntryWindow TimeWindow - ExitWindow TimeWindow + EntrySlippageBps decimal.Decimal + ExitSlippageBps decimal.Decimal + CommissionRoundtripBps decimal.Decimal + RiskBufferBps decimal.Decimal + InitialEquity decimal.Decimal + OutputDir string + RollingShort int + RollingLong int + EWMALambda float64 + MinTStat60 decimal.Decimal + MinWinRate60 decimal.Decimal + MinNetEdgeBps decimal.Decimal + MinADVRUB decimal.Decimal + MaxSpreadBps decimal.Decimal + MaxSpreadBpsMoneyMarket decimal.Decimal + MaxSpreadBpsBondFunds decimal.Decimal + MaxSpreadBpsEquityFunds decimal.Decimal + MaxTickBps decimal.Decimal + RequireZeroCommission *bool + MaxPositions int + MaxPositionPct decimal.Decimal + MaxTotalExposurePct decimal.Decimal + MaxParticipationRate decimal.Decimal + CashUsageBuffer decimal.Decimal + RiskBudgetPct decimal.Decimal + MinOrderNotionalRUB decimal.Decimal + AssumedSpreadBps decimal.Decimal + AssumedSpreadBpsByFundType map[string]decimal.Decimal + InstrumentFundTypes map[string]string + AssumedTickBps decimal.Decimal + Lot int64 + UseMinuteModel bool + EntryWindow TimeWindow + ExitWindow TimeWindow } type TimeWindow struct { @@ -120,6 +126,15 @@ func (cfg Config) withDefaults() Config { if cfg.MaxSpreadBps.IsZero() { cfg.MaxSpreadBps = decimal.NewFromInt(20) } + if cfg.MaxSpreadBpsMoneyMarket.IsZero() { + cfg.MaxSpreadBpsMoneyMarket = decimal.NewFromInt(5) + } + if cfg.MaxSpreadBpsBondFunds.IsZero() { + cfg.MaxSpreadBpsBondFunds = decimal.NewFromInt(10) + } + if cfg.MaxSpreadBpsEquityFunds.IsZero() { + cfg.MaxSpreadBpsEquityFunds = decimal.NewFromInt(25) + } if cfg.MaxTickBps.IsZero() { cfg.MaxTickBps = decimal.NewFromInt(10) } @@ -132,8 +147,9 @@ func (cfg Config) withDefaults() Config { if cfg.AssumedTickBps.IsZero() { cfg.AssumedTickBps = cfg.MaxTickBps } - if !cfg.RequireZeroCommission && cfg.CommissionRoundtripBps.IsZero() { - cfg.RequireZeroCommission = true + if cfg.RequireZeroCommission == nil { + requireZero := true + cfg.RequireZeroCommission = &requireZero } if cfg.MaxPositions == 0 { cfg.MaxPositions = 5 @@ -260,7 +276,7 @@ func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Can Lots: lots, Notional: notional, NetPnL: pnl, - SpreadBps: e.cfg.AssumedSpreadBps, + SpreadBps: c.spreadBps, SlippageBps: e.cfg.EntrySlippageBps.Add(e.cfg.ExitSlippageBps), OvernightGap: c.overnightGap, CapacityRUB: capacity, @@ -355,6 +371,7 @@ type candidate struct { buy decimal.Decimal sell decimal.Decimal netEdge decimal.Decimal + spreadBps decimal.Decimal adv decimal.Decimal q05Abs decimal.Decimal overnightGap decimal.Decimal @@ -381,7 +398,8 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, return candidate{}, false, nil } rawEdge := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000)) - cost := e.cfg.AssumedSpreadBps. + spreadBps := e.assumedSpreadBps(instrumentUID) + cost := spreadBps. Add(e.cfg.EntrySlippageBps). Add(e.cfg.ExitSlippageBps). Add(e.cfg.CommissionRoundtripBps). @@ -389,7 +407,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, netEdge := rawEdge.Sub(cost) adv := features.ADV(history, e.cfg.Lot, 20) switch { - case e.cfg.RequireZeroCommission && e.cfg.CommissionRoundtripBps.IsPositive(): + case e.requireZeroCommission() && e.cfg.CommissionRoundtripBps.IsPositive(): return candidate{}, false, nil case !decimal.NewFromFloat(short.Mean).IsPositive() || !decimal.NewFromFloat(long.Mean).IsPositive(): return candidate{}, false, nil @@ -399,7 +417,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, return candidate{}, false, nil case netEdge.LessThan(e.cfg.MinNetEdgeBps): return candidate{}, false, nil - case e.cfg.AssumedSpreadBps.GreaterThan(e.cfg.MaxSpreadBps): + case spreadBps.GreaterThan(e.maxSpreadBps(instrumentUID)): return candidate{}, false, nil case e.cfg.AssumedTickBps.GreaterThan(e.cfg.MaxTickBps): return candidate{}, false, nil @@ -425,6 +443,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, buy: buy, sell: sell, netEdge: netEdge, + spreadBps: spreadBps, adv: adv, q05Abs: q05Abs, overnightGap: gap, @@ -432,6 +451,58 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, }, true, nil } +func (e Engine) requireZeroCommission() bool { + return e.cfg.RequireZeroCommission != nil && *e.cfg.RequireZeroCommission +} + +func (e Engine) assumedSpreadBps(instrumentUID string) decimal.Decimal { + fundType := normalizedFundType(e.cfg.InstrumentFundTypes[instrumentUID]) + if !fundType.IsZeroValue { + if spread, ok := e.cfg.AssumedSpreadBpsByFundType[fundType.Key]; ok { + return spread + } + return e.maxSpreadBpsForFundType(fundType.Raw) + } + return e.cfg.AssumedSpreadBps +} + +func (e Engine) maxSpreadBps(instrumentUID string) decimal.Decimal { + fundType := normalizedFundType(e.cfg.InstrumentFundTypes[instrumentUID]) + if fundType.IsZeroValue { + return e.cfg.MaxSpreadBps + } + return e.maxSpreadBpsForFundType(fundType.Raw) +} + +func (e Engine) maxSpreadBpsForFundType(fundType string) decimal.Decimal { + switch { + case strings.Contains(fundType, "money"): + return e.cfg.MaxSpreadBpsMoneyMarket + case strings.Contains(fundType, "bond"): + return e.cfg.MaxSpreadBpsBondFunds + case strings.Contains(fundType, "equity"): + return e.cfg.MaxSpreadBpsEquityFunds + default: + return e.cfg.MaxSpreadBps + } +} + +type normalizedType struct { + Raw string + Key string + IsZeroValue bool +} + +func normalizedFundType(raw string) normalizedType { + raw = strings.ToLower(strings.TrimSpace(raw)) + if raw == "" { + return normalizedType{IsZeroValue: true} + } + key := strings.ReplaceAll(raw, "-", "_") + key = strings.ReplaceAll(key, " ", "_") + return normalizedType{Raw: raw, Key: key} +} + func prepareCandles(candlesByInstrument map[string][]domain.Candle) map[string][]domain.Candle { prepared := make(map[string][]domain.Candle, len(candlesByInstrument)) for instrumentUID, candles := range candlesByInstrument { diff --git a/internal/config/config.go b/internal/config/config.go index dd4d640..62679a7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -68,14 +68,15 @@ type TelegramConfig struct { } type StrategyConfig struct { - RollingShort int `env:"ROLLING_SHORT" envDefault:"60"` - RollingLong int `env:"ROLLING_LONG" envDefault:"252"` - EWMALambda float64 `env:"EWMA_LAMBDA" envDefault:"0.08"` - MinTStat60 decimal.Decimal `env:"MIN_TSTAT_60" envDefault:"1.25"` - MinWinRate60 decimal.Decimal `env:"MIN_WIN_RATE_60" envDefault:"0.55"` - MinNetEdgeBps decimal.Decimal `env:"MIN_NET_EDGE_BPS" envDefault:"10"` - RiskBufferBps decimal.Decimal `env:"RISK_BUFFER_BPS" envDefault:"5"` - MaxPositions int `env:"MAX_POSITIONS" envDefault:"5"` + RollingShort int `env:"ROLLING_SHORT" envDefault:"60"` + RollingLong int `env:"ROLLING_LONG" envDefault:"252"` + EWMALambda float64 `env:"EWMA_LAMBDA" envDefault:"0.08"` + AllocationMethod string `env:"ALLOCATION_METHOD" envDefault:"equal_weight"` + MinTStat60 decimal.Decimal `env:"MIN_TSTAT_60" envDefault:"1.25"` + MinWinRate60 decimal.Decimal `env:"MIN_WIN_RATE_60" envDefault:"0.55"` + MinNetEdgeBps decimal.Decimal `env:"MIN_NET_EDGE_BPS" envDefault:"10"` + RiskBufferBps decimal.Decimal `env:"RISK_BUFFER_BPS" envDefault:"5"` + MaxPositions int `env:"MAX_POSITIONS" envDefault:"5"` } type ExecutionConfig struct { @@ -203,8 +204,19 @@ func (c *Config) Validate() error { if c.Risk.CommissionToleranceRUB.IsNegative() { return errors.New("RISK_COMMISSION_TOLERANCE_RUB must be non-negative") } - if c.Commission.FreeOrderCountPolicy != "submitted" { - return fmt.Errorf("COMM_FREE_ORDER_COUNT_POLICY must be submitted, got %q", c.Commission.FreeOrderCountPolicy) + if c.Commission.FreeOrderCountPolicy == "" { + c.Commission.FreeOrderCountPolicy = "submitted" + } + switch c.Commission.FreeOrderCountPolicy { + case "submitted", "cancel_counts": + default: + return fmt.Errorf("COMM_FREE_ORDER_COUNT_POLICY must be submitted or cancel_counts, got %q", c.Commission.FreeOrderCountPolicy) + } + if c.Strategy.AllocationMethod == "" { + c.Strategy.AllocationMethod = "equal_weight" + } + if c.Strategy.AllocationMethod != "equal_weight" { + return fmt.Errorf("STRATEGY_ALLOCATION_METHOD must be equal_weight, got %q", c.Strategy.AllocationMethod) } if err := c.validateWindows(); err != nil { return err diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 455d635..051c6a7 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -19,6 +19,14 @@ func TestValidateRequiresAccountIDForBrokerModes(t *testing.T) { } } +func TestValidateAllowsCancelCountsFreeOrderPolicy(t *testing.T) { + cfg := minimalBrokerConfig(domain.ModeSandbox) + cfg.Commission.FreeOrderCountPolicy = "cancel_counts" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate cancel_counts: %v", err) + } +} + func minimalBrokerConfig(mode domain.Mode) Config { return Config{ App: AppConfig{ @@ -44,6 +52,7 @@ func minimalBrokerConfig(mode domain.Mode) Config { QuoteDepth: 20, OrderPollIntervalMS: 500, }, + Strategy: StrategyConfig{AllocationMethod: "equal_weight"}, Risk: RiskConfig{ APIOutageHaltSec: 180, ReconciliationWindowHours: 72, diff --git a/internal/execution/engine.go b/internal/execution/engine.go index bd56939..881e19b 100644 --- a/internal/execution/engine.go +++ b/internal/execution/engine.go @@ -17,6 +17,11 @@ import ( var ErrBrokerOrdersDisabled = errors.New("broker orders are disabled for current mode") var ErrEmptyOrderBook = errors.New("order book has no usable bid/ask") +const ( + FreeOrderPolicySubmitted = "submitted" + FreeOrderPolicyCancelCounts = "cancel_counts" +) + type Gateway interface { PostLimitOrder(ctx context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error) CancelOrder(ctx context.Context, accountID, orderID string) error @@ -24,12 +29,13 @@ type Gateway interface { } type Engine struct { - mode domain.Mode - accountID string - gateway Gateway - store repository.Repository - maxQuoteAge time.Duration - mu sync.Map + mode domain.Mode + accountID string + gateway Gateway + store repository.Repository + maxQuoteAge time.Duration + freeOrderCountPolicy string + mu sync.Map } type MonitorConfig struct { @@ -44,13 +50,22 @@ type MonitorConfig struct { } func NewEngine(mode domain.Mode, accountID string, gateway Gateway, store repository.Repository) Engine { - return Engine{mode: mode, accountID: accountID, gateway: gateway, store: store} + return Engine{mode: mode, accountID: accountID, gateway: gateway, store: store, freeOrderCountPolicy: FreeOrderPolicySubmitted} } func (e *Engine) SetMaxQuoteAge(maxQuoteAge time.Duration) { e.maxQuoteAge = maxQuoteAge } +func (e *Engine) SetFreeOrderCountPolicy(policy string) { + switch policy { + case FreeOrderPolicyCancelCounts: + e.freeOrderCountPolicy = policy + default: + e.freeOrderCountPolicy = FreeOrderPolicySubmitted + } +} + func (e *Engine) PlaceEntry(ctx context.Context, accountIDHash string, instrument domain.Instrument, tradeDate time.Time, lots int64, book domain.OrderBook, improveTicks int, attempt int) (domain.Order, error) { if err := e.checkQuoteFresh(book); err != nil { return domain.Order{}, err @@ -251,7 +266,15 @@ func (e *Engine) Cancel(ctx context.Context, order domain.Order) error { return err } if e.store != nil { - return e.store.UpdateOrderStatus(ctx, order.ClientOrderID, domain.OrderStatusCancelled, order.FilledLots, order.RawStateJSON) + return e.store.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error { + if err := repo.UpdateOrderStatus(ctx, order.ClientOrderID, domain.OrderStatusCancelled, order.FilledLots, order.RawStateJSON); err != nil { + return err + } + if e.cancelCountsAsFreeOrder() { + return repo.IncrementFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1) + } + return nil + }) } return nil } @@ -485,12 +508,21 @@ func (e *Engine) ensureRepostBudget(ctx context.Context, order domain.Order, ins if err != nil { return err } - if instrument.FreeOrderLimitPerDay-sent < 1 { - return fmt.Errorf("%w: %s remaining=0", risk.ErrFreeOrderBudget, instrument.InstrumentUID) + needed := 1 + if e.cancelCountsAsFreeOrder() { + needed = 2 + } + remaining := instrument.FreeOrderLimitPerDay - sent + if remaining < needed { + return fmt.Errorf("%w: %s remaining=%d needed=%d", risk.ErrFreeOrderBudget, instrument.InstrumentUID, remaining, needed) } return nil } +func (e *Engine) cancelCountsAsFreeOrder() bool { + return e.freeOrderCountPolicy == FreeOrderPolicyCancelCounts +} + func (e *Engine) checkQuoteFresh(book domain.OrderBook) error { if e.maxQuoteAge <= 0 { return nil diff --git a/internal/execution/state_test.go b/internal/execution/state_test.go index 0c031a1..830f62c 100644 --- a/internal/execution/state_test.go +++ b/internal/execution/state_test.go @@ -96,6 +96,40 @@ func TestPaperPlaceEntryFillsAndCountsSubmittedOrder(t *testing.T) { } } +func TestCancelCountsAsFreeOrderWhenPolicyEnabled(t *testing.T) { + ctx := context.Background() + repo := testutil.NewMemoryRepository() + gateway := tinvest.NewFakeGateway() + engine := NewEngine(domain.ModeSandbox, "account", gateway, repo) + engine.SetFreeOrderCountPolicy(FreeOrderPolicyCancelCounts) + tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC) + order, err := engine.PlaceLimit(ctx, domain.Order{ + ClientOrderID: "order-1", + AccountIDHash: "hash", + InstrumentUID: "uid", + TradeDate: tradeDate, + Side: domain.SideBuy, + OrderType: domain.OrderTypeLimit, + LimitPrice: decimal.NewFromInt(100), + QuantityLots: 1, + Status: domain.OrderStatusNew, + AttemptNo: 1, + }) + if err != nil { + t.Fatal(err) + } + if err := engine.Cancel(ctx, order); err != nil { + t.Fatal(err) + } + sent, err := repo.GetFreeOrdersSent(ctx, tradeDate, "uid") + if err != nil { + t.Fatal(err) + } + if sent != 2 { + t.Fatalf("free order counter=%d, want submit+cancel", sent) + } +} + func TestPlaceEntryRejectsStaleQuote(t *testing.T) { ctx := context.Background() engine := NewEngine(domain.ModeSandbox, "account", tinvest.NewFakeGateway(), testutil.NewMemoryRepository()) diff --git a/internal/features/pipeline.go b/internal/features/pipeline.go index 627953c..86024ca 100644 --- a/internal/features/pipeline.go +++ b/internal/features/pipeline.go @@ -40,7 +40,8 @@ func NewPipeline(repo repository.Repository, cfg PipelineConfig) Pipeline { func (p Pipeline) Recompute(ctx context.Context, instrument domain.Instrument, tradeDate time.Time, spread SpreadResult) (domain.FeatureSet, error) { from := tradeDate.AddDate(0, 0, -p.cfg.RollingLong-5) - candles, err := p.repo.ListDailyCandles(ctx, instrument.InstrumentUID, from, tradeDate) + to := dateOnly(tradeDate).AddDate(0, 0, -1) + candles, err := p.repo.ListDailyCandles(ctx, instrument.InstrumentUID, from, to) if err != nil { return domain.FeatureSet{}, err } @@ -74,8 +75,9 @@ func (p Pipeline) intervalVolume(ctx context.Context, instrument domain.Instrume if lookback <= 0 { lookback = defaultIntervalVolumeLookback } - from := window.Start.On(date.AddDate(0, 0, -lookback), loc).UTC() - to := window.End.On(date, loc).UTC() + toDate := dateOnly(date).AddDate(0, 0, -1) + from := window.Start.On(toDate.AddDate(0, 0, -lookback+1), loc).UTC() + to := window.End.On(toDate, loc).UTC() candles, err := p.repo.ListMinuteCandles(ctx, instrument.InstrumentUID, from, to) if err != nil { return decimal.Zero, err @@ -84,6 +86,7 @@ func (p Pipeline) intervalVolume(ctx context.Context, instrument domain.Instrume } func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate time.Time, spread SpreadResult, cfg PipelineConfig, entryVolume, exitVolume decimal.Decimal) (domain.FeatureSet, error) { + candles = historicalDailyCandles(candles, tradeDate) if len(candles) < 2 { return domain.FeatureSet{}, fmt.Errorf("need at least 2 candles, got %d", len(candles)) } @@ -138,6 +141,22 @@ func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate ti }, nil } +func historicalDailyCandles(candles []domain.Candle, tradeDate time.Time) []domain.Candle { + tradeDay := dateOnly(tradeDate) + out := make([]domain.Candle, 0, len(candles)) + for _, candle := range candles { + if dateOnly(candle.TradeDate).Before(tradeDay) { + out = append(out, candle) + } + } + return out +} + +func dateOnly(ts time.Time) time.Time { + year, month, day := ts.UTC().Date() + return time.Date(year, month, day, 0, 0, 0, 0, time.UTC) +} + func IntervalVolume(candles []domain.Candle, lot int64) decimal.Decimal { if lot <= 0 { return decimal.Zero diff --git a/internal/features/pipeline_test.go b/internal/features/pipeline_test.go index ad365bb..4db1d97 100644 --- a/internal/features/pipeline_test.go +++ b/internal/features/pipeline_test.go @@ -1,12 +1,14 @@ package features import ( + "context" "testing" "time" "github.com/shopspring/decimal" "overnight-trading-bot/internal/domain" + "overnight-trading-bot/internal/testutil" "overnight-trading-bot/internal/timeutil" ) @@ -74,6 +76,41 @@ func TestAverageIntervalVolumeUsesExecutionWindowsAcrossDays(t *testing.T) { } } +func TestRecomputeExcludesTradeDateDailyCandle(t *testing.T) { + ctx := context.Background() + repo := testutil.NewMemoryRepository() + start := time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC) + var candles []domain.Candle + for i := 0; i < 6; i++ { + closePrice := decimal.NewFromInt(100) + if i == 5 { + closePrice = decimal.NewFromInt(100000) + } + candles = append(candles, domain.Candle{ + InstrumentUID: "uid", + TradeDate: start.AddDate(0, 0, i), + Open: decimal.NewFromInt(100), + Close: closePrice, + VolumeLots: decimal.NewFromInt(1), + }) + } + if err := repo.UpsertDailyCandles(ctx, candles); err != nil { + t.Fatal(err) + } + pipeline := NewPipeline(repo, PipelineConfig{ + RollingShort: 2, + RollingLong: 2, + EWMALambda: 0.08, + }) + got, err := pipeline.Recompute(ctx, domain.Instrument{InstrumentUID: "uid", Lot: 1}, start.AddDate(0, 0, 5), SpreadResult{}) + if err != nil { + t.Fatal(err) + } + if !got.ADV20.Equal(decimal.NewFromInt(100)) { + t.Fatalf("ADV20=%s, want tradeDate candle excluded", got.ADV20) + } +} + func mustTOD(raw string) timeutil.TimeOfDay { tod, err := timeutil.ParseTimeOfDay(raw) if err != nil { diff --git a/internal/logging/logging.go b/internal/logging/logging.go index e0fb652..3377eee 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -1,9 +1,11 @@ package logging import ( + "fmt" "io" "log/slog" "os" + "regexp" "strings" ) @@ -22,7 +24,10 @@ func New(level string, out io.Writer) *slog.Logger { default: slogLevel = slog.LevelInfo } - return slog.New(slog.NewJSONHandler(out, &slog.HandlerOptions{Level: slogLevel})) + return slog.New(slog.NewJSONHandler(out, &slog.HandlerOptions{ + Level: slogLevel, + ReplaceAttr: redactAttr, + })) } type SDKLogger struct { @@ -31,18 +36,69 @@ type SDKLogger struct { func (l SDKLogger) Infof(template string, args ...any) { if l.Logger != nil { - l.Logger.Info(template, "args", args) + l.Logger.Info(RedactString(template), "args", redactArgs(args)) } } func (l SDKLogger) Errorf(template string, args ...any) { if l.Logger != nil { - l.Logger.Error(template, "args", args) + l.Logger.Error(RedactString(template), "args", redactArgs(args)) } } func (l SDKLogger) Fatalf(template string, args ...any) { if l.Logger != nil { - l.Logger.Error(template, "args", args) + l.Logger.Error(RedactString(template), "args", redactArgs(args)) } } + +var sensitiveStringPatterns = []*regexp.Regexp{ + regexp.MustCompile(`(?i)((?:account[_-]?id|token)\s*[:=]\s*)("[^"]+"|'[^']+'|[^\s,}]+)`), + regexp.MustCompile(`(?i)("(?:accountId|account_id|token)"\s*:\s*)("[^"]*"|null)`), +} + +func redactAttr(_ []string, attr slog.Attr) slog.Attr { + if attr.Value.Kind() == slog.KindString { + attr.Value = slog.StringValue(RedactString(attr.Value.String())) + } + return attr +} + +func redactArgs(args []any) []any { + out := make([]any, len(args)) + for i, arg := range args { + out[i] = redactAny(arg) + } + return out +} + +func redactAny(value any) any { + switch typed := value.(type) { + case string: + return RedactString(typed) + case []string: + out := make([]string, len(typed)) + for i, item := range typed { + out[i] = RedactString(item) + } + return out + case []any: + out := make([]any, len(typed)) + for i, item := range typed { + out[i] = redactAny(item) + } + return out + case fmt.Stringer: + return RedactString(typed.String()) + default: + return value + } +} + +func RedactString(raw string) string { + redacted := raw + for _, pattern := range sensitiveStringPatterns { + redacted = pattern.ReplaceAllString(redacted, `${1}"[REDACTED]"`) + } + return redacted +} diff --git a/internal/reconciliation/engine.go b/internal/reconciliation/engine.go index b86e2a7..5d60a68 100644 --- a/internal/reconciliation/engine.go +++ b/internal/reconciliation/engine.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "strings" + "sync" "time" "github.com/shopspring/decimal" @@ -18,6 +19,7 @@ import ( var defaultCommissionTolerance = decimal.RequireFromString("0.01") type Engine struct { + mu *sync.Mutex repo repository.Repository gateway tinvest.Gateway accountID string @@ -31,6 +33,7 @@ type Engine struct { func New(repo repository.Repository, gateway tinvest.Gateway, accountID, accountIDHash string) Engine { return Engine{ + mu: &sync.Mutex{}, repo: repo, gateway: gateway, accountID: accountID, @@ -64,6 +67,10 @@ func (e Engine) WithCommissionPolicy(requireZero, quarantineOnNonZero bool, tole } func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) { + if e.mu != nil { + e.mu.Lock() + defer e.mu.Unlock() + } localOrders, err := e.repo.ListActiveOrders(ctx, e.accountIDHash) if err != nil { return nil, err @@ -150,6 +157,7 @@ func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) { }) } } + diffs = append(diffs, compareCash(localPositions, portfolio, e.commissionTolerance)...) from := now.Add(-e.window) recentOrders, err := e.repo.ListOrders(ctx, e.accountIDHash, from, now) if err != nil { @@ -204,6 +212,55 @@ func compareOperations(orders []domain.Order, operations []domain.Operation) []d return compareOperationsWithPolicy(orders, operations, false, defaultCommissionTolerance) } +func compareCash(localPositions []domain.Position, portfolio domain.Portfolio, tolerance decimal.Decimal) []domain.ReconciliationDiff { + if tolerance.IsNegative() { + tolerance = decimal.Zero + } + expectedCash, ok := expectedCashFromLocalPositions(localPositions, portfolio) + if !ok { + return nil + } + diff := money.Abs(expectedCash.Sub(portfolio.Cash)) + if diff.LessThanOrEqual(tolerance) { + return nil + } + return []domain.ReconciliationDiff{{ + Kind: "cash_mismatch", + Message: fmt.Sprintf("expected cash=%s broker cash=%s diff=%s", expectedCash.StringFixed(2), portfolio.Cash.StringFixed(2), diff.StringFixed(2)), + Critical: true, + }} +} + +func expectedCashFromLocalPositions(localPositions []domain.Position, portfolio domain.Portfolio) (decimal.Decimal, bool) { + if !portfolio.Equity.IsPositive() { + return decimal.Zero, false + } + if len(localPositions) == 0 { + if len(portfolio.Holdings) != 0 { + return decimal.Zero, false + } + return portfolio.Equity, true + } + holdingByInstrument := make(map[string]domain.Holding, len(portfolio.Holdings)) + for _, holding := range portfolio.Holdings { + holdingByInstrument[holding.InstrumentUID] = holding + } + positionMarketValue := decimal.Zero + for _, pos := range localPositions { + if pos.Lots <= 0 { + continue + } + holding, ok := holdingByInstrument[pos.InstrumentUID] + if !ok || holding.QuantityLots <= 0 || !holding.MarketValue.IsPositive() { + return decimal.Zero, false + } + positionMarketValue = positionMarketValue.Add(holding.MarketValue. + Mul(decimal.NewFromInt(pos.Lots)). + Div(decimal.NewFromInt(holding.QuantityLots))) + } + return portfolio.Equity.Sub(positionMarketValue), true +} + func compareOperationsWithPolicy(orders []domain.Order, operations []domain.Operation, requireZeroCommission bool, commissionTolerance decimal.Decimal) []domain.ReconciliationDiff { var diffs []domain.ReconciliationDiff if commissionTolerance.IsNegative() { diff --git a/internal/reconciliation/engine_test.go b/internal/reconciliation/engine_test.go index 4154284..2acbc82 100644 --- a/internal/reconciliation/engine_test.go +++ b/internal/reconciliation/engine_test.go @@ -170,3 +170,37 @@ func TestReconciliationSkipsFreshInFlightLocalOrders(t *testing.T) { } } } + +func TestReconciliationFindsCashMismatch(t *testing.T) { + ctx := context.Background() + repo := testutil.NewMemoryRepository() + gateway := tinvest.NewFakeGateway() + if err := repo.UpsertPosition(ctx, domain.Position{ + AccountIDHash: "hash", + InstrumentUID: "uid", + OpenTradeDate: time.Now().UTC(), + Lots: 2, + Status: domain.PositionHoldingOvernight, + }); err != nil { + t.Fatal(err) + } + gateway.Portfolio = domain.Portfolio{ + Equity: decimal.NewFromInt(1000), + Cash: decimal.NewFromInt(700), + Holdings: []domain.Holding{{ + InstrumentUID: "uid", + QuantityLots: 2, + MarketValue: decimal.NewFromInt(200), + }}, + } + diffs, err := New(repo, gateway, "account", "hash").Run(ctx) + if err != nil { + t.Fatal(err) + } + for _, diff := range diffs { + if diff.Kind == "cash_mismatch" && diff.Critical { + return + } + } + t.Fatalf("missing cash_mismatch in %+v", diffs) +} diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index d0aeb7f..c5451b2 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -45,6 +45,7 @@ type Config struct { EntryWindowEnd timeutil.TimeOfDay NoNewEntryAfter timeutil.TimeOfDay ExitWatchStart timeutil.TimeOfDay + ExitNotBefore timeutil.TimeOfDay ExitWindowStart timeutil.TimeOfDay ExitWindowEnd timeutil.TimeOfDay HardExitDeadline timeutil.TimeOfDay @@ -60,6 +61,7 @@ type Config struct { APIOutageHalt time.Duration RequireZeroCommission bool QuarantineOnNonZero bool + FreeOrderCountPolicy string ReconciliationInterval time.Duration MaxOpenPositions int } @@ -133,6 +135,13 @@ func (s *Scheduler) Step(ctx context.Context) error { return err } now := s.clock.Now().In(s.cfg.Location) + reported, err := s.sendMissedDailyReport(ctx, now) + if err != nil { + return err + } + if reported { + return nil + } phase := s.phase(now) switch phase { case domain.StateWaitExitWindow: @@ -158,10 +167,14 @@ func (s *Scheduler) Step(ctx context.Context) error { func (s Scheduler) phase(now time.Time) domain.SystemState { tod := sinceMidnight(now) + exitWindowStart := s.cfg.ExitWindowStart.Duration + if s.cfg.ExitNotBefore.Duration > exitWindowStart { + exitWindowStart = s.cfg.ExitNotBefore.Duration + } switch { - case tod >= s.cfg.ExitWatchStart.Duration && tod < s.cfg.ExitWindowStart.Duration: + case tod >= s.cfg.ExitWatchStart.Duration && tod < exitWindowStart: return domain.StateWaitExitWindow - case tod >= s.cfg.ExitWindowStart.Duration && tod < s.cfg.ExitWindowEnd.Duration: + case tod >= exitWindowStart && tod < s.cfg.ExitWindowEnd.Duration: return domain.StatePlaceExitOrders case tod >= s.cfg.ExitWindowEnd.Duration && tod < s.cfg.HardExitDeadline.Duration: return domain.StateMonitorExitOrders @@ -463,7 +476,7 @@ func (s *Scheduler) placeEntryOrders(ctx context.Context, now time.Time) error { } continue } - pre, err := s.preTradeCheck(ctx, now, portfolio, projectedOpenPositions, tradingStatus, book.ReceivedAt) + pre, err := s.preTradeCheck(ctx, now, sig.InstrumentUID, portfolio, projectedOpenPositions, tradingStatus, book.ReceivedAt) if err != nil { return err } @@ -585,7 +598,7 @@ func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error { if !ok { return fmt.Errorf("instrument %s is not in registry", pos.InstrumentUID) } - if _, err := s.svc.FreeOrders.Check(ctx, exitTradeDate, instrument, s.cfg.MaxExitOrderAttempts); err != nil { + if _, err := s.svc.FreeOrders.Check(ctx, exitTradeDate, instrument, s.orderBudgetNeededForAttempts(s.cfg.MaxExitOrderAttempts)); err != nil { if insertErr := s.recordPreTradeReject(ctx, pos.InstrumentUID, err.Error(), `{"reason":"free_order_budget_insufficient"}`); insertErr != nil { return insertErr } @@ -609,7 +622,7 @@ func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error { if err != nil { return err } - pre, err := s.preTradeCheck(ctx, now, portfolio, len(positionsList), tradingStatus, book.ReceivedAt) + pre, err := s.preTradeCheck(ctx, now, pos.InstrumentUID, portfolio, len(positionsList), tradingStatus, book.ReceivedAt) if err != nil { return err } @@ -722,12 +735,47 @@ func (s *Scheduler) reconcileAndReport(ctx context.Context, now time.Time) error if err := s.transitionTo(ctx, domain.StateReconcile); err != nil { return err } - if err := s.reconcileCritical(ctx, "reconciliation_critical"); err != nil { - return err + if s.cfg.Mode.AllowsBrokerOrders() { + if err := s.reconcileCritical(ctx, "reconciliation_critical"); err != nil { + return err + } } return s.sendDailyReport(ctx, now, "ok") } +func (s *Scheduler) sendMissedDailyReport(ctx context.Context, now time.Time) (bool, error) { + if s.svc.Repo == nil || !s.hasStateMachine() { + return false, nil + } + tod := sinceMidnight(now) + if tod < s.cfg.EntrySignalTime.Duration { + return false, nil + } + phase := s.phase(now) + if phase == domain.StateReconcile || phase == domain.StateReport { + return false, nil + } + state, halted, _, err := s.svc.Repo.GetSystemState(ctx) + if err != nil { + return false, err + } + if halted || state == domain.StateHalted { + return false, nil + } + if state != domain.StateInit && state != domain.StateSleep { + return false, nil + } + tradeDate := tradingDate(now) + sent, err := s.svc.Repo.WasDailyReportSent(ctx, tradeDate, s.svc.AccountIDHash) + if err != nil { + return false, err + } + if sent { + return false, nil + } + return true, s.reconcileAndReport(ctx, now) +} + func (s *Scheduler) sendDailyReport(ctx context.Context, now time.Time, riskStatus string) error { tradeDate := tradingDate(now) sent, err := s.svc.Repo.WasDailyReportSent(ctx, tradeDate, s.svc.AccountIDHash) @@ -1081,7 +1129,7 @@ func (s Scheduler) repostPreTradeCheck(ctx context.Context, now time.Time, order if err != nil { return err } - pre, err := s.preTradeCheck(ctx, now, portfolio, len(openPositions), tradingStatus, book.ReceivedAt) + pre, err := s.preTradeCheck(ctx, now, order.InstrumentUID, portfolio, len(openPositions), tradingStatus, book.ReceivedAt) if err != nil { return err } @@ -1092,23 +1140,96 @@ func (s Scheduler) repostPreTradeCheck(ctx context.Context, now time.Time, order return nil } -func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, portfolio domain.Portfolio, openPositions int, tradingStatus domain.TradingStatus, quoteReceivedAt time.Time) (risk.PreTradeResult, error) { +func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, instrumentUID string, portfolio domain.Portfolio, openPositions int, tradingStatus domain.TradingStatus, quoteReceivedAt time.Time) (risk.PreTradeResult, error) { metrics, err := s.riskMetrics(ctx, now, portfolio) + if err != nil { + if haltErr := s.halt(ctx, "database_unavailable", fmt.Sprintf("pre-trade risk metrics unavailable: %s", err), instrumentUID); haltErr != nil { + return risk.PreTradeResult{}, fmt.Errorf("database_unavailable: %w; halt failed: %v", err, haltErr) + } + return risk.PreTradeResult{Allowed: false, Reason: "database_unavailable"}, fmt.Errorf("%w: database_unavailable", statemachine.ErrSystemHalted) + } + unknownOrder, unknownHolding, err := s.unknownBrokerState(ctx, portfolio) if err != nil { return risk.PreTradeResult{}, err } - return s.svc.Risk.PreTradeCheck(risk.PreTradeInput{ - Portfolio: portfolio, - OpenPositions: openPositions, - DailyPnL: metrics.dailyPnL, - WeeklyPnL: metrics.weeklyPnL, - MonthlyDrawdownPct: metrics.monthlyDrawdownPct, - AvgSlippageBps10: metrics.avgSlippageBps10, - TradingStatus: tradingStatus, - QuoteReceivedAt: quoteReceivedAt, - Now: now.UTC(), - MarketClose: s.marketCloseOn(now), - }), nil + result := s.svc.Risk.PreTradeCheck(risk.PreTradeInput{ + Portfolio: portfolio, + OpenPositions: openPositions, + DailyPnL: metrics.dailyPnL, + WeeklyPnL: metrics.weeklyPnL, + MonthlyDrawdownPct: metrics.monthlyDrawdownPct, + AvgSlippageBps10: metrics.avgSlippageBps10, + TradingStatus: tradingStatus, + QuoteReceivedAt: quoteReceivedAt, + Now: now.UTC(), + MarketClose: s.marketCloseOn(now), + UnknownBrokerOrder: unknownOrder, + UnknownBrokerHolding: unknownHolding, + }) + if !result.Allowed && isHardHaltPreTradeReason(result.Reason) { + if err := s.halt(ctx, result.Reason, fmt.Sprintf("pre-trade hard limit breached: %s", result.Reason), instrumentUID); err != nil { + return result, err + } + return result, fmt.Errorf("%w: %s", statemachine.ErrSystemHalted, result.Reason) + } + return result, nil +} + +func (s Scheduler) unknownBrokerState(ctx context.Context, portfolio domain.Portfolio) (bool, bool, error) { + if !s.cfg.Mode.AllowsBrokerOrders() { + return false, false, nil + } + localOrders, err := s.svc.Repo.ListActiveOrders(ctx, s.svc.AccountIDHash) + if err != nil { + return false, false, err + } + localByBroker := make(map[string]struct{}, len(localOrders)) + for _, order := range localOrders { + if order.BrokerOrderID != "" { + localByBroker[order.BrokerOrderID] = struct{}{} + } + } + brokerOrders, err := s.svc.Gateway.GetActiveOrders(ctx, s.svc.AccountID) + if err != nil { + return false, false, err + } + for _, brokerOrder := range brokerOrders { + if brokerOrder.BrokerOrderID == "" { + continue + } + if _, ok := localByBroker[brokerOrder.BrokerOrderID]; !ok { + return true, false, nil + } + } + localPositions, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash) + if err != nil { + return false, false, err + } + localLots := make(map[string]int64, len(localPositions)) + for _, pos := range localPositions { + localLots[pos.InstrumentUID] += pos.Lots + } + for _, holding := range portfolio.Holdings { + if holding.QuantityLots > 0 && localLots[holding.InstrumentUID] == 0 { + return false, true, nil + } + } + return false, false, nil +} + +func isHardHaltPreTradeReason(reason string) bool { + switch reason { + case "database_unavailable", + "unknown_broker_order", + "unknown_broker_position", + "trading_status_unknown_before_order", + "max_daily_loss", + "max_weekly_loss", + "max_monthly_drawdown": + return true + default: + return false + } } type preTradeMetrics struct { @@ -1255,13 +1376,24 @@ func repostAfter(now, deadline time.Time, attempts int, poll time.Duration) time } func (s Scheduler) maxOrderAttemptsPerTrade() int { - needed := s.cfg.MaxEntryOrderAttempts + s.cfg.MaxExitOrderAttempts + needed := s.orderBudgetNeededForAttempts(s.cfg.MaxEntryOrderAttempts) + s.orderBudgetNeededForAttempts(s.cfg.MaxExitOrderAttempts) if needed <= 0 { return 1 } return needed } +func (s Scheduler) orderBudgetNeededForAttempts(attempts int) int { + if attempts <= 0 { + attempts = 1 + } + needed := attempts + if s.cfg.FreeOrderCountPolicy == execution.FreeOrderPolicyCancelCounts { + needed += attempts - 1 + } + return needed +} + func isSizingSkipReason(reason string) bool { return reason == "lots_below_one" || reason == "min_order_notional" } diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index 110520b..52fcad6 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -2,6 +2,7 @@ package scheduler import ( "context" + "errors" "testing" "time" @@ -57,6 +58,33 @@ func TestPhaseUsesMoscowWindows(t *testing.T) { } } +func TestPhaseHonorsExitNotBeforeWhenWindowStartsEarlier(t *testing.T) { + loc := time.FixedZone("MSK", 3*60*60) + s := Scheduler{cfg: Config{ + Location: loc, + EntrySignalTime: mustTOD("18:10:00"), + ExitWatchStart: mustTOD("09:50:00"), + ExitNotBefore: mustTOD("10:03:00"), + ExitWindowStart: mustTOD("10:00:00"), + ExitWindowEnd: mustTOD("10:25:00"), + HardExitDeadline: mustTOD("10:45:00"), + }} + at, err := time.Parse(time.RFC3339, "2026-06-06T10:01:00+03:00") + if err != nil { + t.Fatal(err) + } + if got := s.phase(at.In(loc)); got != domain.StateWaitExitWindow { + t.Fatalf("phase before ExitNotBefore=%s, want WAIT_EXIT_WINDOW", got) + } + at, err = time.Parse(time.RFC3339, "2026-06-06T10:04:00+03:00") + if err != nil { + t.Fatal(err) + } + if got := s.phase(at.In(loc)); got != domain.StatePlaceExitOrders { + t.Fatalf("phase after ExitNotBefore=%s, want PLACE_EXIT_ORDERS", got) + } +} + func TestInfrastructureOutageRequiresThreshold(t *testing.T) { gateway := tinvest.NewFakeGateway() gateway.ServerTime = time.Now().UTC().Add(-10 * time.Second) @@ -260,6 +288,80 @@ func TestNonZeroCommissionQuarantinesInstrumentAndHalts(t *testing.T) { } } +func TestPreTradeDailyLossBreachHalts(t *testing.T) { + ctx := context.Background() + repo := testutil.NewMemoryRepository() + now := time.Date(2026, 6, 8, 18, 20, 0, 0, time.UTC) + closedAt := now.Add(-time.Hour) + if err := repo.UpsertPosition(ctx, domain.Position{ + AccountIDHash: "hash", + InstrumentUID: "uid", + OpenTradeDate: tradingDate(now), + Status: domain.PositionExitFilled, + NetPnL: decimal.NewFromInt(-200), + ClosedAt: &closedAt, + }); err != nil { + t.Fatal(err) + } + notifier := &countNotifier{} + s := Scheduler{ + cfg: Config{Mode: domain.ModePaper, Location: time.UTC}, + svc: Services{ + Repo: repo, + Risk: risk.NewManager(repo, risk.ManagerConfig{MaxDailyLossPct: decimal.RequireFromString("0.01")}), + Notifier: notifier, + AccountIDHash: "hash", + }, + } + _, err := s.preTradeCheck(ctx, now, "uid", domain.Portfolio{ + Equity: decimal.NewFromInt(10000), + Cash: decimal.NewFromInt(10000), + }, 0, domain.TradingStatusNormal, now) + if !errors.Is(err, statemachine.ErrSystemHalted) { + t.Fatalf("err=%v, want ErrSystemHalted", err) + } + if !repo.Halted || repo.HaltReason != "pre-trade hard limit breached: max_daily_loss" { + t.Fatalf("halted=%v reason=%q", repo.Halted, repo.HaltReason) + } + if notifier.alerts != 1 { + t.Fatalf("alerts=%d, want 1", notifier.alerts) + } +} + +func TestStepSendsMissedDailyReportAfterEntrySignalTime(t *testing.T) { + ctx := context.Background() + repo := testutil.NewMemoryRepository() + notifier := &countNotifier{} + now := time.Date(2026, 6, 8, 18, 15, 0, 0, time.UTC) + s := Scheduler{ + clock: fixedClock{now: now}, + cfg: Config{ + Mode: domain.ModePaper, + Location: time.UTC, + EntrySignalTime: mustTOD("18:10:00"), + }, + sm: statemachine.New(repo, domain.ModePaper), + svc: Services{ + Repo: repo, + Notifier: notifier, + AccountIDHash: "hash", + }, + } + if err := s.Step(ctx); err != nil { + t.Fatal(err) + } + if notifier.reports != 1 { + t.Fatalf("reports=%d, want catch-up report", notifier.reports) + } + sent, err := repo.WasDailyReportSent(ctx, now, "hash") + if err != nil { + t.Fatal(err) + } + if !sent { + t.Fatal("daily report was not marked as sent") + } +} + func TestSizeReductionRuleCutsSizerAfterBadExpectedErrors(t *testing.T) { ctx := context.Background() repo := testutil.NewMemoryRepository() diff --git a/internal/statemachine/system.go b/internal/statemachine/system.go index 335c151..f4cada1 100644 --- a/internal/statemachine/system.go +++ b/internal/statemachine/system.go @@ -37,6 +37,9 @@ func (s System) Recover(ctx context.Context, reconcile reconciliation.Engine) (d case domain.StatePlaceEntryOrders, domain.StateMonitorEntryOrders, domain.StatePlaceExitOrders, domain.StateMonitorExitOrders, domain.StateHoldOvernight: + if !s.mode.AllowsBrokerOrders() { + return state, nil + } diffs, err := reconcile.Run(ctx) if err != nil { return "", err diff --git a/internal/statemachine/system_test.go b/internal/statemachine/system_test.go index 0524642..6fc65f8 100644 --- a/internal/statemachine/system_test.go +++ b/internal/statemachine/system_test.go @@ -80,7 +80,7 @@ func TestCalendarRecoveryAllowsRestartInsideExitWindow(t *testing.T) { func TestRecoverFromMonitorEntryHaltsOnCriticalReconciliationDiff(t *testing.T) { ctx := context.Background() repo := testutil.NewMemoryRepository() - if err := repo.SaveSystemState(ctx, domain.StateMonitorEntryOrders, domain.ModePaper, false, "", "{}"); err != nil { + if err := repo.SaveSystemState(ctx, domain.StateMonitorEntryOrders, domain.ModeSandbox, false, "", "{}"); err != nil { t.Fatal(err) } if err := repo.UpsertOrder(ctx, domain.Order{ @@ -97,7 +97,7 @@ func TestRecoverFromMonitorEntryHaltsOnCriticalReconciliationDiff(t *testing.T) }); err != nil { t.Fatal(err) } - system := New(repo, domain.ModePaper) + system := New(repo, domain.ModeSandbox) state, err := system.Recover(ctx, reconciliation.New(repo, tinvest.NewFakeGateway(), "account", "hash")) if err == nil { t.Fatal("expected critical reconciliation error") diff --git a/internal/tinvest/real.go b/internal/tinvest/real.go index 81fc555..530a487 100644 --- a/internal/tinvest/real.go +++ b/internal/tinvest/real.go @@ -15,6 +15,7 @@ import ( "github.com/shopspring/decimal" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" "overnight-trading-bot/internal/domain" "overnight-trading-bot/internal/logging" @@ -407,6 +408,13 @@ func orderFromPostResponse(resp *pb.PostOrderResponse, accountID, clientOrderID return domain.Order{} } now := time.Now().UTC() + avgFillPrice := decimal.Zero + if resp.GetLotsExecuted() > 0 { + avgFillPrice = money.MoneyValueToDecimal(resp.GetExecutedOrderPrice()) + if !avgFillPrice.IsPositive() { + avgFillPrice = limitPrice + } + } return domain.Order{ ClientOrderID: clientOrderID, BrokerOrderID: resp.GetOrderId(), @@ -417,7 +425,7 @@ func orderFromPostResponse(resp *pb.PostOrderResponse, accountID, clientOrderID LimitPrice: limitPrice, QuantityLots: resp.GetLotsRequested(), FilledLots: resp.GetLotsExecuted(), - AvgFillPrice: limitPrice, + AvgFillPrice: avgFillPrice, Status: mapOrderStatus(resp.GetExecutionReportStatus()), Commission: money.MoneyValueToDecimal(resp.GetExecutedCommission()), RawStateJSON: marshalProto(resp), @@ -438,6 +446,10 @@ func orderFromState(state *pb.OrderState, accountID string) domain.Order { if state.GetOrderDate() != nil { orderDate = state.GetOrderDate().AsTime().UTC() } + avgFillPrice := decimal.Zero + if state.GetLotsExecuted() > 0 { + avgFillPrice = money.MoneyValueToDecimal(state.GetAveragePositionPrice()) + } return domain.Order{ ClientOrderID: state.GetOrderRequestId(), BrokerOrderID: state.GetOrderId(), @@ -448,7 +460,7 @@ func orderFromState(state *pb.OrderState, accountID string) domain.Order { LimitPrice: money.MoneyValueToDecimal(state.GetInitialSecurityPrice()), QuantityLots: state.GetLotsRequested(), FilledLots: state.GetLotsExecuted(), - AvgFillPrice: money.MoneyValueToDecimal(state.GetAveragePositionPrice()), + AvgFillPrice: avgFillPrice, Status: mapOrderStatus(state.GetExecutionReportStatus()), Commission: money.MoneyValueToDecimal(state.GetExecutedCommission()), RawStateJSON: marshalProto(state), @@ -478,10 +490,52 @@ func marshalProto(msg proto.Message) string { if msg == nil { return "{}" } - raw, err := protojson.Marshal(msg) + sanitized := proto.Clone(msg) + clearSensitiveProtoFields(sanitized.ProtoReflect()) + raw, err := protojson.Marshal(sanitized) if err != nil { fallback, _ := json.Marshal(map[string]string{"marshal_error": err.Error()}) return string(fallback) } return string(raw) } + +func clearSensitiveProtoFields(message protoreflect.Message) { + if !message.IsValid() { + return + } + fields := message.Descriptor().Fields() + for i := 0; i < fields.Len(); i++ { + field := fields.Get(i) + if isSensitiveProtoField(field.Name()) { + message.Clear(field) + continue + } + value := message.Get(field) + switch { + case field.IsList(): + list := value.List() + if field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind { + for j := 0; j < list.Len(); j++ { + clearSensitiveProtoFields(list.Get(j).Message()) + } + } + case field.IsMap(): + if field.MapValue().Kind() == protoreflect.MessageKind || field.MapValue().Kind() == protoreflect.GroupKind { + value.Map().Range(func(_ protoreflect.MapKey, value protoreflect.Value) bool { + clearSensitiveProtoFields(value.Message()) + return true + }) + } + case field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind: + if message.Has(field) { + clearSensitiveProtoFields(value.Message()) + } + } + } +} + +func isSensitiveProtoField(name protoreflect.Name) bool { + normalized := strings.ReplaceAll(strings.ToLower(string(name)), "_", "") + return normalized == "accountid" +} diff --git a/internal/tinvest/real_test.go b/internal/tinvest/real_test.go new file mode 100644 index 0000000..4dbfd89 --- /dev/null +++ b/internal/tinvest/real_test.go @@ -0,0 +1,39 @@ +package tinvest + +import ( + "strings" + "testing" + + pb "github.com/russianinvestments/invest-api-go-sdk/proto" + "github.com/shopspring/decimal" + + "overnight-trading-bot/internal/domain" +) + +func TestOrderFromPostResponseZeroFillHasZeroAvgPrice(t *testing.T) { + order := orderFromPostResponse(&pb.PostOrderResponse{ + OrderId: "broker", + ExecutionReportStatus: pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_NEW, + LotsRequested: 1, + LotsExecuted: 0, + ExecutedOrderPrice: &pb.MoneyValue{Currency: "rub", Units: 100}, + InstrumentUid: "uid", + }, "account", "client", domain.SideBuy, decimal.NewFromInt(100)) + if !order.AvgFillPrice.IsZero() { + t.Fatalf("avg fill price=%s, want zero for unfilled order", order.AvgFillPrice) + } +} + +func TestMarshalProtoRedactsAccountID(t *testing.T) { + raw := marshalProto(&pb.OrderTrades{ + OrderId: "order", + AccountId: "plain-account-id", + InstrumentUid: "uid", + }) + if strings.Contains(raw, "plain-account-id") || strings.Contains(raw, "accountId") || strings.Contains(raw, "account_id") { + t.Fatalf("raw proto leaked account id: %s", raw) + } + if !strings.Contains(raw, "order") { + t.Fatalf("sanitizer removed non-sensitive data: %s", raw) + } +}