diff --git a/cmd/backtest/main.go b/cmd/backtest/main.go index 3b93151..c77ab75 100644 --- a/cmd/backtest/main.go +++ b/cmd/backtest/main.go @@ -49,7 +49,7 @@ func run() error { defer func() { _ = file.Close() }() - candles, err := backtest.LoadCandlesCSV(file) + candles, metadata, err := backtest.LoadCandlesCSVWithMetadata(file) if err != nil { return fmt.Errorf("load candles: %w", err) } @@ -62,10 +62,12 @@ func run() error { defer func() { _ = minuteFile.Close() }() - minuteCandles, err = backtest.LoadCandlesCSV(minuteFile) + var minuteMetadata map[string]backtest.InstrumentMetadata + minuteCandles, minuteMetadata, err = backtest.LoadCandlesCSVWithMetadata(minuteFile) if err != nil { return fmt.Errorf("load minute candles: %w", err) } + mergeMetadata(metadata, minuteMetadata) } if *useMinuteModel && len(minuteCandles) == 0 { return fmt.Errorf("-minute-candles is required when -use-minute-model=true") @@ -114,24 +116,27 @@ func run() error { if err != nil { return fmt.Errorf("max tick: %w", err) } + lotsByInstrument, ticksByInstrument := metadataMaps(metadata) engine := backtest.New(backtest.Config{ - EntrySlippageBps: entry, - ExitSlippageBps: exit, - CommissionRoundtripBps: comm, - RiskBufferBps: riskBuf, - OutputDir: *outputDir, - RollingShort: *rollingShort, - RollingLong: *rollingLong, - EWMALambda: *ewmaLambda, - MinTStat60: tstat, - MinWinRate60: winRate, - MinNetEdgeBps: netEdge, - MinADVRUB: adv, - MaxSpreadBps: spread, - MaxTickBps: tick, - AssumedSpreadBps: assumed, - RequireZeroCommission: requireZeroCommission, - UseMinuteModel: *useMinuteModel, + EntrySlippageBps: entry, + ExitSlippageBps: exit, + CommissionRoundtripBps: comm, + RiskBufferBps: riskBuf, + OutputDir: *outputDir, + RollingShort: *rollingShort, + RollingLong: *rollingLong, + EWMALambda: *ewmaLambda, + MinTStat60: tstat, + MinWinRate60: winRate, + MinNetEdgeBps: netEdge, + MinADVRUB: adv, + MaxSpreadBps: spread, + MaxTickBps: tick, + AssumedSpreadBps: assumed, + RequireZeroCommission: requireZeroCommission, + LotsByInstrument: lotsByInstrument, + MinPriceIncrementsByInstrument: ticksByInstrument, + UseMinuteModel: *useMinuteModel, }) result, err := engine.RunWithMinuteCandles(candles, minuteCandles) if err != nil { @@ -143,3 +148,32 @@ func run() error { fmt.Printf("backtest complete: trades=%d total_return=%.6f\n", result.Metrics.NumberOfTrades, result.Metrics.TotalReturn) return nil } + +func mergeMetadata(dst, src map[string]backtest.InstrumentMetadata) { + for uid, meta := range src { + current := dst[uid] + if current.Lot <= 0 { + current.Lot = meta.Lot + } + if !current.MinPriceIncrement.IsPositive() { + current.MinPriceIncrement = meta.MinPriceIncrement + } + if current.Lot > 0 || current.MinPriceIncrement.IsPositive() { + dst[uid] = current + } + } +} + +func metadataMaps(metadata map[string]backtest.InstrumentMetadata) (map[string]int64, map[string]decimal.Decimal) { + lots := make(map[string]int64) + ticks := make(map[string]decimal.Decimal) + for uid, meta := range metadata { + if meta.Lot > 0 { + lots[uid] = meta.Lot + } + if meta.MinPriceIncrement.IsPositive() { + ticks[uid] = meta.MinPriceIncrement + } + } + return lots, ticks +} diff --git a/internal/app/app.go b/internal/app/app.go index 71c5276..e1499b4 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -114,10 +114,12 @@ func Run(ctx context.Context, opts Options) error { defer closer() } accountIDHash := accountHash(cfg.TInvest.AccountID) + clock := timeutil.RealClock{Loc: cfg.Location} recon := reconciliation.New(repo, gateway, cfg.TInvest.AccountID, accountIDHash). WithWindow(time.Duration(cfg.Risk.ReconciliationWindowHours)*time.Hour). WithInFlightGrace(time.Duration(cfg.Risk.ReconciliationSkewSec)*time.Second). - WithCommissionPolicy(cfg.Commission.RequireZeroCommission, cfg.Commission.QuarantineOnNonZero, cfg.Risk.CommissionToleranceRUB) + WithCommissionPolicy(cfg.Commission.RequireZeroCommission, cfg.Commission.QuarantineOnNonZero, cfg.Risk.CommissionToleranceRUB). + WithClock(clock) diffs, err := recon.Run(ctx) if err != nil { return fmt.Errorf("pre-unhalt reconciliation: %w", err) @@ -159,10 +161,12 @@ func Run(ctx context.Context, opts Options) error { }() accountIDHash := accountHash(cfg.TInvest.AccountID) + clock := timeutil.RealClock{Loc: cfg.Location} recon := reconciliation.New(repo, gateway, cfg.TInvest.AccountID, accountIDHash). WithWindow(time.Duration(cfg.Risk.ReconciliationWindowHours)*time.Hour). WithInFlightGrace(time.Duration(cfg.Risk.ReconciliationSkewSec)*time.Second). - WithCommissionPolicy(cfg.Commission.RequireZeroCommission, cfg.Commission.QuarantineOnNonZero, cfg.Risk.CommissionToleranceRUB) + WithCommissionPolicy(cfg.Commission.RequireZeroCommission, cfg.Commission.QuarantineOnNonZero, cfg.Risk.CommissionToleranceRUB). + WithClock(clock) sm := statemachine.New(repo, cfg.App.Mode) if _, err := sm.Recover(ctx, recon); err != nil { _ = notifier.Alert(ctx, fmt.Sprintf("state recovery failed: %s", err)) @@ -184,14 +188,25 @@ func Run(ctx context.Context, opts Options) error { } runCtx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) defer stop() - clock := timeutil.RealClock{Loc: cfg.Location} runtime := buildScheduler(clock, sm, cfg, repo, gateway, notifier, recon, accountIDHash, log) - return runtime.Run(runCtx) + if err := runtime.Run(runCtx); err != nil { + if runCtx.Err() == nil { + return err + } + shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.App.ShutdownTimeoutSec)*time.Second) + defer cancel() + if shutdownErr := runtime.GracefulShutdown(shutdownCtx); shutdownErr != nil { + return fmt.Errorf("%w; graceful shutdown: %v", err, shutdownErr) + } + return err + } + return nil } func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Config, repo *mysqlrepo.Repository, gateway tinvest.Gateway, notifier notify.Notifier, recon reconciliation.Engine, accountIDHash string, log *slog.Logger) scheduler.Scheduler { registry := instruments.NewRegistry(repo, gateway) loader := marketdata.NewLoader(repo, gateway) + loader.SetClock(clock) pipeline := features.NewPipeline(repo, features.PipelineConfig{ RollingShort: cfg.Strategy.RollingShort, RollingLong: cfg.Strategy.RollingLong, @@ -243,6 +258,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con MaxQuoteAge: time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second, }) execEngine := execution.NewEngine(cfg.App.Mode, cfg.TInvest.AccountID, gateway, repo) + execEngine.SetClock(clock) execEngine.SetMaxQuoteAge(time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second) execEngine.SetFreeOrderCountPolicy(cfg.Commission.FreeOrderCountPolicy) services := scheduler.Services{ diff --git a/internal/backtest/engine.go b/internal/backtest/engine.go index bbea3cc..4c16835 100644 --- a/internal/backtest/engine.go +++ b/internal/backtest/engine.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "sort" + "strconv" "strings" "time" @@ -20,40 +21,48 @@ import ( ) type Config struct { - EntrySlippageBps decimal.Decimal - ExitSlippageBps decimal.Decimal - CommissionRoundtripBps decimal.Decimal - RiskBufferBps decimal.Decimal - InitialEquity decimal.Decimal - OutputDir string - RollingShort int - RollingLong int - EWMALambda float64 - MinTStat60 decimal.Decimal - MinWinRate60 decimal.Decimal - MinNetEdgeBps decimal.Decimal - MinADVRUB decimal.Decimal - MaxSpreadBps decimal.Decimal - MaxSpreadBpsMoneyMarket decimal.Decimal - MaxSpreadBpsBondFunds decimal.Decimal - MaxSpreadBpsEquityFunds decimal.Decimal - MaxTickBps decimal.Decimal - RequireZeroCommission *bool - MaxPositions int - MaxPositionPct decimal.Decimal - MaxTotalExposurePct decimal.Decimal - MaxParticipationRate decimal.Decimal - CashUsageBuffer decimal.Decimal - RiskBudgetPct decimal.Decimal - MinOrderNotionalRUB decimal.Decimal - AssumedSpreadBps decimal.Decimal - AssumedSpreadBpsByFundType map[string]decimal.Decimal - InstrumentFundTypes map[string]string - AssumedTickBps decimal.Decimal - Lot int64 - UseMinuteModel bool - EntryWindow TimeWindow - ExitWindow TimeWindow + EntrySlippageBps decimal.Decimal + ExitSlippageBps decimal.Decimal + CommissionRoundtripBps decimal.Decimal + RiskBufferBps decimal.Decimal + InitialEquity decimal.Decimal + OutputDir string + RollingShort int + RollingLong int + EWMALambda float64 + MinTStat60 decimal.Decimal + MinWinRate60 decimal.Decimal + MinNetEdgeBps decimal.Decimal + MinADVRUB decimal.Decimal + MaxSpreadBps decimal.Decimal + MaxSpreadBpsMoneyMarket decimal.Decimal + MaxSpreadBpsBondFunds decimal.Decimal + MaxSpreadBpsEquityFunds decimal.Decimal + MaxTickBps decimal.Decimal + RequireZeroCommission *bool + MaxPositions int + MaxPositionPct decimal.Decimal + MaxTotalExposurePct decimal.Decimal + MaxParticipationRate decimal.Decimal + CashUsageBuffer decimal.Decimal + RiskBudgetPct decimal.Decimal + MinOrderNotionalRUB decimal.Decimal + AssumedSpreadBps decimal.Decimal + AssumedSpreadBpsByFundType map[string]decimal.Decimal + InstrumentFundTypes map[string]string + AssumedTickBps decimal.Decimal + Lot int64 + LotsByInstrument map[string]int64 + MinPriceIncrement decimal.Decimal + MinPriceIncrementsByInstrument map[string]decimal.Decimal + UseMinuteModel bool + EntryWindow TimeWindow + ExitWindow TimeWindow +} + +type InstrumentMetadata struct { + Lot int64 + MinPriceIncrement decimal.Decimal } type TimeWindow struct { @@ -203,6 +212,9 @@ func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Can return Result{}, err } if ok { + if capacity := e.windowCapacity(candidate, preparedMinutes[instrumentUID]); capacity.IsPositive() { + candidate.capacity = capacity + } candidatesByExitDate[candidate.exit.TradeDate.Format("2006-01-02")] = append(candidatesByExitDate[candidate.exit.TradeDate.Format("2006-01-02")], candidate) } } @@ -243,7 +255,7 @@ func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Can Portfolio: domain.Portfolio{Equity: equity, Cash: cash}, SelectedInstruments: len(dayCandidates), LimitPrice: c.buy, - Lot: e.cfg.Lot, + Lot: c.lot, EntryIntervalVolume: c.adv, ExitIntervalVolume: c.adv, Q05OvernightAbs: c.q05Abs, @@ -261,7 +273,7 @@ func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Can lots = executedLots capacity = minuteCapacity } - notional := c.buy.Mul(decimal.NewFromInt(lots)).Mul(decimal.NewFromInt(e.cfg.Lot)) + notional := c.buy.Mul(decimal.NewFromInt(lots)).Mul(decimal.NewFromInt(c.lot)) ret := c.sell.Div(c.buy).Sub(decimal.NewFromInt(1)).Sub(money.FromBps(e.cfg.CommissionRoundtripBps)) pnl := notional.Mul(ret) dayPnL = dayPnL.Add(pnl) @@ -311,8 +323,12 @@ func (e Engine) minuteExecution(c candidate, minutes []domain.Candle, requestedL if requestedLots <= 0 || len(minutes) == 0 { return 0, decimal.Zero, false } - entryLots, entryCapacity := e.fillableMinuteLots(minutes, c.entry.TradeDate, c.buy, domain.SideBuy, e.cfg.EntryWindow) - exitLots, exitCapacity := e.fillableMinuteLots(minutes, c.exit.TradeDate, c.sell, domain.SideSell, e.cfg.ExitWindow) + lot := c.lot + if lot <= 0 { + lot = e.lotFor(c.instrumentUID) + } + entryLots, entryCapacity := e.fillableMinuteLots(minutes, c.entry.TradeDate, c.buy, lot, domain.SideBuy, e.cfg.EntryWindow) + exitLots, exitCapacity := e.fillableMinuteLots(minutes, c.exit.TradeDate, c.sell, lot, domain.SideSell, e.cfg.ExitWindow) lots := min(requestedLots, entryLots) lots = min(lots, exitLots) if lots <= 0 { @@ -321,11 +337,11 @@ func (e Engine) minuteExecution(c candidate, minutes []domain.Candle, requestedL return lots, money.Min(entryCapacity, exitCapacity), true } -func (e Engine) fillableMinuteLots(minutes []domain.Candle, date time.Time, limitPrice decimal.Decimal, side domain.Side, window TimeWindow) (int64, decimal.Decimal) { - if !limitPrice.IsPositive() || e.cfg.Lot <= 0 { +func (e Engine) fillableMinuteLots(minutes []domain.Candle, date time.Time, limitPrice decimal.Decimal, lot int64, side domain.Side, window TimeWindow) (int64, decimal.Decimal) { + if !limitPrice.IsPositive() || lot <= 0 { return 0, decimal.Zero } - lotNotional := limitPrice.Mul(decimal.NewFromInt(e.cfg.Lot)) + lotNotional := limitPrice.Mul(decimal.NewFromInt(lot)) if !lotNotional.IsPositive() { return 0, decimal.Zero } @@ -348,6 +364,36 @@ func (e Engine) fillableMinuteLots(minutes []domain.Candle, date time.Time, limi return capacity.Div(lotNotional).Floor().IntPart(), capacity } +func (e Engine) windowCapacity(c candidate, minutes []domain.Candle) decimal.Decimal { + if len(minutes) == 0 { + return decimal.Zero + } + lot := c.lot + if lot <= 0 { + lot = e.lotFor(c.instrumentUID) + } + if lot <= 0 { + return decimal.Zero + } + entryVolume := e.windowNotional(minutes, c.entry.TradeDate, e.cfg.EntryWindow, lot) + exitVolume := e.windowNotional(minutes, c.exit.TradeDate, e.cfg.ExitWindow, lot) + if !entryVolume.IsPositive() || !exitVolume.IsPositive() { + return decimal.Zero + } + return money.Min(entryVolume, exitVolume).Mul(e.cfg.MaxParticipationRate) +} + +func (e Engine) windowNotional(minutes []domain.Candle, date time.Time, window TimeWindow, lot int64) decimal.Decimal { + total := decimal.Zero + for _, candle := range minutes { + if !sameDate(candle.TradeDate, date) || !window.Contains(candle.TradeDate) { + continue + } + total = total.Add(candle.VolumeLots.Mul(decimal.NewFromInt(lot)).Mul(candle.Close)) + } + return total +} + func (w TimeWindow) Contains(ts time.Time) bool { if w.Start == 0 && w.End == 0 { return true @@ -376,12 +422,14 @@ type candidate struct { q05Abs decimal.Decimal overnightGap decimal.Decimal capacity decimal.Decimal + lot int64 } func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, exitIndex int) (candidate, bool, error) { if exitIndex < e.cfg.RollingShort || exitIndex <= 0 { return candidate{}, false, nil } + lot := e.lotFor(instrumentUID) history := candles[:exitIndex] returns := make([]float64, 0, len(history)-1) for j := 1; j < len(history); j++ { @@ -405,7 +453,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, Add(e.cfg.CommissionRoundtripBps). Add(e.cfg.RiskBufferBps) netEdge := rawEdge.Sub(cost) - adv := features.ADV(history, e.cfg.Lot, 20) + adv := features.ADV(history, lot, 20) switch { case e.requireZeroCommission() && e.cfg.CommissionRoundtripBps.IsPositive(): return candidate{}, false, nil @@ -428,6 +476,17 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, exit := candles[exitIndex] buy := entry.Close.Mul(decimal.NewFromInt(1).Add(money.FromBps(e.cfg.EntrySlippageBps))) sell := exit.Open.Mul(decimal.NewFromInt(1).Sub(money.FromBps(e.cfg.ExitSlippageBps))) + if tick := e.minPriceIncrementFor(instrumentUID); tick.IsPositive() { + var err error + buy, err = money.RoundToTick(buy, tick, money.RoundCeil) + if err != nil { + return candidate{}, false, err + } + sell, err = money.RoundToTick(sell, tick, money.RoundFloor) + if err != nil { + return candidate{}, false, err + } + } gap, err := features.OvernightReturn(exit.Open, entry.Close) if err != nil { return candidate{}, false, err @@ -448,9 +507,31 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, q05Abs: q05Abs, overnightGap: gap, capacity: adv.Mul(e.cfg.MaxParticipationRate), + lot: lot, }, true, nil } +func (e Engine) lotFor(instrumentUID string) int64 { + if e.cfg.LotsByInstrument != nil { + if lot := e.cfg.LotsByInstrument[instrumentUID]; lot > 0 { + return lot + } + } + if e.cfg.Lot > 0 { + return e.cfg.Lot + } + return 1 +} + +func (e Engine) minPriceIncrementFor(instrumentUID string) decimal.Decimal { + if e.cfg.MinPriceIncrementsByInstrument != nil { + if tick := e.cfg.MinPriceIncrementsByInstrument[instrumentUID]; tick.IsPositive() { + return tick + } + } + return e.cfg.MinPriceIncrement +} + func (e Engine) requireZeroCommission() bool { return e.cfg.RequireZeroCommission != nil && *e.cfg.RequireZeroCommission } @@ -536,46 +617,60 @@ func (r Result) Write(outputDir string) error { } func LoadCandlesCSV(reader io.Reader) (map[string][]domain.Candle, error) { + candles, _, err := LoadCandlesCSVWithMetadata(reader) + return candles, err +} + +func LoadCandlesCSVWithMetadata(reader io.Reader) (map[string][]domain.Candle, map[string]InstrumentMetadata, error) { r := csv.NewReader(reader) r.FieldsPerRecord = -1 records, err := r.ReadAll() if err != nil { - return nil, err + return nil, nil, err } out := make(map[string][]domain.Candle) - for i, record := range records { - if i == 0 && len(record) > 0 && record[0] == "instrument_uid" { - continue + metadata := make(map[string]InstrumentMetadata) + header := map[string]int(nil) + start := 0 + if len(records) > 0 && len(records[0]) > 0 && strings.EqualFold(strings.TrimSpace(records[0][0]), "instrument_uid") { + header = make(map[string]int, len(records[0])) + for i, name := range records[0] { + header[strings.ToLower(strings.TrimSpace(name))] = i } + start = 1 + } + for i := start; i < len(records); i++ { + record := records[i] if len(record) < 7 { - return nil, fmt.Errorf("line %d: expected 7 fields", i+1) + return nil, nil, fmt.Errorf("line %d: expected at least 7 fields", i+1) } - date, err := parseCandleTime(record[1]) + instrumentUID := csvValue(record, header, "instrument_uid", 0) + date, err := parseCandleTime(csvValue(record, header, "trade_date", 1)) if err != nil { - return nil, err + return nil, nil, err } - open, err := decimal.NewFromString(record[2]) + open, err := decimal.NewFromString(csvValue(record, header, "open", 2)) if err != nil { - return nil, err + return nil, nil, err } - high, err := decimal.NewFromString(record[3]) + high, err := decimal.NewFromString(csvValue(record, header, "high", 3)) if err != nil { - return nil, err + return nil, nil, err } - low, err := decimal.NewFromString(record[4]) + low, err := decimal.NewFromString(csvValue(record, header, "low", 4)) if err != nil { - return nil, err + return nil, nil, err } - closePrice, err := decimal.NewFromString(record[5]) + closePrice, err := decimal.NewFromString(csvValue(record, header, "close", 5)) if err != nil { - return nil, err + return nil, nil, err } - volume, err := decimal.NewFromString(record[6]) + volume, err := decimal.NewFromString(csvValue(record, header, "volume_lots", 6)) if err != nil { - return nil, err + return nil, nil, err } candle := domain.Candle{ - InstrumentUID: record[0], + InstrumentUID: instrumentUID, TradeDate: date, Open: open, High: high, @@ -586,8 +681,49 @@ func LoadCandlesCSV(reader io.Reader) (map[string][]domain.Candle, error) { LoadedAt: time.Now().UTC(), } out[candle.InstrumentUID] = append(out[candle.InstrumentUID], candle) + meta := metadata[candle.InstrumentUID] + if raw, ok := optionalCSVValue(record, header, "lot", 7); ok && strings.TrimSpace(raw) != "" { + lot, err := strconv.ParseInt(strings.TrimSpace(raw), 10, 64) + if err != nil { + return nil, nil, fmt.Errorf("line %d: parse lot: %w", i+1, err) + } + if lot > 0 { + meta.Lot = lot + } + } + if raw, ok := optionalCSVValue(record, header, "min_price_increment", 8); ok && strings.TrimSpace(raw) != "" { + tick, err := decimal.NewFromString(strings.TrimSpace(raw)) + if err != nil { + return nil, nil, fmt.Errorf("line %d: parse min_price_increment: %w", i+1, err) + } + if tick.IsPositive() { + meta.MinPriceIncrement = tick + } + } + if meta.Lot > 0 || meta.MinPriceIncrement.IsPositive() { + metadata[candle.InstrumentUID] = meta + } } - return out, nil + return out, metadata, nil +} + +func csvValue(record []string, header map[string]int, name string, fallback int) string { + value, _ := optionalCSVValue(record, header, name, fallback) + return strings.TrimSpace(value) +} + +func optionalCSVValue(record []string, header map[string]int, name string, fallback int) (string, bool) { + if header != nil { + idx, ok := header[name] + if !ok || idx < 0 || idx >= len(record) { + return "", false + } + return record[idx], true + } + if fallback < 0 || fallback >= len(record) { + return "", false + } + return record[fallback], true } func parseCandleTime(raw string) (time.Time, error) { diff --git a/internal/backtest/lookahead_test.go b/internal/backtest/lookahead_test.go index e50543b..d4b7371 100644 --- a/internal/backtest/lookahead_test.go +++ b/internal/backtest/lookahead_test.go @@ -1,6 +1,7 @@ package backtest import ( + "strings" "testing" "time" @@ -69,3 +70,86 @@ func TestMinuteExecutionRequiresReachableLimitAndParticipation(t *testing.T) { t.Fatal("sell limit should be unreachable") } } + +func TestEvaluateCandidateUsesInstrumentLotAndTick(t *testing.T) { + engine := New(Config{ + RollingShort: 2, + RollingLong: 2, + MinTStat60: decimal.NewFromInt(-1), + MinWinRate60: decimal.NewFromFloat(0.1), + MinNetEdgeBps: decimal.NewFromInt(-1000), + MinADVRUB: decimal.NewFromInt(1), + Lot: 1, + LotsByInstrument: map[string]int64{"uid": 10}, + MinPriceIncrementsByInstrument: map[string]decimal.Decimal{"uid": decimal.NewFromFloat(0.05)}, + EntrySlippageBps: decimal.NewFromInt(13), + ExitSlippageBps: decimal.NewFromInt(13), + }) + candles := candidateCandles("uid") + got, ok, err := engine.evaluateCandidate("uid", candles, 3) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("expected candidate") + } + if got.lot != 10 { + t.Fatalf("lot=%d, want 10", got.lot) + } + if !got.adv.Equal(decimal.NewFromInt(10_000)) { + t.Fatalf("adv=%s, want 10000", got.adv) + } + if !got.buy.Equal(decimal.NewFromFloat(100.15)) { + t.Fatalf("buy=%s, want rounded 100.15", got.buy) + } + if !got.sell.Equal(decimal.NewFromFloat(104.85)) { + t.Fatalf("sell=%s, want rounded 104.85", got.sell) + } +} + +func TestWindowCapacityUsesMinuteEntryAndExitWindows(t *testing.T) { + engine := New(Config{ + Lot: 10, + MaxParticipationRate: decimal.NewFromFloat(0.10), + }) + entryDate := time.Date(2024, 1, 2, 18, 25, 0, 0, time.UTC) + exitDate := time.Date(2024, 1, 3, 10, 5, 0, 0, time.UTC) + got := engine.windowCapacity(candidate{ + instrumentUID: "uid", + entry: domain.Candle{TradeDate: entryDate}, + exit: domain.Candle{TradeDate: exitDate}, + }, []domain.Candle{ + {TradeDate: entryDate, Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(20)}, + {TradeDate: exitDate, Close: decimal.NewFromInt(200), VolumeLots: decimal.NewFromInt(5)}, + {TradeDate: time.Date(2024, 1, 3, 12, 0, 0, 0, time.UTC), Close: decimal.NewFromInt(999), VolumeLots: decimal.NewFromInt(999)}, + }) + if !got.Equal(decimal.NewFromInt(1000)) { + t.Fatalf("capacity=%s, want min(entry=20000, exit=10000)*0.10 = 1000", got) + } +} + +func TestLoadCandlesCSVWithMetadata(t *testing.T) { + raw := strings.NewReader(`instrument_uid,trade_date,open,high,low,close,volume_lots,lot,min_price_increment +uid,2024-01-02,100,101,99,100,10,10,0.05 +`) + candles, metadata, err := LoadCandlesCSVWithMetadata(raw) + if err != nil { + t.Fatal(err) + } + if len(candles["uid"]) != 1 { + t.Fatalf("candles=%+v", candles) + } + if metadata["uid"].Lot != 10 || !metadata["uid"].MinPriceIncrement.Equal(decimal.NewFromFloat(0.05)) { + t.Fatalf("metadata=%+v", metadata["uid"]) + } +} + +func candidateCandles(uid string) []domain.Candle { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + return []domain.Candle{ + {InstrumentUID: uid, TradeDate: start, Open: decimal.NewFromInt(100), Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)}, + {InstrumentUID: uid, TradeDate: start.AddDate(0, 0, 1), Open: decimal.NewFromInt(101), Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)}, + {InstrumentUID: uid, TradeDate: start.AddDate(0, 0, 2), Open: decimal.NewFromInt(102), Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)}, + {InstrumentUID: uid, TradeDate: start.AddDate(0, 0, 3), Open: decimal.NewFromInt(105), Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)}, + } +} diff --git a/internal/domain/types.go b/internal/domain/types.go index 68ed72b..d83a2d0 100644 --- a/internal/domain/types.go +++ b/internal/domain/types.go @@ -175,6 +175,7 @@ type FeatureSet struct { MuOn60 decimal.Decimal MuOn252 decimal.Decimal SigmaOn60 decimal.Decimal + Q05On60Abs decimal.Decimal TStatOn60 decimal.Decimal WinOn60 decimal.Decimal EWMAOn decimal.Decimal diff --git a/internal/execution/engine.go b/internal/execution/engine.go index 881e19b..476e3e0 100644 --- a/internal/execution/engine.go +++ b/internal/execution/engine.go @@ -12,6 +12,7 @@ import ( "overnight-trading-bot/internal/domain" "overnight-trading-bot/internal/repository" "overnight-trading-bot/internal/risk" + "overnight-trading-bot/internal/timeutil" ) var ErrBrokerOrdersDisabled = errors.New("broker orders are disabled for current mode") @@ -35,6 +36,7 @@ type Engine struct { store repository.Repository maxQuoteAge time.Duration freeOrderCountPolicy string + clock timeutil.Clock mu sync.Map } @@ -50,13 +52,26 @@ type MonitorConfig struct { } func NewEngine(mode domain.Mode, accountID string, gateway Gateway, store repository.Repository) Engine { - return Engine{mode: mode, accountID: accountID, gateway: gateway, store: store, freeOrderCountPolicy: FreeOrderPolicySubmitted} + return Engine{ + mode: mode, + accountID: accountID, + gateway: gateway, + store: store, + freeOrderCountPolicy: FreeOrderPolicySubmitted, + clock: timeutil.RealClock{}, + } } func (e *Engine) SetMaxQuoteAge(maxQuoteAge time.Duration) { e.maxQuoteAge = maxQuoteAge } +func (e *Engine) SetClock(clock timeutil.Clock) { + if clock != nil { + e.clock = clock + } +} + func (e *Engine) SetFreeOrderCountPolicy(policy string) { switch policy { case FreeOrderPolicyCancelCounts: @@ -78,7 +93,7 @@ func (e *Engine) PlaceEntry(ctx context.Context, accountIDHash string, instrumen if err != nil { return domain.Order{}, err } - return e.PlaceLimit(ctx, domain.Order{ + return e.placeLimit(ctx, domain.Order{ ClientOrderID: ClientOrderID(tradeDate, instrument.InstrumentUID, domain.SideBuy, attempt), AccountIDHash: accountIDHash, InstrumentUID: instrument.InstrumentUID, @@ -90,7 +105,7 @@ func (e *Engine) PlaceEntry(ctx context.Context, accountIDHash string, instrumen Status: domain.OrderStatusNew, AttemptNo: attempt, RawStateJSON: "{}", - }) + }, instrument.FreeOrderLimitPerDay) } func (e *Engine) PlaceExit(ctx context.Context, accountIDHash string, instrument domain.Instrument, tradeDate time.Time, lots int64, book domain.OrderBook, improveTicks int, attempt int) (domain.Order, error) { @@ -105,7 +120,7 @@ func (e *Engine) PlaceExit(ctx context.Context, accountIDHash string, instrument if err != nil { return domain.Order{}, err } - return e.PlaceLimit(ctx, domain.Order{ + return e.placeLimit(ctx, domain.Order{ ClientOrderID: ClientOrderID(tradeDate, instrument.InstrumentUID, domain.SideSell, attempt), AccountIDHash: accountIDHash, InstrumentUID: instrument.InstrumentUID, @@ -117,10 +132,14 @@ func (e *Engine) PlaceExit(ctx context.Context, accountIDHash string, instrument Status: domain.OrderStatusNew, AttemptNo: attempt, RawStateJSON: "{}", - }) + }, instrument.FreeOrderLimitPerDay) } func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Order, error) { + return e.placeLimit(ctx, order, 0) +} + +func (e *Engine) placeLimit(ctx context.Context, order domain.Order, freeOrderLimit int) (domain.Order, error) { lock := e.lockFor(order.InstrumentUID) lock.Lock() defer lock.Unlock() @@ -134,7 +153,7 @@ func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Ord } } if e.mode == domain.ModePaper { - return e.placePaperLimit(ctx, order) + return e.placePaperLimit(ctx, order, freeOrderLimit) } if !e.mode.AllowsBrokerOrders() { order.Status = domain.OrderStatusNew @@ -147,7 +166,7 @@ func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Ord return domain.Order{}, errors.New("gateway is nil") } - now := time.Now().UTC() + now := e.nowUTC() draft := order draft.Status = domain.OrderStatusSent draft.CreatedAt = now @@ -156,8 +175,13 @@ func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Ord draft.RawStateJSON = "{}" } if e.store != nil { - if err := e.store.UpsertOrder(ctx, draft); err != nil { - return domain.Order{}, fmt.Errorf("persist draft order: %w", err) + if err := e.store.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error { + if err := repo.UpsertOrder(ctx, draft); err != nil { + return fmt.Errorf("persist draft order: %w", err) + } + return repo.ReserveFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1, freeOrderLimit) + }); err != nil { + return domain.Order{}, err } } posted, err := e.gateway.PostLimitOrder(ctx, e.accountID, order.InstrumentUID, order.Side, order.QuantityLots, order.LimitPrice, order.ClientOrderID) @@ -180,20 +204,15 @@ func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Ord posted.CreatedAt = now posted.UpdatedAt = posted.CreatedAt if e.store != nil { - if err := e.store.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error { - if err := repo.UpsertOrder(ctx, posted); err != nil { - return fmt.Errorf("persist posted order: %w", err) - } - return repo.IncrementFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1) - }); err != nil { + if err := e.store.UpsertOrder(ctx, posted); err != nil { return domain.Order{}, err } } return posted, nil } -func (e *Engine) placePaperLimit(ctx context.Context, order domain.Order) (domain.Order, error) { - now := time.Now().UTC() +func (e *Engine) placePaperLimit(ctx context.Context, order domain.Order, freeOrderLimit int) (domain.Order, error) { + now := e.nowUTC() order.BrokerOrderID = "paper-" + order.ClientOrderID order.FilledLots = order.QuantityLots order.AvgFillPrice = order.LimitPrice @@ -206,7 +225,7 @@ func (e *Engine) placePaperLimit(ctx context.Context, order domain.Order) (domai if err := repo.UpsertOrder(ctx, order); err != nil { return fmt.Errorf("persist paper order: %w", err) } - return repo.IncrementFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1) + return repo.ReserveFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1, freeOrderLimit) }); err != nil { return domain.Order{}, err } @@ -286,12 +305,10 @@ func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg Monit if cfg.MaxAttempts <= 0 { cfg.MaxAttempts = 1 } - lastPost := time.Now() + lastPost := e.nowUTC() current := order aggregate := order seen := map[string]domain.Order{order.ClientOrderID: order} - ticker := time.NewTicker(cfg.PollInterval) - defer ticker.Stop() for { previous := seen[current.ClientOrderID] refreshed, err := e.Refresh(ctx, current) @@ -311,7 +328,7 @@ func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg Monit if isTerminal(current.Status) { return aggregate, nil } - if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) { + if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) { if err := e.Cancel(ctx, current); err != nil { return aggregate, err } @@ -324,7 +341,7 @@ func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg Monit return aggregate, nil } shouldRepost := cfg.RepostAfter > 0 && - time.Since(lastPost) >= cfg.RepostAfter && + e.nowUTC().Sub(lastPost) >= cfg.RepostAfter && current.AttemptNo < cfg.MaxAttempts && aggregate.FilledLots < aggregate.QuantityLots && cfg.Quote != nil @@ -337,13 +354,11 @@ func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg Monit current = next seen[current.ClientOrderID] = current } - lastPost = time.Now() + lastPost = e.nowUTC() continue } - select { - case <-ctx.Done(): + if !e.sleep(ctx, cfg.PollInterval) { return aggregate, ctx.Err() - case <-ticker.C: } } } @@ -372,7 +387,7 @@ func (e *Engine) MonitorOnce(ctx context.Context, order domain.Order, cfg Monito if isTerminal(current.Status) { return aggregate, nil } - if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) { + if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) { if err := e.Cancel(ctx, current); err != nil { return aggregate, err } @@ -385,7 +400,7 @@ func (e *Engine) MonitorOnce(ctx context.Context, order domain.Order, cfg Monito return aggregate, nil } shouldRepost := cfg.RepostAfter > 0 && - repostDue(current, cfg.RepostAfter) && + e.repostDue(current, cfg.RepostAfter) && current.AttemptNo < cfg.MaxAttempts && aggregate.FilledLots < aggregate.QuantityLots && cfg.Quote != nil @@ -409,7 +424,7 @@ func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConf if err := e.ensureRepostBudget(ctx, order, cfg.Instrument); err != nil { return domain.Order{}, false, err } - if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) { + if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) { return order, false, nil } book, err := cfg.Quote(ctx, order.InstrumentUID) @@ -432,7 +447,7 @@ func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConf cancelled.Status = domain.OrderStatusFilled return cancelled, true, nil } - if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) { + if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) { return cancelled, true, nil } book, err = cfg.Quote(ctx, order.InstrumentUID) @@ -440,7 +455,7 @@ func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConf return domain.Order{}, false, err } if cfg.RepostCheck != nil { - if err := cfg.RepostCheck(ctx, order, cfg.Instrument, book); err != nil { + if err := cfg.RepostCheck(ctx, cancelled, cfg.Instrument, book); err != nil { return cancelled, true, nil } } @@ -468,25 +483,16 @@ func (e *Engine) waitTerminal(ctx context.Context, order domain.Order, cfg Monit if isTerminal(current.Status) { return current, nil } - if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) { + if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) { return current, nil } - timer := time.NewTimer(cfg.PollInterval) - select { - case <-ctx.Done(): - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } + if !e.sleep(ctx, cfg.PollInterval) { return domain.Order{}, ctx.Err() - case <-timer.C: } } } -func repostDue(order domain.Order, after time.Duration) bool { +func (e *Engine) repostDue(order domain.Order, after time.Duration) bool { if after <= 0 { return false } @@ -497,7 +503,7 @@ func repostDue(order domain.Order, after time.Duration) bool { if basis.IsZero() { return true } - return time.Since(basis) >= after + return e.nowUTC().Sub(basis) >= after } func (e *Engine) ensureRepostBudget(ctx context.Context, order domain.Order, instrument domain.Instrument) error { @@ -530,7 +536,7 @@ func (e *Engine) checkQuoteFresh(book domain.OrderBook) error { if book.ReceivedAt.IsZero() { return fmt.Errorf("quote received timestamp is missing") } - age := time.Since(book.ReceivedAt) + age := e.nowUTC().Sub(book.ReceivedAt) if age > e.maxQuoteAge { return fmt.Errorf("quote age %s exceeds %s", age, e.maxQuoteAge) } @@ -541,11 +547,29 @@ func (e *Engine) lockFor(instrumentUID string) *sync.Mutex { value, _ := e.mu.LoadOrStore(instrumentUID, &sync.Mutex{}) lock, ok := value.(*sync.Mutex) if !ok { - panic("execution lock has unexpected type") + lock = &sync.Mutex{} + e.mu.Store(instrumentUID, lock) } return lock } +func (e *Engine) nowUTC() time.Time { + if e.clock == nil { + return time.Now().UTC() + } + return e.clock.Now().UTC() +} + +func (e *Engine) sleep(ctx context.Context, d time.Duration) bool { + if d <= 0 { + return true + } + if e.clock == nil { + return timeutil.RealClock{}.Sleep(ctx.Done(), d) + } + return e.clock.Sleep(ctx.Done(), d) +} + func bestBidAsk(book domain.OrderBook) (decimal.Decimal, decimal.Decimal, error) { bid, ok := book.BestBid() if !ok { diff --git a/internal/execution/state_test.go b/internal/execution/state_test.go index 830f62c..3c18e0c 100644 --- a/internal/execution/state_test.go +++ b/internal/execution/state_test.go @@ -2,16 +2,30 @@ package execution import ( "context" + "errors" "testing" "time" "github.com/shopspring/decimal" "overnight-trading-bot/internal/domain" + "overnight-trading-bot/internal/risk" "overnight-trading-bot/internal/testutil" "overnight-trading-bot/internal/tinvest" ) +type fixedClock struct { + now time.Time +} + +func (c *fixedClock) Now() time.Time { + return c.now +} + +func (c *fixedClock) Sleep(<-chan struct{}, time.Duration) bool { + return true +} + func TestClientOrderIDIncludesAttempt(t *testing.T) { date := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC) first := ClientOrderID(date, "uid:TRUR", domain.SideBuy, 1) @@ -66,6 +80,86 @@ func TestPlaceLimitSuppressesDuplicateSubmit(t *testing.T) { } } +func TestPlaceEntryReservesFreeOrderBudgetAtomically(t *testing.T) { + ctx := context.Background() + repo := testutil.NewMemoryRepository() + gateway := tinvest.NewFakeGateway() + engine := NewEngine(domain.ModeSandbox, "account", gateway, repo) + instrument := domain.Instrument{ + InstrumentUID: "uid", + Lot: 1, + MinPriceIncrement: decimal.NewFromInt(1), + FreeOrderLimitPerDay: 1, + } + book := domain.OrderBook{ + InstrumentUID: "uid", + Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(99), QuantityLots: 10}}, + Asks: []domain.OrderBookLevel{{Price: decimal.NewFromInt(101), QuantityLots: 10}}, + ReceivedAt: time.Now().UTC(), + } + tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC) + if _, err := engine.PlaceEntry(ctx, "hash", instrument, tradeDate, 1, book, 1, 1); err != nil { + t.Fatal(err) + } + _, err := engine.PlaceEntry(ctx, "hash", instrument, tradeDate, 1, book, 1, 2) + if !errors.Is(err, risk.ErrFreeOrderBudget) { + t.Fatalf("expected free order budget error, got %v", err) + } + if got := len(gateway.Orders); got != 1 { + t.Fatalf("broker orders=%d, want no second post", got) + } +} + +func TestMonitorOnceUsesInjectedClockForDeadline(t *testing.T) { + ctx := context.Background() + repo := testutil.NewMemoryRepository() + gateway := tinvest.NewFakeGateway() + engine := NewEngine(domain.ModeSandbox, "account", gateway, repo) + clock := &fixedClock{now: time.Date(2030, 1, 1, 10, 0, 0, 0, time.UTC)} + engine.SetClock(clock) + order, err := engine.PlaceLimit(ctx, domain.Order{ + ClientOrderID: "clocked", + AccountIDHash: "hash", + InstrumentUID: "uid", + TradeDate: clock.now, + Side: domain.SideBuy, + OrderType: domain.OrderTypeLimit, + LimitPrice: decimal.NewFromInt(100), + QuantityLots: 1, + Status: domain.OrderStatusNew, + AttemptNo: 1, + }) + if err != nil { + t.Fatal(err) + } + if !order.CreatedAt.Equal(clock.now) { + t.Fatalf("created_at=%s, want injected clock %s", order.CreatedAt, clock.now) + } + monitored, err := engine.MonitorOnce(ctx, order, MonitorConfig{ + Deadline: clock.now.Add(time.Minute), + PollInterval: time.Millisecond, + MaxAttempts: 1, + }) + if err != nil { + t.Fatal(err) + } + if monitored.Status == domain.OrderStatusExpired { + t.Fatalf("order expired before injected deadline: %+v", monitored) + } + clock.now = clock.now.Add(time.Minute) + monitored, err = engine.MonitorOnce(ctx, order, MonitorConfig{ + Deadline: clock.now, + PollInterval: time.Millisecond, + MaxAttempts: 1, + }) + if err != nil { + t.Fatal(err) + } + if monitored.Status != domain.OrderStatusExpired { + t.Fatalf("status=%s, want EXPIRED at injected deadline", monitored.Status) + } +} + func TestPaperPlaceEntryFillsAndCountsSubmittedOrder(t *testing.T) { ctx := context.Background() repo := testutil.NewMemoryRepository() diff --git a/internal/features/pipeline.go b/internal/features/pipeline.go index 86024ca..a6442ca 100644 --- a/internal/features/pipeline.go +++ b/internal/features/pipeline.go @@ -109,14 +109,14 @@ func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate ti } short := Rolling(overnight, cfg.RollingShort, cfg.EWMALambda) long := Rolling(overnight, cfg.RollingLong, cfg.EWMALambda) + q05Abs := rollingQ05Abs(overnight, cfg.RollingShort) adv := ADV(candles, instrument.Lot, 20) rawEdgeBps := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000)) - instrumentCommission := instrument.ExpectedCommissionBpsPerSide.Mul(decimal.NewFromInt(2)) + commission := roundTripCommissionBps(instrument, cfg) expectedCost := spread.SpreadBps. Add(cfg.EntrySlippageBps). Add(cfg.ExitSlippageBps). - Add(cfg.CommissionRoundtripBps). - Add(instrumentCommission). + Add(commission). Add(cfg.RiskBufferBps) return domain.FeatureSet{ InstrumentUID: instrument.InstrumentUID, @@ -126,6 +126,7 @@ func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate ti MuOn60: decimal.NewFromFloat(short.Mean), MuOn252: decimal.NewFromFloat(long.Mean), SigmaOn60: decimal.NewFromFloat(short.StdDev), + Q05On60Abs: q05Abs, TStatOn60: decimal.NewFromFloat(short.TStat), WinOn60: decimal.NewFromFloat(short.WinRate), EWMAOn: decimal.NewFromFloat(short.EWMA), @@ -141,6 +142,26 @@ func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate ti }, nil } +func rollingQ05Abs(values []float64, window int) decimal.Decimal { + if window <= 0 || len(values) < window { + return decimal.Zero + } + sample := values[len(values)-window:] + q05 := decimal.NewFromFloat(Quantile(sample, 0.05)) + if q05.IsNegative() { + return q05.Neg() + } + return q05 +} + +func roundTripCommissionBps(instrument domain.Instrument, cfg PipelineConfig) decimal.Decimal { + instrumentCommission := instrument.ExpectedCommissionBpsPerSide.Mul(decimal.NewFromInt(2)) + if instrumentCommission.IsPositive() { + return instrumentCommission + } + return cfg.CommissionRoundtripBps +} + func historicalDailyCandles(candles []domain.Candle, tradeDate time.Time) []domain.Candle { tradeDay := dateOnly(tradeDate) out := make([]domain.Candle, 0, len(candles)) diff --git a/internal/features/pipeline_test.go b/internal/features/pipeline_test.go index 4db1d97..178fadd 100644 --- a/internal/features/pipeline_test.go +++ b/internal/features/pipeline_test.go @@ -41,14 +41,93 @@ func TestComputeExpectedCostIncludesCommissionAndSlippage(t *testing.T) { if err != nil { t.Fatal(err) } - if !got.ExpectedCostBps.Equal(decimal.NewFromInt(26)) { - t.Fatalf("expected cost=%s, want 26", got.ExpectedCostBps) + if !got.ExpectedCostBps.Equal(decimal.NewFromInt(22)) { + t.Fatalf("expected cost=%s, want 22", got.ExpectedCostBps) } if !got.EntryIntervalVolume.Equal(decimal.NewFromInt(10000)) || !got.ExitIntervalVolume.Equal(decimal.NewFromInt(9000)) { t.Fatalf("interval volumes were not preserved: %+v", got) } } +func TestComputeExpectedCostFallsBackToConfigCommission(t *testing.T) { + candles := flatCandles(time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), 6) + got, err := Compute(domain.Instrument{ + InstrumentUID: "uid", + Lot: 1, + }, candles, candles[5].TradeDate, SpreadResult{SpreadBps: decimal.NewFromInt(10)}, PipelineConfig{ + RollingShort: 2, + RollingLong: 2, + EWMALambda: 0.08, + RiskBufferBps: decimal.NewFromInt(5), + EntrySlippageBps: decimal.NewFromInt(2), + ExitSlippageBps: decimal.NewFromInt(3), + CommissionRoundtripBps: decimal.NewFromInt(4), + }, decimal.Zero, decimal.Zero) + if err != nil { + t.Fatal(err) + } + if !got.ExpectedCostBps.Equal(decimal.NewFromInt(24)) { + t.Fatalf("expected cost=%s, want 24", got.ExpectedCostBps) + } +} + +func TestComputeStoresHistoricalQ05Abs(t *testing.T) { + start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + returns := []string{"-0.10", "0.01", "0.02", "0.03", "0.04"} + candles := []domain.Candle{{ + InstrumentUID: "uid", + TradeDate: start, + Open: decimal.NewFromInt(100), + Close: decimal.NewFromInt(100), + VolumeLots: decimal.NewFromInt(1), + }} + for i, raw := range returns { + r, err := decimal.NewFromString(raw) + if err != nil { + t.Fatal(err) + } + open := decimal.NewFromInt(100).Mul(decimal.NewFromInt(1).Add(r)) + candles = append(candles, domain.Candle{ + InstrumentUID: "uid", + TradeDate: start.AddDate(0, 0, i+1), + Open: open, + Close: decimal.NewFromInt(100), + VolumeLots: decimal.NewFromInt(1), + }) + } + got, err := Compute(domain.Instrument{InstrumentUID: "uid", Lot: 1}, candles, start.AddDate(0, 0, 6), SpreadResult{}, PipelineConfig{ + RollingShort: 5, + RollingLong: 5, + EWMALambda: 0.08, + }, decimal.Zero, decimal.Zero) + if err != nil { + t.Fatal(err) + } + want := decimal.NewFromFloat(0.078) + diff := got.Q05On60Abs.Sub(want) + if diff.IsNegative() { + diff = diff.Neg() + } + if diff.GreaterThan(decimal.NewFromFloat(0.000001)) { + t.Fatalf("Q05On60Abs=%s, want about %s", got.Q05On60Abs, want) + } +} + +func flatCandles(start time.Time, count int) []domain.Candle { + candles := make([]domain.Candle, 0, count) + for i := 0; i < count; i++ { + price := decimal.NewFromInt(int64(100 + i)) + candles = append(candles, domain.Candle{ + InstrumentUID: "uid", + TradeDate: start.AddDate(0, 0, i), + Open: price, + Close: price, + VolumeLots: decimal.NewFromInt(1000), + }) + } + return candles +} + func TestIntervalVolume(t *testing.T) { got := IntervalVolume([]domain.Candle{ {Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)}, diff --git a/internal/marketdata/loader.go b/internal/marketdata/loader.go index 26f2936..c884292 100644 --- a/internal/marketdata/loader.go +++ b/internal/marketdata/loader.go @@ -7,16 +7,24 @@ import ( "overnight-trading-bot/internal/domain" "overnight-trading-bot/internal/repository" + "overnight-trading-bot/internal/timeutil" "overnight-trading-bot/internal/tinvest" ) type Loader struct { repo repository.Repository gateway tinvest.Gateway + clock timeutil.Clock } func NewLoader(repo repository.Repository, gateway tinvest.Gateway) Loader { - return Loader{repo: repo, gateway: gateway} + return Loader{repo: repo, gateway: gateway, clock: timeutil.RealClock{}} +} + +func (l *Loader) SetClock(clock timeutil.Clock) { + if clock != nil { + l.clock = clock + } } func (l Loader) BackfillDaily(ctx context.Context, instruments []domain.Instrument, from, to time.Time) error { @@ -59,9 +67,16 @@ func (l Loader) LatestQuote(ctx context.Context, instrumentUID string, depth int if book.ReceivedAt.IsZero() { return domain.OrderBook{}, fmt.Errorf("quote received timestamp is missing") } - age := time.Since(book.ReceivedAt) + age := l.nowUTC().Sub(book.ReceivedAt) if maxAge > 0 && age > maxAge { return domain.OrderBook{}, fmt.Errorf("quote age %s exceeds %s", age, maxAge) } return book, nil } + +func (l Loader) nowUTC() time.Time { + if l.clock == nil { + return time.Now().UTC() + } + return l.clock.Now().UTC() +} diff --git a/internal/money/money.go b/internal/money/money.go index 297764d..6c9f1ea 100644 --- a/internal/money/money.go +++ b/internal/money/money.go @@ -9,8 +9,9 @@ import ( ) var ( - ErrInvalidTick = errors.New("tick must be positive") - ErrInvalidBase = errors.New("base must be positive") + ErrInvalidTick = errors.New("tick must be positive") + ErrInvalidBase = errors.New("base must be positive") + ErrInvalidQuotation = errors.New("decimal cannot be represented as protobuf quotation") ) type RoundMode int @@ -28,7 +29,7 @@ func QuotationToDecimal(q *pb.Quotation) decimal.Decimal { return decimal.NewFromInt(q.GetUnits()).Add(decimal.New(int64(q.GetNano()), -9)) } -func DecimalToQuotation(d decimal.Decimal) *pb.Quotation { +func DecimalToQuotation(d decimal.Decimal) (*pb.Quotation, error) { units := d.Truncate(0) nano := d.Sub(units).Mul(decimal.NewFromInt(1_000_000_000)).Round(0) if nano.Equal(decimal.NewFromInt(1_000_000_000)) { @@ -41,12 +42,12 @@ func DecimalToQuotation(d decimal.Decimal) *pb.Quotation { } nanoPart := nano.IntPart() if nanoPart < -999_999_999 || nanoPart > 999_999_999 { - panic("decimal quotation nano is out of protobuf range") + return nil, ErrInvalidQuotation } return &pb.Quotation{ Units: units.IntPart(), Nano: int32(nanoPart), // #nosec G115 -- nanoPart is bounded above. - } + }, nil } func MoneyValueToDecimal(v *pb.MoneyValue) decimal.Decimal { diff --git a/internal/money/rounding_test.go b/internal/money/rounding_test.go index 3f0f6a9..e0528d6 100644 --- a/internal/money/rounding_test.go +++ b/internal/money/rounding_test.go @@ -37,3 +37,18 @@ func TestRoundToTick(t *testing.T) { } } } + +func TestDecimalToQuotationHandlesRoundingCarry(t *testing.T) { + tooPrecise := d("0.0000000005") + if _, err := DecimalToQuotation(tooPrecise); err != nil { + t.Fatalf("roundable quotation returned error: %v", err) + } + hugeNano := d("0.9999999996") + got, err := DecimalToQuotation(hugeNano) + if err != nil { + t.Fatalf("carry quotation returned error: %v", err) + } + if got.Units != 1 || got.Nano != 0 { + t.Fatalf("quotation=%+v, want carry to 1/0", got) + } +} diff --git a/internal/position/manager.go b/internal/position/manager.go index 956f3e1..822e33a 100644 --- a/internal/position/manager.go +++ b/internal/position/manager.go @@ -33,7 +33,7 @@ func (m Manager) OnEntryFill(ctx context.Context, accountIDHash string, instrume Lot: lot, AvgBuyPrice: order.AvgFillPrice, CommissionTotal: order.Commission, - Status: domain.PositionHoldingOvernight, + Status: domain.PositionEntryFilled, OpenedAt: &now, UpdatedAt: now, } diff --git a/internal/position/manager_test.go b/internal/position/manager_test.go index 475a8a0..1a89130 100644 --- a/internal/position/manager_test.go +++ b/internal/position/manager_test.go @@ -28,6 +28,9 @@ func TestOnEntryFillKeepsBuyCommission(t *testing.T) { if !pos.CommissionTotal.Equal(decimal.NewFromInt(3)) { t.Fatalf("commission=%s, want 3", pos.CommissionTotal) } + if pos.Status != domain.PositionEntryFilled { + t.Fatalf("status=%s, want ENTRY_FILLED", pos.Status) + } } func TestOnExitFillPartialUsesExecutedLots(t *testing.T) { diff --git a/internal/reconciliation/engine.go b/internal/reconciliation/engine.go index 5d60a68..bc698bf 100644 --- a/internal/reconciliation/engine.go +++ b/internal/reconciliation/engine.go @@ -13,6 +13,7 @@ import ( "overnight-trading-bot/internal/domain" "overnight-trading-bot/internal/money" "overnight-trading-bot/internal/repository" + "overnight-trading-bot/internal/timeutil" "overnight-trading-bot/internal/tinvest" ) @@ -29,6 +30,7 @@ type Engine struct { commissionTolerance decimal.Decimal requireZeroCommission bool quarantineOnNonZero bool + clock timeutil.Clock } func New(repo repository.Repository, gateway tinvest.Gateway, accountID, accountIDHash string) Engine { @@ -40,6 +42,7 @@ func New(repo repository.Repository, gateway tinvest.Gateway, accountID, account accountIDHash: accountIDHash, window: 72 * time.Hour, commissionTolerance: defaultCommissionTolerance, + clock: timeutil.RealClock{}, } } @@ -66,6 +69,13 @@ func (e Engine) WithCommissionPolicy(requireZero, quarantineOnNonZero bool, tole return e } +func (e Engine) WithClock(clock timeutil.Clock) Engine { + if clock != nil { + e.clock = clock + } + return e +} + func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) { if e.mu != nil { e.mu.Lock() @@ -79,7 +89,7 @@ func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) { if err != nil { return nil, err } - now := time.Now().UTC() + now := e.nowUTC() localByBroker := make(map[string]domain.Order, len(localOrders)) brokerByID := make(map[string]domain.Order, len(brokerOrders)) for _, order := range localOrders { @@ -175,7 +185,7 @@ func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) { } if err := e.repo.QuarantineInstrument(ctx, diff.InstrumentUID, diff.Message); err != nil { _ = e.repo.InsertRiskEvent(ctx, domain.RiskEvent{ - TS: time.Now().UTC(), + TS: now, Severity: domain.SeverityCritical, EventType: "quarantine_failed", InstrumentUID: diff.InstrumentUID, @@ -192,6 +202,13 @@ func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) { return diffs, nil } +func (e Engine) nowUTC() time.Time { + if e.clock == nil { + return time.Now().UTC() + } + return e.clock.Now().UTC() +} + func (e Engine) isInFlight(order domain.Order, now time.Time) bool { if e.inFlightGrace <= 0 || order.CreatedAt.IsZero() { return false diff --git a/internal/reconciliation/engine_test.go b/internal/reconciliation/engine_test.go index 2acbc82..8a6be71 100644 --- a/internal/reconciliation/engine_test.go +++ b/internal/reconciliation/engine_test.go @@ -171,6 +171,50 @@ func TestReconciliationSkipsFreshInFlightLocalOrders(t *testing.T) { } } +func TestReconciliationUsesInjectedClockForInFlightGrace(t *testing.T) { + ctx := context.Background() + repo := testutil.NewMemoryRepository() + gateway := tinvest.NewFakeGateway() + clock := &fixedClock{now: time.Date(2000, 1, 1, 10, 0, 0, 0, time.UTC)} + if err := repo.UpsertOrder(ctx, domain.Order{ + ClientOrderID: "fresh", + AccountIDHash: "hash", + InstrumentUID: "uid", + TradeDate: clock.now, + Side: domain.SideBuy, + OrderType: domain.OrderTypeLimit, + QuantityLots: 1, + Status: domain.OrderStatusSent, + CreatedAt: clock.now.Add(-5 * time.Second), + }); err != nil { + t.Fatal(err) + } + diffs, err := New(repo, gateway, "account", "hash"). + WithClock(clock). + WithInFlightGrace(10 * time.Second). + Run(ctx) + if err != nil { + t.Fatal(err) + } + for _, diff := range diffs { + if diff.Kind == "local_order_without_broker_id" || diff.Kind == "missing_local_order" { + t.Fatalf("fresh in-flight order produced diff: %+v", diffs) + } + } +} + +type fixedClock struct { + now time.Time +} + +func (c *fixedClock) Now() time.Time { + return c.now +} + +func (c *fixedClock) Sleep(<-chan struct{}, time.Duration) bool { + return true +} + func TestReconciliationFindsCashMismatch(t *testing.T) { ctx := context.Background() repo := testutil.NewMemoryRepository() diff --git a/internal/repository/migrations/0007_feature_q05.down.sql b/internal/repository/migrations/0007_feature_q05.down.sql new file mode 100644 index 0000000..f4d3629 --- /dev/null +++ b/internal/repository/migrations/0007_feature_q05.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE features DROP COLUMN q05_on_60_abs; + +UPDATE schema_meta SET meta_value='0006' WHERE meta_key='schema_version'; diff --git a/internal/repository/migrations/0007_feature_q05.up.sql b/internal/repository/migrations/0007_feature_q05.up.sql new file mode 100644 index 0000000..ff74871 --- /dev/null +++ b/internal/repository/migrations/0007_feature_q05.up.sql @@ -0,0 +1,4 @@ +ALTER TABLE features + ADD COLUMN q05_on_60_abs DECIMAL(20,10) NOT NULL DEFAULT 0 AFTER sigma_on_60; + +UPDATE schema_meta SET meta_value='0007' WHERE meta_key='schema_version'; diff --git a/internal/repository/mysql/repository.go b/internal/repository/mysql/repository.go index ab1e390..3d3b7e5 100644 --- a/internal/repository/mysql/repository.go +++ b/internal/repository/mysql/repository.go @@ -13,6 +13,7 @@ import ( "overnight-trading-bot/internal/domain" "overnight-trading-bot/internal/repository" + "overnight-trading-bot/internal/risk" ) var _ repository.Repository = (*Repository)(nil) @@ -224,13 +225,13 @@ ON DUPLICATE KEY UPDATE func (r *Repository) mergeFeatures(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error { _, err := r.execer().ExecContext(ctx, ` INSERT INTO features ( - instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60, + instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60, q05_on_60_abs, tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps, adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume, exit_interval_volume, calculated_at ) SELECT - ?, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60, + ?, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60, q05_on_60_abs, tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps, adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume, exit_interval_volume, calculated_at @@ -238,7 +239,7 @@ FROM features WHERE instrument_uid=? ON DUPLICATE KEY UPDATE r_on=VALUES(r_on), r_day=VALUES(r_day), mu_on_60=VALUES(mu_on_60), mu_on_252=VALUES(mu_on_252), sigma_on_60=VALUES(sigma_on_60), - tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60), + q05_on_60_abs=VALUES(q05_on_60_abs), tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60), ewma_on=VALUES(ewma_on), spread_bps=VALUES(spread_bps), half_spread_bps=VALUES(half_spread_bps), tick_bps=VALUES(tick_bps), adv_20=VALUES(adv_20), expected_cost_bps=VALUES(expected_cost_bps), @@ -385,19 +386,19 @@ func (r *Repository) UpsertFeature(ctx context.Context, feature domain.FeatureSe } _, err := sqlx.NamedExecContext(ctx, r.execer(), ` INSERT INTO features ( - instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60, + instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60, q05_on_60_abs, tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps, adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume, exit_interval_volume, calculated_at ) VALUES ( - :instrument_uid, :trade_date, :r_on, :r_day, :mu_on_60, :mu_on_252, :sigma_on_60, + :instrument_uid, :trade_date, :r_on, :r_day, :mu_on_60, :mu_on_252, :sigma_on_60, :q05_on_60_abs, :tstat_on_60, :win_on_60, :ewma_on, :spread_bps, :half_spread_bps, :tick_bps, :adv_20, :expected_cost_bps, :net_edge_bps, :entry_interval_volume, :exit_interval_volume, :calculated_at ) ON DUPLICATE KEY UPDATE r_on=VALUES(r_on), r_day=VALUES(r_day), mu_on_60=VALUES(mu_on_60), mu_on_252=VALUES(mu_on_252), sigma_on_60=VALUES(sigma_on_60), - tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60), + q05_on_60_abs=VALUES(q05_on_60_abs), tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60), ewma_on=VALUES(ewma_on), spread_bps=VALUES(spread_bps), half_spread_bps=VALUES(half_spread_bps), tick_bps=VALUES(tick_bps), adv_20=VALUES(adv_20), expected_cost_bps=VALUES(expected_cost_bps), @@ -453,7 +454,9 @@ func (r *Repository) UpsertOrder(ctx context.Context, order domain.Order) error if order.CreatedAt.IsZero() { order.CreatedAt = now } - order.UpdatedAt = now + if order.UpdatedAt.IsZero() { + order.UpdatedAt = now + } if order.RawStateJSON == "" { order.RawStateJSON = "{}" } @@ -596,6 +599,47 @@ ON DUPLICATE KEY UPDATE orders_sent=orders_sent+VALUES(orders_sent)`, dateOnly(t return err } +func (r *Repository) ReserveFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int, limit int) error { + if delta <= 0 { + return nil + } + if limit <= 0 { + return r.IncrementFreeOrders(ctx, tradeDate, instrumentUID, delta) + } + return r.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error { + txRepo, ok := repo.(*Repository) + if !ok { + return errors.New("unexpected repository implementation") + } + return txRepo.reserveFreeOrdersLocked(ctx, tradeDate, instrumentUID, delta, limit) + }) +} + +func (r *Repository) reserveFreeOrdersLocked(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int, limit int) error { + tradeDay := dateOnly(tradeDate) + if _, err := r.execer().ExecContext(ctx, ` +INSERT IGNORE INTO free_order_counters (trade_date, instrument_uid, orders_sent) +VALUES (?, ?, 0)`, tradeDay, instrumentUID); err != nil { + return err + } + var sent int + if err := r.getContext(ctx, &sent, ` +SELECT orders_sent FROM free_order_counters +WHERE trade_date=? AND instrument_uid=? +FOR UPDATE`, tradeDay, instrumentUID); err != nil { + return err + } + remaining := limit - sent + if remaining < delta { + return fmt.Errorf("%w: %s remaining=%d needed=%d", risk.ErrFreeOrderBudget, instrumentUID, remaining, delta) + } + _, err := r.execer().ExecContext(ctx, ` +UPDATE free_order_counters +SET orders_sent=orders_sent+? +WHERE trade_date=? AND instrument_uid=?`, delta, tradeDay, instrumentUID) + return err +} + func (r *Repository) GetSystemState(ctx context.Context) (domain.SystemState, bool, string, error) { var row struct { State string `db:"state"` diff --git a/internal/repository/mysql/rows.go b/internal/repository/mysql/rows.go index a5770d3..528b450 100644 --- a/internal/repository/mysql/rows.go +++ b/internal/repository/mysql/rows.go @@ -57,6 +57,7 @@ type featureRow struct { MuOn60 decimal.Decimal `db:"mu_on_60"` MuOn252 decimal.Decimal `db:"mu_on_252"` SigmaOn60 decimal.Decimal `db:"sigma_on_60"` + Q05On60Abs decimal.Decimal `db:"q05_on_60_abs"` TStatOn60 decimal.Decimal `db:"tstat_on_60"` WinOn60 decimal.Decimal `db:"win_on_60"` EWMAOn decimal.Decimal `db:"ewma_on"` @@ -80,6 +81,7 @@ func featureRowFromDomain(feature domain.FeatureSet) featureRow { MuOn60: feature.MuOn60, MuOn252: feature.MuOn252, SigmaOn60: feature.SigmaOn60, + Q05On60Abs: feature.Q05On60Abs, TStatOn60: feature.TStatOn60, WinOn60: feature.WinOn60, EWMAOn: feature.EWMAOn, @@ -104,6 +106,7 @@ func (r featureRow) domain() domain.FeatureSet { MuOn60: r.MuOn60, MuOn252: r.MuOn252, SigmaOn60: r.SigmaOn60, + Q05On60Abs: r.Q05On60Abs, TStatOn60: r.TStatOn60, WinOn60: r.WinOn60, EWMAOn: r.EWMAOn, diff --git a/internal/repository/repository.go b/internal/repository/repository.go index d90b173..911cc4a 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -38,6 +38,7 @@ type Repository interface { InsertRiskEvent(ctx context.Context, event domain.RiskEvent) error GetFreeOrdersSent(ctx context.Context, tradeDate time.Time, instrumentUID string) (int, error) IncrementFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int) error + ReserveFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int, limit int) error GetSystemState(ctx context.Context) (domain.SystemState, bool, string, error) SaveSystemState(ctx context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, contextJSON string) error diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index c5451b2..adb5b72 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -130,6 +130,16 @@ func (s *Scheduler) Run(ctx context.Context) error { } } +func (s Scheduler) GracefulShutdown(ctx context.Context) error { + if s.svc.Repo == nil || s.svc.Execution == nil { + return nil + } + if err := s.cancelActiveOrders(ctx, domain.SideBuy, domain.OrderStatusCancelled, "shutdown_cancel_active_orders"); err != nil { + return err + } + return s.cancelActiveOrders(ctx, domain.SideSell, domain.OrderStatusCancelled, "shutdown_cancel_active_orders") +} + func (s *Scheduler) Step(ctx context.Context) error { if err := s.checkInfrastructure(ctx); err != nil { return err @@ -370,7 +380,7 @@ func (s Scheduler) sizeSignal(portfolio domain.Portfolio, instrument domain.Inst Lot: instrument.Lot, EntryIntervalVolume: feature.EntryIntervalVolume, ExitIntervalVolume: feature.ExitIntervalVolume, - Q05OvernightAbs: money.Abs(feature.SigmaOn60).Mul(decimal.NewFromFloat(1.65)), + Q05OvernightAbs: feature.Q05On60Abs, }), nil } @@ -544,7 +554,7 @@ func (s *Scheduler) monitorEntryOrders(ctx context.Context, now time.Time) error return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge) }, RepostCheck: func(ctx context.Context, order domain.Order, instrument domain.Instrument, book domain.OrderBook) error { - return s.repostPreTradeCheck(ctx, now, order, instrument, book) + return s.repostPreTradeCheck(ctx, s.nowUTC().In(s.cfg.Location), order, instrument, book) }, }) if err != nil { @@ -570,9 +580,31 @@ func (s *Scheduler) holdOvernight(ctx context.Context) error { if err := s.closeEntryWindow(ctx); err != nil { return err } + if err := s.promoteEntryFilledPositions(ctx); err != nil { + return err + } return s.periodicReconcile(ctx) } +func (s Scheduler) promoteEntryFilledPositions(ctx context.Context) error { + positionsList, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash) + if err != nil { + return err + } + now := s.nowUTC() + for _, pos := range positionsList { + if pos.Status != domain.PositionEntryFilled { + continue + } + pos.Status = domain.PositionHoldingOvernight + pos.UpdatedAt = now + if err := s.svc.Repo.UpsertPosition(ctx, pos); err != nil { + return err + } + } + return nil +} + func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error { if err := s.transitionTo(ctx, domain.StatePlaceExitOrders); err != nil { return err @@ -694,7 +726,7 @@ func (s *Scheduler) monitorExitOrders(ctx context.Context, now time.Time) error return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge) }, RepostCheck: func(ctx context.Context, order domain.Order, instrument domain.Instrument, book domain.OrderBook) error { - return s.repostPreTradeCheck(ctx, now, order, instrument, book) + return s.repostPreTradeCheck(ctx, s.nowUTC().In(s.cfg.Location), order, instrument, book) }, }) if err != nil { @@ -1080,7 +1112,7 @@ func (s *Scheduler) failOpenPositionsAtHardDeadline(ctx context.Context) error { now := s.nowUTC() for _, pos := range positionsList { switch pos.Status { - case domain.PositionHoldingOvernight, domain.PositionExitPartiallyFilled, domain.PositionExitOrderSent: + case domain.PositionEntryFilled, domain.PositionHoldingOvernight, domain.PositionExitPartiallyFilled, domain.PositionExitOrderSent: pos.Status = domain.PositionExitFailed pos.UpdatedAt = now if err := s.svc.Repo.UpsertPosition(ctx, pos); err != nil { diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index 52fcad6..2aa1015 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -650,6 +650,49 @@ func TestPlaceExitUsesCurrentTradeDateForOrderAndFreeCounter(t *testing.T) { } } +func TestGracefulShutdownCancelsActiveOrders(t *testing.T) { + ctx := context.Background() + repo := testutil.NewMemoryRepository() + gateway := tinvest.NewFakeGateway() + tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC) + order := domain.Order{ + ClientOrderID: "shutdown-order", + BrokerOrderID: "broker-shutdown-order", + AccountIDHash: "hash", + InstrumentUID: "uid", + TradeDate: tradeDate, + Side: domain.SideBuy, + OrderType: domain.OrderTypeLimit, + LimitPrice: decimal.NewFromInt(100), + QuantityLots: 1, + Status: domain.OrderStatusSent, + RawStateJSON: "{}", + } + if err := repo.UpsertOrder(ctx, order); err != nil { + t.Fatal(err) + } + gateway.Orders[order.BrokerOrderID] = order + execEngine := execution.NewEngine(domain.ModeSandbox, "account", gateway, repo) + s := Scheduler{ + cfg: Config{Mode: domain.ModeSandbox}, + svc: Services{ + Repo: repo, + Execution: &execEngine, + AccountIDHash: "hash", + }, + } + if err := s.GracefulShutdown(ctx); err != nil { + t.Fatal(err) + } + orders, err := repo.ListOrders(ctx, "hash", tradeDate, tradeDate) + if err != nil { + t.Fatal(err) + } + if len(orders) != 1 || orders[0].Status != domain.OrderStatusCancelled { + t.Fatalf("orders=%+v, want cancelled", orders) + } +} + func mustTOD(raw string) timeutil.TimeOfDay { tod, err := timeutil.ParseTimeOfDay(raw) if err != nil { diff --git a/internal/testutil/memory_repository.go b/internal/testutil/memory_repository.go index 6297a3e..773035f 100644 --- a/internal/testutil/memory_repository.go +++ b/internal/testutil/memory_repository.go @@ -9,6 +9,7 @@ import ( "overnight-trading-bot/internal/domain" "overnight-trading-bot/internal/repository" + "overnight-trading-bot/internal/risk" ) var _ repository.Repository = (*MemoryRepository)(nil) @@ -263,6 +264,20 @@ func (r *MemoryRepository) IncrementFreeOrders(_ context.Context, tradeDate time return nil } +func (r *MemoryRepository) ReserveFreeOrders(_ context.Context, tradeDate time.Time, instrumentUID string, delta int, limit int) error { + r.mu.Lock() + defer r.mu.Unlock() + if delta <= 0 { + return nil + } + key := featureKey(instrumentUID, tradeDate) + if limit > 0 && r.FreeOrders[key]+delta > limit { + return risk.ErrFreeOrderBudget + } + r.FreeOrders[key] += delta + return nil +} + func (r *MemoryRepository) GetSystemState(_ context.Context) (domain.SystemState, bool, string, error) { r.mu.Lock() defer r.mu.Unlock() diff --git a/internal/tinvest/real.go b/internal/tinvest/real.go index 530a487..14f0fc6 100644 --- a/internal/tinvest/real.go +++ b/internal/tinvest/real.go @@ -177,11 +177,15 @@ func (g *RealGateway) PostLimitOrder(ctx context.Context, accountID, instrumentU if side == domain.SideSell { direction = pb.OrderDirection_ORDER_DIRECTION_SELL } + quotation, err := money.DecimalToQuotation(price) + if err != nil { + return domain.Order{}, err + } resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) { return g.orders.PostOrder(&investgo.PostOrderRequest{ InstrumentId: instrumentUID, Quantity: lots, - Price: money.DecimalToQuotation(price), + Price: quotation, Direction: direction, AccountId: accountID, OrderType: pb.OrderType_ORDER_TYPE_LIMIT, diff --git a/internal/tinvest/sandbox.go b/internal/tinvest/sandbox.go index 117fd8f..c968650 100644 --- a/internal/tinvest/sandbox.go +++ b/internal/tinvest/sandbox.go @@ -39,11 +39,15 @@ func (g *SandboxGateway) PostLimitOrder(ctx context.Context, accountID, instrume if side == domain.SideSell { direction = pb.OrderDirection_ORDER_DIRECTION_SELL } + quotation, err := money.DecimalToQuotation(price) + if err != nil { + return domain.Order{}, err + } resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) { return g.sandbox.PostSandboxOrder(&investgo.PostOrderRequest{ InstrumentId: instrumentUID, Quantity: lots, - Price: money.DecimalToQuotation(price), + Price: quotation, Direction: direction, AccountId: accountID, OrderType: pb.OrderType_ORDER_TYPE_LIMIT,