fifth version

This commit is contained in:
2026-06-08 09:03:37 +00:00
parent b9efa98758
commit 2d57c4ff1f
26 changed files with 896 additions and 159 deletions
+53 -19
View File
@@ -49,7 +49,7 @@ func run() error {
defer func() {
_ = file.Close()
}()
candles, err := backtest.LoadCandlesCSV(file)
candles, metadata, err := backtest.LoadCandlesCSVWithMetadata(file)
if err != nil {
return fmt.Errorf("load candles: %w", err)
}
@@ -62,10 +62,12 @@ func run() error {
defer func() {
_ = minuteFile.Close()
}()
minuteCandles, err = backtest.LoadCandlesCSV(minuteFile)
var minuteMetadata map[string]backtest.InstrumentMetadata
minuteCandles, minuteMetadata, err = backtest.LoadCandlesCSVWithMetadata(minuteFile)
if err != nil {
return fmt.Errorf("load minute candles: %w", err)
}
mergeMetadata(metadata, minuteMetadata)
}
if *useMinuteModel && len(minuteCandles) == 0 {
return fmt.Errorf("-minute-candles is required when -use-minute-model=true")
@@ -114,24 +116,27 @@ func run() error {
if err != nil {
return fmt.Errorf("max tick: %w", err)
}
lotsByInstrument, ticksByInstrument := metadataMaps(metadata)
engine := backtest.New(backtest.Config{
EntrySlippageBps: entry,
ExitSlippageBps: exit,
CommissionRoundtripBps: comm,
RiskBufferBps: riskBuf,
OutputDir: *outputDir,
RollingShort: *rollingShort,
RollingLong: *rollingLong,
EWMALambda: *ewmaLambda,
MinTStat60: tstat,
MinWinRate60: winRate,
MinNetEdgeBps: netEdge,
MinADVRUB: adv,
MaxSpreadBps: spread,
MaxTickBps: tick,
AssumedSpreadBps: assumed,
RequireZeroCommission: requireZeroCommission,
UseMinuteModel: *useMinuteModel,
EntrySlippageBps: entry,
ExitSlippageBps: exit,
CommissionRoundtripBps: comm,
RiskBufferBps: riskBuf,
OutputDir: *outputDir,
RollingShort: *rollingShort,
RollingLong: *rollingLong,
EWMALambda: *ewmaLambda,
MinTStat60: tstat,
MinWinRate60: winRate,
MinNetEdgeBps: netEdge,
MinADVRUB: adv,
MaxSpreadBps: spread,
MaxTickBps: tick,
AssumedSpreadBps: assumed,
RequireZeroCommission: requireZeroCommission,
LotsByInstrument: lotsByInstrument,
MinPriceIncrementsByInstrument: ticksByInstrument,
UseMinuteModel: *useMinuteModel,
})
result, err := engine.RunWithMinuteCandles(candles, minuteCandles)
if err != nil {
@@ -143,3 +148,32 @@ func run() error {
fmt.Printf("backtest complete: trades=%d total_return=%.6f\n", result.Metrics.NumberOfTrades, result.Metrics.TotalReturn)
return nil
}
func mergeMetadata(dst, src map[string]backtest.InstrumentMetadata) {
for uid, meta := range src {
current := dst[uid]
if current.Lot <= 0 {
current.Lot = meta.Lot
}
if !current.MinPriceIncrement.IsPositive() {
current.MinPriceIncrement = meta.MinPriceIncrement
}
if current.Lot > 0 || current.MinPriceIncrement.IsPositive() {
dst[uid] = current
}
}
}
func metadataMaps(metadata map[string]backtest.InstrumentMetadata) (map[string]int64, map[string]decimal.Decimal) {
lots := make(map[string]int64)
ticks := make(map[string]decimal.Decimal)
for uid, meta := range metadata {
if meta.Lot > 0 {
lots[uid] = meta.Lot
}
if meta.MinPriceIncrement.IsPositive() {
ticks[uid] = meta.MinPriceIncrement
}
}
return lots, ticks
}
+20 -4
View File
@@ -114,10 +114,12 @@ func Run(ctx context.Context, opts Options) error {
defer closer()
}
accountIDHash := accountHash(cfg.TInvest.AccountID)
clock := timeutil.RealClock{Loc: cfg.Location}
recon := reconciliation.New(repo, gateway, cfg.TInvest.AccountID, accountIDHash).
WithWindow(time.Duration(cfg.Risk.ReconciliationWindowHours)*time.Hour).
WithInFlightGrace(time.Duration(cfg.Risk.ReconciliationSkewSec)*time.Second).
WithCommissionPolicy(cfg.Commission.RequireZeroCommission, cfg.Commission.QuarantineOnNonZero, cfg.Risk.CommissionToleranceRUB)
WithCommissionPolicy(cfg.Commission.RequireZeroCommission, cfg.Commission.QuarantineOnNonZero, cfg.Risk.CommissionToleranceRUB).
WithClock(clock)
diffs, err := recon.Run(ctx)
if err != nil {
return fmt.Errorf("pre-unhalt reconciliation: %w", err)
@@ -159,10 +161,12 @@ func Run(ctx context.Context, opts Options) error {
}()
accountIDHash := accountHash(cfg.TInvest.AccountID)
clock := timeutil.RealClock{Loc: cfg.Location}
recon := reconciliation.New(repo, gateway, cfg.TInvest.AccountID, accountIDHash).
WithWindow(time.Duration(cfg.Risk.ReconciliationWindowHours)*time.Hour).
WithInFlightGrace(time.Duration(cfg.Risk.ReconciliationSkewSec)*time.Second).
WithCommissionPolicy(cfg.Commission.RequireZeroCommission, cfg.Commission.QuarantineOnNonZero, cfg.Risk.CommissionToleranceRUB)
WithCommissionPolicy(cfg.Commission.RequireZeroCommission, cfg.Commission.QuarantineOnNonZero, cfg.Risk.CommissionToleranceRUB).
WithClock(clock)
sm := statemachine.New(repo, cfg.App.Mode)
if _, err := sm.Recover(ctx, recon); err != nil {
_ = notifier.Alert(ctx, fmt.Sprintf("state recovery failed: %s", err))
@@ -184,14 +188,25 @@ func Run(ctx context.Context, opts Options) error {
}
runCtx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer stop()
clock := timeutil.RealClock{Loc: cfg.Location}
runtime := buildScheduler(clock, sm, cfg, repo, gateway, notifier, recon, accountIDHash, log)
return runtime.Run(runCtx)
if err := runtime.Run(runCtx); err != nil {
if runCtx.Err() == nil {
return err
}
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.App.ShutdownTimeoutSec)*time.Second)
defer cancel()
if shutdownErr := runtime.GracefulShutdown(shutdownCtx); shutdownErr != nil {
return fmt.Errorf("%w; graceful shutdown: %v", err, shutdownErr)
}
return err
}
return nil
}
func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Config, repo *mysqlrepo.Repository, gateway tinvest.Gateway, notifier notify.Notifier, recon reconciliation.Engine, accountIDHash string, log *slog.Logger) scheduler.Scheduler {
registry := instruments.NewRegistry(repo, gateway)
loader := marketdata.NewLoader(repo, gateway)
loader.SetClock(clock)
pipeline := features.NewPipeline(repo, features.PipelineConfig{
RollingShort: cfg.Strategy.RollingShort,
RollingLong: cfg.Strategy.RollingLong,
@@ -243,6 +258,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con
MaxQuoteAge: time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second,
})
execEngine := execution.NewEngine(cfg.App.Mode, cfg.TInvest.AccountID, gateway, repo)
execEngine.SetClock(clock)
execEngine.SetMaxQuoteAge(time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second)
execEngine.SetFreeOrderCountPolicy(cfg.Commission.FreeOrderCountPolicy)
services := scheduler.Services{
+197 -61
View File
@@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
@@ -20,40 +21,48 @@ import (
)
type Config struct {
EntrySlippageBps decimal.Decimal
ExitSlippageBps decimal.Decimal
CommissionRoundtripBps decimal.Decimal
RiskBufferBps decimal.Decimal
InitialEquity decimal.Decimal
OutputDir string
RollingShort int
RollingLong int
EWMALambda float64
MinTStat60 decimal.Decimal
MinWinRate60 decimal.Decimal
MinNetEdgeBps decimal.Decimal
MinADVRUB decimal.Decimal
MaxSpreadBps decimal.Decimal
MaxSpreadBpsMoneyMarket decimal.Decimal
MaxSpreadBpsBondFunds decimal.Decimal
MaxSpreadBpsEquityFunds decimal.Decimal
MaxTickBps decimal.Decimal
RequireZeroCommission *bool
MaxPositions int
MaxPositionPct decimal.Decimal
MaxTotalExposurePct decimal.Decimal
MaxParticipationRate decimal.Decimal
CashUsageBuffer decimal.Decimal
RiskBudgetPct decimal.Decimal
MinOrderNotionalRUB decimal.Decimal
AssumedSpreadBps decimal.Decimal
AssumedSpreadBpsByFundType map[string]decimal.Decimal
InstrumentFundTypes map[string]string
AssumedTickBps decimal.Decimal
Lot int64
UseMinuteModel bool
EntryWindow TimeWindow
ExitWindow TimeWindow
EntrySlippageBps decimal.Decimal
ExitSlippageBps decimal.Decimal
CommissionRoundtripBps decimal.Decimal
RiskBufferBps decimal.Decimal
InitialEquity decimal.Decimal
OutputDir string
RollingShort int
RollingLong int
EWMALambda float64
MinTStat60 decimal.Decimal
MinWinRate60 decimal.Decimal
MinNetEdgeBps decimal.Decimal
MinADVRUB decimal.Decimal
MaxSpreadBps decimal.Decimal
MaxSpreadBpsMoneyMarket decimal.Decimal
MaxSpreadBpsBondFunds decimal.Decimal
MaxSpreadBpsEquityFunds decimal.Decimal
MaxTickBps decimal.Decimal
RequireZeroCommission *bool
MaxPositions int
MaxPositionPct decimal.Decimal
MaxTotalExposurePct decimal.Decimal
MaxParticipationRate decimal.Decimal
CashUsageBuffer decimal.Decimal
RiskBudgetPct decimal.Decimal
MinOrderNotionalRUB decimal.Decimal
AssumedSpreadBps decimal.Decimal
AssumedSpreadBpsByFundType map[string]decimal.Decimal
InstrumentFundTypes map[string]string
AssumedTickBps decimal.Decimal
Lot int64
LotsByInstrument map[string]int64
MinPriceIncrement decimal.Decimal
MinPriceIncrementsByInstrument map[string]decimal.Decimal
UseMinuteModel bool
EntryWindow TimeWindow
ExitWindow TimeWindow
}
type InstrumentMetadata struct {
Lot int64
MinPriceIncrement decimal.Decimal
}
type TimeWindow struct {
@@ -203,6 +212,9 @@ func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Can
return Result{}, err
}
if ok {
if capacity := e.windowCapacity(candidate, preparedMinutes[instrumentUID]); capacity.IsPositive() {
candidate.capacity = capacity
}
candidatesByExitDate[candidate.exit.TradeDate.Format("2006-01-02")] = append(candidatesByExitDate[candidate.exit.TradeDate.Format("2006-01-02")], candidate)
}
}
@@ -243,7 +255,7 @@ func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Can
Portfolio: domain.Portfolio{Equity: equity, Cash: cash},
SelectedInstruments: len(dayCandidates),
LimitPrice: c.buy,
Lot: e.cfg.Lot,
Lot: c.lot,
EntryIntervalVolume: c.adv,
ExitIntervalVolume: c.adv,
Q05OvernightAbs: c.q05Abs,
@@ -261,7 +273,7 @@ func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Can
lots = executedLots
capacity = minuteCapacity
}
notional := c.buy.Mul(decimal.NewFromInt(lots)).Mul(decimal.NewFromInt(e.cfg.Lot))
notional := c.buy.Mul(decimal.NewFromInt(lots)).Mul(decimal.NewFromInt(c.lot))
ret := c.sell.Div(c.buy).Sub(decimal.NewFromInt(1)).Sub(money.FromBps(e.cfg.CommissionRoundtripBps))
pnl := notional.Mul(ret)
dayPnL = dayPnL.Add(pnl)
@@ -311,8 +323,12 @@ func (e Engine) minuteExecution(c candidate, minutes []domain.Candle, requestedL
if requestedLots <= 0 || len(minutes) == 0 {
return 0, decimal.Zero, false
}
entryLots, entryCapacity := e.fillableMinuteLots(minutes, c.entry.TradeDate, c.buy, domain.SideBuy, e.cfg.EntryWindow)
exitLots, exitCapacity := e.fillableMinuteLots(minutes, c.exit.TradeDate, c.sell, domain.SideSell, e.cfg.ExitWindow)
lot := c.lot
if lot <= 0 {
lot = e.lotFor(c.instrumentUID)
}
entryLots, entryCapacity := e.fillableMinuteLots(minutes, c.entry.TradeDate, c.buy, lot, domain.SideBuy, e.cfg.EntryWindow)
exitLots, exitCapacity := e.fillableMinuteLots(minutes, c.exit.TradeDate, c.sell, lot, domain.SideSell, e.cfg.ExitWindow)
lots := min(requestedLots, entryLots)
lots = min(lots, exitLots)
if lots <= 0 {
@@ -321,11 +337,11 @@ func (e Engine) minuteExecution(c candidate, minutes []domain.Candle, requestedL
return lots, money.Min(entryCapacity, exitCapacity), true
}
func (e Engine) fillableMinuteLots(minutes []domain.Candle, date time.Time, limitPrice decimal.Decimal, side domain.Side, window TimeWindow) (int64, decimal.Decimal) {
if !limitPrice.IsPositive() || e.cfg.Lot <= 0 {
func (e Engine) fillableMinuteLots(minutes []domain.Candle, date time.Time, limitPrice decimal.Decimal, lot int64, side domain.Side, window TimeWindow) (int64, decimal.Decimal) {
if !limitPrice.IsPositive() || lot <= 0 {
return 0, decimal.Zero
}
lotNotional := limitPrice.Mul(decimal.NewFromInt(e.cfg.Lot))
lotNotional := limitPrice.Mul(decimal.NewFromInt(lot))
if !lotNotional.IsPositive() {
return 0, decimal.Zero
}
@@ -348,6 +364,36 @@ func (e Engine) fillableMinuteLots(minutes []domain.Candle, date time.Time, limi
return capacity.Div(lotNotional).Floor().IntPart(), capacity
}
func (e Engine) windowCapacity(c candidate, minutes []domain.Candle) decimal.Decimal {
if len(minutes) == 0 {
return decimal.Zero
}
lot := c.lot
if lot <= 0 {
lot = e.lotFor(c.instrumentUID)
}
if lot <= 0 {
return decimal.Zero
}
entryVolume := e.windowNotional(minutes, c.entry.TradeDate, e.cfg.EntryWindow, lot)
exitVolume := e.windowNotional(minutes, c.exit.TradeDate, e.cfg.ExitWindow, lot)
if !entryVolume.IsPositive() || !exitVolume.IsPositive() {
return decimal.Zero
}
return money.Min(entryVolume, exitVolume).Mul(e.cfg.MaxParticipationRate)
}
func (e Engine) windowNotional(minutes []domain.Candle, date time.Time, window TimeWindow, lot int64) decimal.Decimal {
total := decimal.Zero
for _, candle := range minutes {
if !sameDate(candle.TradeDate, date) || !window.Contains(candle.TradeDate) {
continue
}
total = total.Add(candle.VolumeLots.Mul(decimal.NewFromInt(lot)).Mul(candle.Close))
}
return total
}
func (w TimeWindow) Contains(ts time.Time) bool {
if w.Start == 0 && w.End == 0 {
return true
@@ -376,12 +422,14 @@ type candidate struct {
q05Abs decimal.Decimal
overnightGap decimal.Decimal
capacity decimal.Decimal
lot int64
}
func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, exitIndex int) (candidate, bool, error) {
if exitIndex < e.cfg.RollingShort || exitIndex <= 0 {
return candidate{}, false, nil
}
lot := e.lotFor(instrumentUID)
history := candles[:exitIndex]
returns := make([]float64, 0, len(history)-1)
for j := 1; j < len(history); j++ {
@@ -405,7 +453,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
Add(e.cfg.CommissionRoundtripBps).
Add(e.cfg.RiskBufferBps)
netEdge := rawEdge.Sub(cost)
adv := features.ADV(history, e.cfg.Lot, 20)
adv := features.ADV(history, lot, 20)
switch {
case e.requireZeroCommission() && e.cfg.CommissionRoundtripBps.IsPositive():
return candidate{}, false, nil
@@ -428,6 +476,17 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
exit := candles[exitIndex]
buy := entry.Close.Mul(decimal.NewFromInt(1).Add(money.FromBps(e.cfg.EntrySlippageBps)))
sell := exit.Open.Mul(decimal.NewFromInt(1).Sub(money.FromBps(e.cfg.ExitSlippageBps)))
if tick := e.minPriceIncrementFor(instrumentUID); tick.IsPositive() {
var err error
buy, err = money.RoundToTick(buy, tick, money.RoundCeil)
if err != nil {
return candidate{}, false, err
}
sell, err = money.RoundToTick(sell, tick, money.RoundFloor)
if err != nil {
return candidate{}, false, err
}
}
gap, err := features.OvernightReturn(exit.Open, entry.Close)
if err != nil {
return candidate{}, false, err
@@ -448,9 +507,31 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
q05Abs: q05Abs,
overnightGap: gap,
capacity: adv.Mul(e.cfg.MaxParticipationRate),
lot: lot,
}, true, nil
}
func (e Engine) lotFor(instrumentUID string) int64 {
if e.cfg.LotsByInstrument != nil {
if lot := e.cfg.LotsByInstrument[instrumentUID]; lot > 0 {
return lot
}
}
if e.cfg.Lot > 0 {
return e.cfg.Lot
}
return 1
}
func (e Engine) minPriceIncrementFor(instrumentUID string) decimal.Decimal {
if e.cfg.MinPriceIncrementsByInstrument != nil {
if tick := e.cfg.MinPriceIncrementsByInstrument[instrumentUID]; tick.IsPositive() {
return tick
}
}
return e.cfg.MinPriceIncrement
}
func (e Engine) requireZeroCommission() bool {
return e.cfg.RequireZeroCommission != nil && *e.cfg.RequireZeroCommission
}
@@ -536,46 +617,60 @@ func (r Result) Write(outputDir string) error {
}
func LoadCandlesCSV(reader io.Reader) (map[string][]domain.Candle, error) {
candles, _, err := LoadCandlesCSVWithMetadata(reader)
return candles, err
}
func LoadCandlesCSVWithMetadata(reader io.Reader) (map[string][]domain.Candle, map[string]InstrumentMetadata, error) {
r := csv.NewReader(reader)
r.FieldsPerRecord = -1
records, err := r.ReadAll()
if err != nil {
return nil, err
return nil, nil, err
}
out := make(map[string][]domain.Candle)
for i, record := range records {
if i == 0 && len(record) > 0 && record[0] == "instrument_uid" {
continue
metadata := make(map[string]InstrumentMetadata)
header := map[string]int(nil)
start := 0
if len(records) > 0 && len(records[0]) > 0 && strings.EqualFold(strings.TrimSpace(records[0][0]), "instrument_uid") {
header = make(map[string]int, len(records[0]))
for i, name := range records[0] {
header[strings.ToLower(strings.TrimSpace(name))] = i
}
start = 1
}
for i := start; i < len(records); i++ {
record := records[i]
if len(record) < 7 {
return nil, fmt.Errorf("line %d: expected 7 fields", i+1)
return nil, nil, fmt.Errorf("line %d: expected at least 7 fields", i+1)
}
date, err := parseCandleTime(record[1])
instrumentUID := csvValue(record, header, "instrument_uid", 0)
date, err := parseCandleTime(csvValue(record, header, "trade_date", 1))
if err != nil {
return nil, err
return nil, nil, err
}
open, err := decimal.NewFromString(record[2])
open, err := decimal.NewFromString(csvValue(record, header, "open", 2))
if err != nil {
return nil, err
return nil, nil, err
}
high, err := decimal.NewFromString(record[3])
high, err := decimal.NewFromString(csvValue(record, header, "high", 3))
if err != nil {
return nil, err
return nil, nil, err
}
low, err := decimal.NewFromString(record[4])
low, err := decimal.NewFromString(csvValue(record, header, "low", 4))
if err != nil {
return nil, err
return nil, nil, err
}
closePrice, err := decimal.NewFromString(record[5])
closePrice, err := decimal.NewFromString(csvValue(record, header, "close", 5))
if err != nil {
return nil, err
return nil, nil, err
}
volume, err := decimal.NewFromString(record[6])
volume, err := decimal.NewFromString(csvValue(record, header, "volume_lots", 6))
if err != nil {
return nil, err
return nil, nil, err
}
candle := domain.Candle{
InstrumentUID: record[0],
InstrumentUID: instrumentUID,
TradeDate: date,
Open: open,
High: high,
@@ -586,8 +681,49 @@ func LoadCandlesCSV(reader io.Reader) (map[string][]domain.Candle, error) {
LoadedAt: time.Now().UTC(),
}
out[candle.InstrumentUID] = append(out[candle.InstrumentUID], candle)
meta := metadata[candle.InstrumentUID]
if raw, ok := optionalCSVValue(record, header, "lot", 7); ok && strings.TrimSpace(raw) != "" {
lot, err := strconv.ParseInt(strings.TrimSpace(raw), 10, 64)
if err != nil {
return nil, nil, fmt.Errorf("line %d: parse lot: %w", i+1, err)
}
if lot > 0 {
meta.Lot = lot
}
}
if raw, ok := optionalCSVValue(record, header, "min_price_increment", 8); ok && strings.TrimSpace(raw) != "" {
tick, err := decimal.NewFromString(strings.TrimSpace(raw))
if err != nil {
return nil, nil, fmt.Errorf("line %d: parse min_price_increment: %w", i+1, err)
}
if tick.IsPositive() {
meta.MinPriceIncrement = tick
}
}
if meta.Lot > 0 || meta.MinPriceIncrement.IsPositive() {
metadata[candle.InstrumentUID] = meta
}
}
return out, nil
return out, metadata, nil
}
func csvValue(record []string, header map[string]int, name string, fallback int) string {
value, _ := optionalCSVValue(record, header, name, fallback)
return strings.TrimSpace(value)
}
func optionalCSVValue(record []string, header map[string]int, name string, fallback int) (string, bool) {
if header != nil {
idx, ok := header[name]
if !ok || idx < 0 || idx >= len(record) {
return "", false
}
return record[idx], true
}
if fallback < 0 || fallback >= len(record) {
return "", false
}
return record[fallback], true
}
func parseCandleTime(raw string) (time.Time, error) {
+84
View File
@@ -1,6 +1,7 @@
package backtest
import (
"strings"
"testing"
"time"
@@ -69,3 +70,86 @@ func TestMinuteExecutionRequiresReachableLimitAndParticipation(t *testing.T) {
t.Fatal("sell limit should be unreachable")
}
}
func TestEvaluateCandidateUsesInstrumentLotAndTick(t *testing.T) {
engine := New(Config{
RollingShort: 2,
RollingLong: 2,
MinTStat60: decimal.NewFromInt(-1),
MinWinRate60: decimal.NewFromFloat(0.1),
MinNetEdgeBps: decimal.NewFromInt(-1000),
MinADVRUB: decimal.NewFromInt(1),
Lot: 1,
LotsByInstrument: map[string]int64{"uid": 10},
MinPriceIncrementsByInstrument: map[string]decimal.Decimal{"uid": decimal.NewFromFloat(0.05)},
EntrySlippageBps: decimal.NewFromInt(13),
ExitSlippageBps: decimal.NewFromInt(13),
})
candles := candidateCandles("uid")
got, ok, err := engine.evaluateCandidate("uid", candles, 3)
if err != nil {
t.Fatal(err)
}
if !ok {
t.Fatal("expected candidate")
}
if got.lot != 10 {
t.Fatalf("lot=%d, want 10", got.lot)
}
if !got.adv.Equal(decimal.NewFromInt(10_000)) {
t.Fatalf("adv=%s, want 10000", got.adv)
}
if !got.buy.Equal(decimal.NewFromFloat(100.15)) {
t.Fatalf("buy=%s, want rounded 100.15", got.buy)
}
if !got.sell.Equal(decimal.NewFromFloat(104.85)) {
t.Fatalf("sell=%s, want rounded 104.85", got.sell)
}
}
func TestWindowCapacityUsesMinuteEntryAndExitWindows(t *testing.T) {
engine := New(Config{
Lot: 10,
MaxParticipationRate: decimal.NewFromFloat(0.10),
})
entryDate := time.Date(2024, 1, 2, 18, 25, 0, 0, time.UTC)
exitDate := time.Date(2024, 1, 3, 10, 5, 0, 0, time.UTC)
got := engine.windowCapacity(candidate{
instrumentUID: "uid",
entry: domain.Candle{TradeDate: entryDate},
exit: domain.Candle{TradeDate: exitDate},
}, []domain.Candle{
{TradeDate: entryDate, Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(20)},
{TradeDate: exitDate, Close: decimal.NewFromInt(200), VolumeLots: decimal.NewFromInt(5)},
{TradeDate: time.Date(2024, 1, 3, 12, 0, 0, 0, time.UTC), Close: decimal.NewFromInt(999), VolumeLots: decimal.NewFromInt(999)},
})
if !got.Equal(decimal.NewFromInt(1000)) {
t.Fatalf("capacity=%s, want min(entry=20000, exit=10000)*0.10 = 1000", got)
}
}
func TestLoadCandlesCSVWithMetadata(t *testing.T) {
raw := strings.NewReader(`instrument_uid,trade_date,open,high,low,close,volume_lots,lot,min_price_increment
uid,2024-01-02,100,101,99,100,10,10,0.05
`)
candles, metadata, err := LoadCandlesCSVWithMetadata(raw)
if err != nil {
t.Fatal(err)
}
if len(candles["uid"]) != 1 {
t.Fatalf("candles=%+v", candles)
}
if metadata["uid"].Lot != 10 || !metadata["uid"].MinPriceIncrement.Equal(decimal.NewFromFloat(0.05)) {
t.Fatalf("metadata=%+v", metadata["uid"])
}
}
func candidateCandles(uid string) []domain.Candle {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
return []domain.Candle{
{InstrumentUID: uid, TradeDate: start, Open: decimal.NewFromInt(100), Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)},
{InstrumentUID: uid, TradeDate: start.AddDate(0, 0, 1), Open: decimal.NewFromInt(101), Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)},
{InstrumentUID: uid, TradeDate: start.AddDate(0, 0, 2), Open: decimal.NewFromInt(102), Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)},
{InstrumentUID: uid, TradeDate: start.AddDate(0, 0, 3), Open: decimal.NewFromInt(105), Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)},
}
}
+1
View File
@@ -175,6 +175,7 @@ type FeatureSet struct {
MuOn60 decimal.Decimal
MuOn252 decimal.Decimal
SigmaOn60 decimal.Decimal
Q05On60Abs decimal.Decimal
TStatOn60 decimal.Decimal
WinOn60 decimal.Decimal
EWMAOn decimal.Decimal
+71 -47
View File
@@ -12,6 +12,7 @@ import (
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/repository"
"overnight-trading-bot/internal/risk"
"overnight-trading-bot/internal/timeutil"
)
var ErrBrokerOrdersDisabled = errors.New("broker orders are disabled for current mode")
@@ -35,6 +36,7 @@ type Engine struct {
store repository.Repository
maxQuoteAge time.Duration
freeOrderCountPolicy string
clock timeutil.Clock
mu sync.Map
}
@@ -50,13 +52,26 @@ type MonitorConfig struct {
}
func NewEngine(mode domain.Mode, accountID string, gateway Gateway, store repository.Repository) Engine {
return Engine{mode: mode, accountID: accountID, gateway: gateway, store: store, freeOrderCountPolicy: FreeOrderPolicySubmitted}
return Engine{
mode: mode,
accountID: accountID,
gateway: gateway,
store: store,
freeOrderCountPolicy: FreeOrderPolicySubmitted,
clock: timeutil.RealClock{},
}
}
func (e *Engine) SetMaxQuoteAge(maxQuoteAge time.Duration) {
e.maxQuoteAge = maxQuoteAge
}
func (e *Engine) SetClock(clock timeutil.Clock) {
if clock != nil {
e.clock = clock
}
}
func (e *Engine) SetFreeOrderCountPolicy(policy string) {
switch policy {
case FreeOrderPolicyCancelCounts:
@@ -78,7 +93,7 @@ func (e *Engine) PlaceEntry(ctx context.Context, accountIDHash string, instrumen
if err != nil {
return domain.Order{}, err
}
return e.PlaceLimit(ctx, domain.Order{
return e.placeLimit(ctx, domain.Order{
ClientOrderID: ClientOrderID(tradeDate, instrument.InstrumentUID, domain.SideBuy, attempt),
AccountIDHash: accountIDHash,
InstrumentUID: instrument.InstrumentUID,
@@ -90,7 +105,7 @@ func (e *Engine) PlaceEntry(ctx context.Context, accountIDHash string, instrumen
Status: domain.OrderStatusNew,
AttemptNo: attempt,
RawStateJSON: "{}",
})
}, instrument.FreeOrderLimitPerDay)
}
func (e *Engine) PlaceExit(ctx context.Context, accountIDHash string, instrument domain.Instrument, tradeDate time.Time, lots int64, book domain.OrderBook, improveTicks int, attempt int) (domain.Order, error) {
@@ -105,7 +120,7 @@ func (e *Engine) PlaceExit(ctx context.Context, accountIDHash string, instrument
if err != nil {
return domain.Order{}, err
}
return e.PlaceLimit(ctx, domain.Order{
return e.placeLimit(ctx, domain.Order{
ClientOrderID: ClientOrderID(tradeDate, instrument.InstrumentUID, domain.SideSell, attempt),
AccountIDHash: accountIDHash,
InstrumentUID: instrument.InstrumentUID,
@@ -117,10 +132,14 @@ func (e *Engine) PlaceExit(ctx context.Context, accountIDHash string, instrument
Status: domain.OrderStatusNew,
AttemptNo: attempt,
RawStateJSON: "{}",
})
}, instrument.FreeOrderLimitPerDay)
}
func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Order, error) {
return e.placeLimit(ctx, order, 0)
}
func (e *Engine) placeLimit(ctx context.Context, order domain.Order, freeOrderLimit int) (domain.Order, error) {
lock := e.lockFor(order.InstrumentUID)
lock.Lock()
defer lock.Unlock()
@@ -134,7 +153,7 @@ func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Ord
}
}
if e.mode == domain.ModePaper {
return e.placePaperLimit(ctx, order)
return e.placePaperLimit(ctx, order, freeOrderLimit)
}
if !e.mode.AllowsBrokerOrders() {
order.Status = domain.OrderStatusNew
@@ -147,7 +166,7 @@ func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Ord
return domain.Order{}, errors.New("gateway is nil")
}
now := time.Now().UTC()
now := e.nowUTC()
draft := order
draft.Status = domain.OrderStatusSent
draft.CreatedAt = now
@@ -156,8 +175,13 @@ func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Ord
draft.RawStateJSON = "{}"
}
if e.store != nil {
if err := e.store.UpsertOrder(ctx, draft); err != nil {
return domain.Order{}, fmt.Errorf("persist draft order: %w", err)
if err := e.store.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
if err := repo.UpsertOrder(ctx, draft); err != nil {
return fmt.Errorf("persist draft order: %w", err)
}
return repo.ReserveFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1, freeOrderLimit)
}); err != nil {
return domain.Order{}, err
}
}
posted, err := e.gateway.PostLimitOrder(ctx, e.accountID, order.InstrumentUID, order.Side, order.QuantityLots, order.LimitPrice, order.ClientOrderID)
@@ -180,20 +204,15 @@ func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Ord
posted.CreatedAt = now
posted.UpdatedAt = posted.CreatedAt
if e.store != nil {
if err := e.store.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
if err := repo.UpsertOrder(ctx, posted); err != nil {
return fmt.Errorf("persist posted order: %w", err)
}
return repo.IncrementFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1)
}); err != nil {
if err := e.store.UpsertOrder(ctx, posted); err != nil {
return domain.Order{}, err
}
}
return posted, nil
}
func (e *Engine) placePaperLimit(ctx context.Context, order domain.Order) (domain.Order, error) {
now := time.Now().UTC()
func (e *Engine) placePaperLimit(ctx context.Context, order domain.Order, freeOrderLimit int) (domain.Order, error) {
now := e.nowUTC()
order.BrokerOrderID = "paper-" + order.ClientOrderID
order.FilledLots = order.QuantityLots
order.AvgFillPrice = order.LimitPrice
@@ -206,7 +225,7 @@ func (e *Engine) placePaperLimit(ctx context.Context, order domain.Order) (domai
if err := repo.UpsertOrder(ctx, order); err != nil {
return fmt.Errorf("persist paper order: %w", err)
}
return repo.IncrementFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1)
return repo.ReserveFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1, freeOrderLimit)
}); err != nil {
return domain.Order{}, err
}
@@ -286,12 +305,10 @@ func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg Monit
if cfg.MaxAttempts <= 0 {
cfg.MaxAttempts = 1
}
lastPost := time.Now()
lastPost := e.nowUTC()
current := order
aggregate := order
seen := map[string]domain.Order{order.ClientOrderID: order}
ticker := time.NewTicker(cfg.PollInterval)
defer ticker.Stop()
for {
previous := seen[current.ClientOrderID]
refreshed, err := e.Refresh(ctx, current)
@@ -311,7 +328,7 @@ func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg Monit
if isTerminal(current.Status) {
return aggregate, nil
}
if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) {
if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) {
if err := e.Cancel(ctx, current); err != nil {
return aggregate, err
}
@@ -324,7 +341,7 @@ func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg Monit
return aggregate, nil
}
shouldRepost := cfg.RepostAfter > 0 &&
time.Since(lastPost) >= cfg.RepostAfter &&
e.nowUTC().Sub(lastPost) >= cfg.RepostAfter &&
current.AttemptNo < cfg.MaxAttempts &&
aggregate.FilledLots < aggregate.QuantityLots &&
cfg.Quote != nil
@@ -337,13 +354,11 @@ func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg Monit
current = next
seen[current.ClientOrderID] = current
}
lastPost = time.Now()
lastPost = e.nowUTC()
continue
}
select {
case <-ctx.Done():
if !e.sleep(ctx, cfg.PollInterval) {
return aggregate, ctx.Err()
case <-ticker.C:
}
}
}
@@ -372,7 +387,7 @@ func (e *Engine) MonitorOnce(ctx context.Context, order domain.Order, cfg Monito
if isTerminal(current.Status) {
return aggregate, nil
}
if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) {
if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) {
if err := e.Cancel(ctx, current); err != nil {
return aggregate, err
}
@@ -385,7 +400,7 @@ func (e *Engine) MonitorOnce(ctx context.Context, order domain.Order, cfg Monito
return aggregate, nil
}
shouldRepost := cfg.RepostAfter > 0 &&
repostDue(current, cfg.RepostAfter) &&
e.repostDue(current, cfg.RepostAfter) &&
current.AttemptNo < cfg.MaxAttempts &&
aggregate.FilledLots < aggregate.QuantityLots &&
cfg.Quote != nil
@@ -409,7 +424,7 @@ func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConf
if err := e.ensureRepostBudget(ctx, order, cfg.Instrument); err != nil {
return domain.Order{}, false, err
}
if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) {
if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) {
return order, false, nil
}
book, err := cfg.Quote(ctx, order.InstrumentUID)
@@ -432,7 +447,7 @@ func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConf
cancelled.Status = domain.OrderStatusFilled
return cancelled, true, nil
}
if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) {
if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) {
return cancelled, true, nil
}
book, err = cfg.Quote(ctx, order.InstrumentUID)
@@ -440,7 +455,7 @@ func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConf
return domain.Order{}, false, err
}
if cfg.RepostCheck != nil {
if err := cfg.RepostCheck(ctx, order, cfg.Instrument, book); err != nil {
if err := cfg.RepostCheck(ctx, cancelled, cfg.Instrument, book); err != nil {
return cancelled, true, nil
}
}
@@ -468,25 +483,16 @@ func (e *Engine) waitTerminal(ctx context.Context, order domain.Order, cfg Monit
if isTerminal(current.Status) {
return current, nil
}
if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) {
if !cfg.Deadline.IsZero() && !e.nowUTC().Before(cfg.Deadline) {
return current, nil
}
timer := time.NewTimer(cfg.PollInterval)
select {
case <-ctx.Done():
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
if !e.sleep(ctx, cfg.PollInterval) {
return domain.Order{}, ctx.Err()
case <-timer.C:
}
}
}
func repostDue(order domain.Order, after time.Duration) bool {
func (e *Engine) repostDue(order domain.Order, after time.Duration) bool {
if after <= 0 {
return false
}
@@ -497,7 +503,7 @@ func repostDue(order domain.Order, after time.Duration) bool {
if basis.IsZero() {
return true
}
return time.Since(basis) >= after
return e.nowUTC().Sub(basis) >= after
}
func (e *Engine) ensureRepostBudget(ctx context.Context, order domain.Order, instrument domain.Instrument) error {
@@ -530,7 +536,7 @@ func (e *Engine) checkQuoteFresh(book domain.OrderBook) error {
if book.ReceivedAt.IsZero() {
return fmt.Errorf("quote received timestamp is missing")
}
age := time.Since(book.ReceivedAt)
age := e.nowUTC().Sub(book.ReceivedAt)
if age > e.maxQuoteAge {
return fmt.Errorf("quote age %s exceeds %s", age, e.maxQuoteAge)
}
@@ -541,11 +547,29 @@ func (e *Engine) lockFor(instrumentUID string) *sync.Mutex {
value, _ := e.mu.LoadOrStore(instrumentUID, &sync.Mutex{})
lock, ok := value.(*sync.Mutex)
if !ok {
panic("execution lock has unexpected type")
lock = &sync.Mutex{}
e.mu.Store(instrumentUID, lock)
}
return lock
}
func (e *Engine) nowUTC() time.Time {
if e.clock == nil {
return time.Now().UTC()
}
return e.clock.Now().UTC()
}
func (e *Engine) sleep(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
if e.clock == nil {
return timeutil.RealClock{}.Sleep(ctx.Done(), d)
}
return e.clock.Sleep(ctx.Done(), d)
}
func bestBidAsk(book domain.OrderBook) (decimal.Decimal, decimal.Decimal, error) {
bid, ok := book.BestBid()
if !ok {
+94
View File
@@ -2,16 +2,30 @@ package execution
import (
"context"
"errors"
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/risk"
"overnight-trading-bot/internal/testutil"
"overnight-trading-bot/internal/tinvest"
)
type fixedClock struct {
now time.Time
}
func (c *fixedClock) Now() time.Time {
return c.now
}
func (c *fixedClock) Sleep(<-chan struct{}, time.Duration) bool {
return true
}
func TestClientOrderIDIncludesAttempt(t *testing.T) {
date := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
first := ClientOrderID(date, "uid:TRUR", domain.SideBuy, 1)
@@ -66,6 +80,86 @@ func TestPlaceLimitSuppressesDuplicateSubmit(t *testing.T) {
}
}
func TestPlaceEntryReservesFreeOrderBudgetAtomically(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
engine := NewEngine(domain.ModeSandbox, "account", gateway, repo)
instrument := domain.Instrument{
InstrumentUID: "uid",
Lot: 1,
MinPriceIncrement: decimal.NewFromInt(1),
FreeOrderLimitPerDay: 1,
}
book := domain.OrderBook{
InstrumentUID: "uid",
Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(99), QuantityLots: 10}},
Asks: []domain.OrderBookLevel{{Price: decimal.NewFromInt(101), QuantityLots: 10}},
ReceivedAt: time.Now().UTC(),
}
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
if _, err := engine.PlaceEntry(ctx, "hash", instrument, tradeDate, 1, book, 1, 1); err != nil {
t.Fatal(err)
}
_, err := engine.PlaceEntry(ctx, "hash", instrument, tradeDate, 1, book, 1, 2)
if !errors.Is(err, risk.ErrFreeOrderBudget) {
t.Fatalf("expected free order budget error, got %v", err)
}
if got := len(gateway.Orders); got != 1 {
t.Fatalf("broker orders=%d, want no second post", got)
}
}
func TestMonitorOnceUsesInjectedClockForDeadline(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
engine := NewEngine(domain.ModeSandbox, "account", gateway, repo)
clock := &fixedClock{now: time.Date(2030, 1, 1, 10, 0, 0, 0, time.UTC)}
engine.SetClock(clock)
order, err := engine.PlaceLimit(ctx, domain.Order{
ClientOrderID: "clocked",
AccountIDHash: "hash",
InstrumentUID: "uid",
TradeDate: clock.now,
Side: domain.SideBuy,
OrderType: domain.OrderTypeLimit,
LimitPrice: decimal.NewFromInt(100),
QuantityLots: 1,
Status: domain.OrderStatusNew,
AttemptNo: 1,
})
if err != nil {
t.Fatal(err)
}
if !order.CreatedAt.Equal(clock.now) {
t.Fatalf("created_at=%s, want injected clock %s", order.CreatedAt, clock.now)
}
monitored, err := engine.MonitorOnce(ctx, order, MonitorConfig{
Deadline: clock.now.Add(time.Minute),
PollInterval: time.Millisecond,
MaxAttempts: 1,
})
if err != nil {
t.Fatal(err)
}
if monitored.Status == domain.OrderStatusExpired {
t.Fatalf("order expired before injected deadline: %+v", monitored)
}
clock.now = clock.now.Add(time.Minute)
monitored, err = engine.MonitorOnce(ctx, order, MonitorConfig{
Deadline: clock.now,
PollInterval: time.Millisecond,
MaxAttempts: 1,
})
if err != nil {
t.Fatal(err)
}
if monitored.Status != domain.OrderStatusExpired {
t.Fatalf("status=%s, want EXPIRED at injected deadline", monitored.Status)
}
}
func TestPaperPlaceEntryFillsAndCountsSubmittedOrder(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
+24 -3
View File
@@ -109,14 +109,14 @@ func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate ti
}
short := Rolling(overnight, cfg.RollingShort, cfg.EWMALambda)
long := Rolling(overnight, cfg.RollingLong, cfg.EWMALambda)
q05Abs := rollingQ05Abs(overnight, cfg.RollingShort)
adv := ADV(candles, instrument.Lot, 20)
rawEdgeBps := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000))
instrumentCommission := instrument.ExpectedCommissionBpsPerSide.Mul(decimal.NewFromInt(2))
commission := roundTripCommissionBps(instrument, cfg)
expectedCost := spread.SpreadBps.
Add(cfg.EntrySlippageBps).
Add(cfg.ExitSlippageBps).
Add(cfg.CommissionRoundtripBps).
Add(instrumentCommission).
Add(commission).
Add(cfg.RiskBufferBps)
return domain.FeatureSet{
InstrumentUID: instrument.InstrumentUID,
@@ -126,6 +126,7 @@ func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate ti
MuOn60: decimal.NewFromFloat(short.Mean),
MuOn252: decimal.NewFromFloat(long.Mean),
SigmaOn60: decimal.NewFromFloat(short.StdDev),
Q05On60Abs: q05Abs,
TStatOn60: decimal.NewFromFloat(short.TStat),
WinOn60: decimal.NewFromFloat(short.WinRate),
EWMAOn: decimal.NewFromFloat(short.EWMA),
@@ -141,6 +142,26 @@ func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate ti
}, nil
}
func rollingQ05Abs(values []float64, window int) decimal.Decimal {
if window <= 0 || len(values) < window {
return decimal.Zero
}
sample := values[len(values)-window:]
q05 := decimal.NewFromFloat(Quantile(sample, 0.05))
if q05.IsNegative() {
return q05.Neg()
}
return q05
}
func roundTripCommissionBps(instrument domain.Instrument, cfg PipelineConfig) decimal.Decimal {
instrumentCommission := instrument.ExpectedCommissionBpsPerSide.Mul(decimal.NewFromInt(2))
if instrumentCommission.IsPositive() {
return instrumentCommission
}
return cfg.CommissionRoundtripBps
}
func historicalDailyCandles(candles []domain.Candle, tradeDate time.Time) []domain.Candle {
tradeDay := dateOnly(tradeDate)
out := make([]domain.Candle, 0, len(candles))
+81 -2
View File
@@ -41,14 +41,93 @@ func TestComputeExpectedCostIncludesCommissionAndSlippage(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !got.ExpectedCostBps.Equal(decimal.NewFromInt(26)) {
t.Fatalf("expected cost=%s, want 26", got.ExpectedCostBps)
if !got.ExpectedCostBps.Equal(decimal.NewFromInt(22)) {
t.Fatalf("expected cost=%s, want 22", got.ExpectedCostBps)
}
if !got.EntryIntervalVolume.Equal(decimal.NewFromInt(10000)) || !got.ExitIntervalVolume.Equal(decimal.NewFromInt(9000)) {
t.Fatalf("interval volumes were not preserved: %+v", got)
}
}
func TestComputeExpectedCostFallsBackToConfigCommission(t *testing.T) {
candles := flatCandles(time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), 6)
got, err := Compute(domain.Instrument{
InstrumentUID: "uid",
Lot: 1,
}, candles, candles[5].TradeDate, SpreadResult{SpreadBps: decimal.NewFromInt(10)}, PipelineConfig{
RollingShort: 2,
RollingLong: 2,
EWMALambda: 0.08,
RiskBufferBps: decimal.NewFromInt(5),
EntrySlippageBps: decimal.NewFromInt(2),
ExitSlippageBps: decimal.NewFromInt(3),
CommissionRoundtripBps: decimal.NewFromInt(4),
}, decimal.Zero, decimal.Zero)
if err != nil {
t.Fatal(err)
}
if !got.ExpectedCostBps.Equal(decimal.NewFromInt(24)) {
t.Fatalf("expected cost=%s, want 24", got.ExpectedCostBps)
}
}
func TestComputeStoresHistoricalQ05Abs(t *testing.T) {
start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
returns := []string{"-0.10", "0.01", "0.02", "0.03", "0.04"}
candles := []domain.Candle{{
InstrumentUID: "uid",
TradeDate: start,
Open: decimal.NewFromInt(100),
Close: decimal.NewFromInt(100),
VolumeLots: decimal.NewFromInt(1),
}}
for i, raw := range returns {
r, err := decimal.NewFromString(raw)
if err != nil {
t.Fatal(err)
}
open := decimal.NewFromInt(100).Mul(decimal.NewFromInt(1).Add(r))
candles = append(candles, domain.Candle{
InstrumentUID: "uid",
TradeDate: start.AddDate(0, 0, i+1),
Open: open,
Close: decimal.NewFromInt(100),
VolumeLots: decimal.NewFromInt(1),
})
}
got, err := Compute(domain.Instrument{InstrumentUID: "uid", Lot: 1}, candles, start.AddDate(0, 0, 6), SpreadResult{}, PipelineConfig{
RollingShort: 5,
RollingLong: 5,
EWMALambda: 0.08,
}, decimal.Zero, decimal.Zero)
if err != nil {
t.Fatal(err)
}
want := decimal.NewFromFloat(0.078)
diff := got.Q05On60Abs.Sub(want)
if diff.IsNegative() {
diff = diff.Neg()
}
if diff.GreaterThan(decimal.NewFromFloat(0.000001)) {
t.Fatalf("Q05On60Abs=%s, want about %s", got.Q05On60Abs, want)
}
}
func flatCandles(start time.Time, count int) []domain.Candle {
candles := make([]domain.Candle, 0, count)
for i := 0; i < count; i++ {
price := decimal.NewFromInt(int64(100 + i))
candles = append(candles, domain.Candle{
InstrumentUID: "uid",
TradeDate: start.AddDate(0, 0, i),
Open: price,
Close: price,
VolumeLots: decimal.NewFromInt(1000),
})
}
return candles
}
func TestIntervalVolume(t *testing.T) {
got := IntervalVolume([]domain.Candle{
{Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)},
+17 -2
View File
@@ -7,16 +7,24 @@ import (
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/repository"
"overnight-trading-bot/internal/timeutil"
"overnight-trading-bot/internal/tinvest"
)
type Loader struct {
repo repository.Repository
gateway tinvest.Gateway
clock timeutil.Clock
}
func NewLoader(repo repository.Repository, gateway tinvest.Gateway) Loader {
return Loader{repo: repo, gateway: gateway}
return Loader{repo: repo, gateway: gateway, clock: timeutil.RealClock{}}
}
func (l *Loader) SetClock(clock timeutil.Clock) {
if clock != nil {
l.clock = clock
}
}
func (l Loader) BackfillDaily(ctx context.Context, instruments []domain.Instrument, from, to time.Time) error {
@@ -59,9 +67,16 @@ func (l Loader) LatestQuote(ctx context.Context, instrumentUID string, depth int
if book.ReceivedAt.IsZero() {
return domain.OrderBook{}, fmt.Errorf("quote received timestamp is missing")
}
age := time.Since(book.ReceivedAt)
age := l.nowUTC().Sub(book.ReceivedAt)
if maxAge > 0 && age > maxAge {
return domain.OrderBook{}, fmt.Errorf("quote age %s exceeds %s", age, maxAge)
}
return book, nil
}
func (l Loader) nowUTC() time.Time {
if l.clock == nil {
return time.Now().UTC()
}
return l.clock.Now().UTC()
}
+6 -5
View File
@@ -9,8 +9,9 @@ import (
)
var (
ErrInvalidTick = errors.New("tick must be positive")
ErrInvalidBase = errors.New("base must be positive")
ErrInvalidTick = errors.New("tick must be positive")
ErrInvalidBase = errors.New("base must be positive")
ErrInvalidQuotation = errors.New("decimal cannot be represented as protobuf quotation")
)
type RoundMode int
@@ -28,7 +29,7 @@ func QuotationToDecimal(q *pb.Quotation) decimal.Decimal {
return decimal.NewFromInt(q.GetUnits()).Add(decimal.New(int64(q.GetNano()), -9))
}
func DecimalToQuotation(d decimal.Decimal) *pb.Quotation {
func DecimalToQuotation(d decimal.Decimal) (*pb.Quotation, error) {
units := d.Truncate(0)
nano := d.Sub(units).Mul(decimal.NewFromInt(1_000_000_000)).Round(0)
if nano.Equal(decimal.NewFromInt(1_000_000_000)) {
@@ -41,12 +42,12 @@ func DecimalToQuotation(d decimal.Decimal) *pb.Quotation {
}
nanoPart := nano.IntPart()
if nanoPart < -999_999_999 || nanoPart > 999_999_999 {
panic("decimal quotation nano is out of protobuf range")
return nil, ErrInvalidQuotation
}
return &pb.Quotation{
Units: units.IntPart(),
Nano: int32(nanoPart), // #nosec G115 -- nanoPart is bounded above.
}
}, nil
}
func MoneyValueToDecimal(v *pb.MoneyValue) decimal.Decimal {
+15
View File
@@ -37,3 +37,18 @@ func TestRoundToTick(t *testing.T) {
}
}
}
func TestDecimalToQuotationHandlesRoundingCarry(t *testing.T) {
tooPrecise := d("0.0000000005")
if _, err := DecimalToQuotation(tooPrecise); err != nil {
t.Fatalf("roundable quotation returned error: %v", err)
}
hugeNano := d("0.9999999996")
got, err := DecimalToQuotation(hugeNano)
if err != nil {
t.Fatalf("carry quotation returned error: %v", err)
}
if got.Units != 1 || got.Nano != 0 {
t.Fatalf("quotation=%+v, want carry to 1/0", got)
}
}
+1 -1
View File
@@ -33,7 +33,7 @@ func (m Manager) OnEntryFill(ctx context.Context, accountIDHash string, instrume
Lot: lot,
AvgBuyPrice: order.AvgFillPrice,
CommissionTotal: order.Commission,
Status: domain.PositionHoldingOvernight,
Status: domain.PositionEntryFilled,
OpenedAt: &now,
UpdatedAt: now,
}
+3
View File
@@ -28,6 +28,9 @@ func TestOnEntryFillKeepsBuyCommission(t *testing.T) {
if !pos.CommissionTotal.Equal(decimal.NewFromInt(3)) {
t.Fatalf("commission=%s, want 3", pos.CommissionTotal)
}
if pos.Status != domain.PositionEntryFilled {
t.Fatalf("status=%s, want ENTRY_FILLED", pos.Status)
}
}
func TestOnExitFillPartialUsesExecutedLots(t *testing.T) {
+19 -2
View File
@@ -13,6 +13,7 @@ import (
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/money"
"overnight-trading-bot/internal/repository"
"overnight-trading-bot/internal/timeutil"
"overnight-trading-bot/internal/tinvest"
)
@@ -29,6 +30,7 @@ type Engine struct {
commissionTolerance decimal.Decimal
requireZeroCommission bool
quarantineOnNonZero bool
clock timeutil.Clock
}
func New(repo repository.Repository, gateway tinvest.Gateway, accountID, accountIDHash string) Engine {
@@ -40,6 +42,7 @@ func New(repo repository.Repository, gateway tinvest.Gateway, accountID, account
accountIDHash: accountIDHash,
window: 72 * time.Hour,
commissionTolerance: defaultCommissionTolerance,
clock: timeutil.RealClock{},
}
}
@@ -66,6 +69,13 @@ func (e Engine) WithCommissionPolicy(requireZero, quarantineOnNonZero bool, tole
return e
}
func (e Engine) WithClock(clock timeutil.Clock) Engine {
if clock != nil {
e.clock = clock
}
return e
}
func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
if e.mu != nil {
e.mu.Lock()
@@ -79,7 +89,7 @@ func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
if err != nil {
return nil, err
}
now := time.Now().UTC()
now := e.nowUTC()
localByBroker := make(map[string]domain.Order, len(localOrders))
brokerByID := make(map[string]domain.Order, len(brokerOrders))
for _, order := range localOrders {
@@ -175,7 +185,7 @@ func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
}
if err := e.repo.QuarantineInstrument(ctx, diff.InstrumentUID, diff.Message); err != nil {
_ = e.repo.InsertRiskEvent(ctx, domain.RiskEvent{
TS: time.Now().UTC(),
TS: now,
Severity: domain.SeverityCritical,
EventType: "quarantine_failed",
InstrumentUID: diff.InstrumentUID,
@@ -192,6 +202,13 @@ func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
return diffs, nil
}
func (e Engine) nowUTC() time.Time {
if e.clock == nil {
return time.Now().UTC()
}
return e.clock.Now().UTC()
}
func (e Engine) isInFlight(order domain.Order, now time.Time) bool {
if e.inFlightGrace <= 0 || order.CreatedAt.IsZero() {
return false
+44
View File
@@ -171,6 +171,50 @@ func TestReconciliationSkipsFreshInFlightLocalOrders(t *testing.T) {
}
}
func TestReconciliationUsesInjectedClockForInFlightGrace(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
clock := &fixedClock{now: time.Date(2000, 1, 1, 10, 0, 0, 0, time.UTC)}
if err := repo.UpsertOrder(ctx, domain.Order{
ClientOrderID: "fresh",
AccountIDHash: "hash",
InstrumentUID: "uid",
TradeDate: clock.now,
Side: domain.SideBuy,
OrderType: domain.OrderTypeLimit,
QuantityLots: 1,
Status: domain.OrderStatusSent,
CreatedAt: clock.now.Add(-5 * time.Second),
}); err != nil {
t.Fatal(err)
}
diffs, err := New(repo, gateway, "account", "hash").
WithClock(clock).
WithInFlightGrace(10 * time.Second).
Run(ctx)
if err != nil {
t.Fatal(err)
}
for _, diff := range diffs {
if diff.Kind == "local_order_without_broker_id" || diff.Kind == "missing_local_order" {
t.Fatalf("fresh in-flight order produced diff: %+v", diffs)
}
}
}
type fixedClock struct {
now time.Time
}
func (c *fixedClock) Now() time.Time {
return c.now
}
func (c *fixedClock) Sleep(<-chan struct{}, time.Duration) bool {
return true
}
func TestReconciliationFindsCashMismatch(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
@@ -0,0 +1,3 @@
ALTER TABLE features DROP COLUMN q05_on_60_abs;
UPDATE schema_meta SET meta_value='0006' WHERE meta_key='schema_version';
@@ -0,0 +1,4 @@
ALTER TABLE features
ADD COLUMN q05_on_60_abs DECIMAL(20,10) NOT NULL DEFAULT 0 AFTER sigma_on_60;
UPDATE schema_meta SET meta_value='0007' WHERE meta_key='schema_version';
+51 -7
View File
@@ -13,6 +13,7 @@ import (
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/repository"
"overnight-trading-bot/internal/risk"
)
var _ repository.Repository = (*Repository)(nil)
@@ -224,13 +225,13 @@ ON DUPLICATE KEY UPDATE
func (r *Repository) mergeFeatures(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
_, err := r.execer().ExecContext(ctx, `
INSERT INTO features (
instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60, q05_on_60_abs,
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
exit_interval_volume, calculated_at
)
SELECT
?, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
?, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60, q05_on_60_abs,
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
exit_interval_volume, calculated_at
@@ -238,7 +239,7 @@ FROM features WHERE instrument_uid=?
ON DUPLICATE KEY UPDATE
r_on=VALUES(r_on), r_day=VALUES(r_day), mu_on_60=VALUES(mu_on_60),
mu_on_252=VALUES(mu_on_252), sigma_on_60=VALUES(sigma_on_60),
tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60),
q05_on_60_abs=VALUES(q05_on_60_abs), tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60),
ewma_on=VALUES(ewma_on), spread_bps=VALUES(spread_bps),
half_spread_bps=VALUES(half_spread_bps), tick_bps=VALUES(tick_bps),
adv_20=VALUES(adv_20), expected_cost_bps=VALUES(expected_cost_bps),
@@ -385,19 +386,19 @@ func (r *Repository) UpsertFeature(ctx context.Context, feature domain.FeatureSe
}
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
INSERT INTO features (
instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60, q05_on_60_abs,
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
exit_interval_volume, calculated_at
) VALUES (
:instrument_uid, :trade_date, :r_on, :r_day, :mu_on_60, :mu_on_252, :sigma_on_60,
:instrument_uid, :trade_date, :r_on, :r_day, :mu_on_60, :mu_on_252, :sigma_on_60, :q05_on_60_abs,
:tstat_on_60, :win_on_60, :ewma_on, :spread_bps, :half_spread_bps, :tick_bps,
:adv_20, :expected_cost_bps, :net_edge_bps, :entry_interval_volume,
:exit_interval_volume, :calculated_at
) ON DUPLICATE KEY UPDATE
r_on=VALUES(r_on), r_day=VALUES(r_day), mu_on_60=VALUES(mu_on_60),
mu_on_252=VALUES(mu_on_252), sigma_on_60=VALUES(sigma_on_60),
tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60),
q05_on_60_abs=VALUES(q05_on_60_abs), tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60),
ewma_on=VALUES(ewma_on), spread_bps=VALUES(spread_bps),
half_spread_bps=VALUES(half_spread_bps), tick_bps=VALUES(tick_bps),
adv_20=VALUES(adv_20), expected_cost_bps=VALUES(expected_cost_bps),
@@ -453,7 +454,9 @@ func (r *Repository) UpsertOrder(ctx context.Context, order domain.Order) error
if order.CreatedAt.IsZero() {
order.CreatedAt = now
}
order.UpdatedAt = now
if order.UpdatedAt.IsZero() {
order.UpdatedAt = now
}
if order.RawStateJSON == "" {
order.RawStateJSON = "{}"
}
@@ -596,6 +599,47 @@ ON DUPLICATE KEY UPDATE orders_sent=orders_sent+VALUES(orders_sent)`, dateOnly(t
return err
}
func (r *Repository) ReserveFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int, limit int) error {
if delta <= 0 {
return nil
}
if limit <= 0 {
return r.IncrementFreeOrders(ctx, tradeDate, instrumentUID, delta)
}
return r.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
txRepo, ok := repo.(*Repository)
if !ok {
return errors.New("unexpected repository implementation")
}
return txRepo.reserveFreeOrdersLocked(ctx, tradeDate, instrumentUID, delta, limit)
})
}
func (r *Repository) reserveFreeOrdersLocked(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int, limit int) error {
tradeDay := dateOnly(tradeDate)
if _, err := r.execer().ExecContext(ctx, `
INSERT IGNORE INTO free_order_counters (trade_date, instrument_uid, orders_sent)
VALUES (?, ?, 0)`, tradeDay, instrumentUID); err != nil {
return err
}
var sent int
if err := r.getContext(ctx, &sent, `
SELECT orders_sent FROM free_order_counters
WHERE trade_date=? AND instrument_uid=?
FOR UPDATE`, tradeDay, instrumentUID); err != nil {
return err
}
remaining := limit - sent
if remaining < delta {
return fmt.Errorf("%w: %s remaining=%d needed=%d", risk.ErrFreeOrderBudget, instrumentUID, remaining, delta)
}
_, err := r.execer().ExecContext(ctx, `
UPDATE free_order_counters
SET orders_sent=orders_sent+?
WHERE trade_date=? AND instrument_uid=?`, delta, tradeDay, instrumentUID)
return err
}
func (r *Repository) GetSystemState(ctx context.Context) (domain.SystemState, bool, string, error) {
var row struct {
State string `db:"state"`
+3
View File
@@ -57,6 +57,7 @@ type featureRow struct {
MuOn60 decimal.Decimal `db:"mu_on_60"`
MuOn252 decimal.Decimal `db:"mu_on_252"`
SigmaOn60 decimal.Decimal `db:"sigma_on_60"`
Q05On60Abs decimal.Decimal `db:"q05_on_60_abs"`
TStatOn60 decimal.Decimal `db:"tstat_on_60"`
WinOn60 decimal.Decimal `db:"win_on_60"`
EWMAOn decimal.Decimal `db:"ewma_on"`
@@ -80,6 +81,7 @@ func featureRowFromDomain(feature domain.FeatureSet) featureRow {
MuOn60: feature.MuOn60,
MuOn252: feature.MuOn252,
SigmaOn60: feature.SigmaOn60,
Q05On60Abs: feature.Q05On60Abs,
TStatOn60: feature.TStatOn60,
WinOn60: feature.WinOn60,
EWMAOn: feature.EWMAOn,
@@ -104,6 +106,7 @@ func (r featureRow) domain() domain.FeatureSet {
MuOn60: r.MuOn60,
MuOn252: r.MuOn252,
SigmaOn60: r.SigmaOn60,
Q05On60Abs: r.Q05On60Abs,
TStatOn60: r.TStatOn60,
WinOn60: r.WinOn60,
EWMAOn: r.EWMAOn,
+1
View File
@@ -38,6 +38,7 @@ type Repository interface {
InsertRiskEvent(ctx context.Context, event domain.RiskEvent) error
GetFreeOrdersSent(ctx context.Context, tradeDate time.Time, instrumentUID string) (int, error)
IncrementFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int) error
ReserveFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int, limit int) error
GetSystemState(ctx context.Context) (domain.SystemState, bool, string, error)
SaveSystemState(ctx context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, contextJSON string) error
+36 -4
View File
@@ -130,6 +130,16 @@ func (s *Scheduler) Run(ctx context.Context) error {
}
}
func (s Scheduler) GracefulShutdown(ctx context.Context) error {
if s.svc.Repo == nil || s.svc.Execution == nil {
return nil
}
if err := s.cancelActiveOrders(ctx, domain.SideBuy, domain.OrderStatusCancelled, "shutdown_cancel_active_orders"); err != nil {
return err
}
return s.cancelActiveOrders(ctx, domain.SideSell, domain.OrderStatusCancelled, "shutdown_cancel_active_orders")
}
func (s *Scheduler) Step(ctx context.Context) error {
if err := s.checkInfrastructure(ctx); err != nil {
return err
@@ -370,7 +380,7 @@ func (s Scheduler) sizeSignal(portfolio domain.Portfolio, instrument domain.Inst
Lot: instrument.Lot,
EntryIntervalVolume: feature.EntryIntervalVolume,
ExitIntervalVolume: feature.ExitIntervalVolume,
Q05OvernightAbs: money.Abs(feature.SigmaOn60).Mul(decimal.NewFromFloat(1.65)),
Q05OvernightAbs: feature.Q05On60Abs,
}), nil
}
@@ -544,7 +554,7 @@ func (s *Scheduler) monitorEntryOrders(ctx context.Context, now time.Time) error
return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
},
RepostCheck: func(ctx context.Context, order domain.Order, instrument domain.Instrument, book domain.OrderBook) error {
return s.repostPreTradeCheck(ctx, now, order, instrument, book)
return s.repostPreTradeCheck(ctx, s.nowUTC().In(s.cfg.Location), order, instrument, book)
},
})
if err != nil {
@@ -570,9 +580,31 @@ func (s *Scheduler) holdOvernight(ctx context.Context) error {
if err := s.closeEntryWindow(ctx); err != nil {
return err
}
if err := s.promoteEntryFilledPositions(ctx); err != nil {
return err
}
return s.periodicReconcile(ctx)
}
func (s Scheduler) promoteEntryFilledPositions(ctx context.Context) error {
positionsList, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
if err != nil {
return err
}
now := s.nowUTC()
for _, pos := range positionsList {
if pos.Status != domain.PositionEntryFilled {
continue
}
pos.Status = domain.PositionHoldingOvernight
pos.UpdatedAt = now
if err := s.svc.Repo.UpsertPosition(ctx, pos); err != nil {
return err
}
}
return nil
}
func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error {
if err := s.transitionTo(ctx, domain.StatePlaceExitOrders); err != nil {
return err
@@ -694,7 +726,7 @@ func (s *Scheduler) monitorExitOrders(ctx context.Context, now time.Time) error
return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
},
RepostCheck: func(ctx context.Context, order domain.Order, instrument domain.Instrument, book domain.OrderBook) error {
return s.repostPreTradeCheck(ctx, now, order, instrument, book)
return s.repostPreTradeCheck(ctx, s.nowUTC().In(s.cfg.Location), order, instrument, book)
},
})
if err != nil {
@@ -1080,7 +1112,7 @@ func (s *Scheduler) failOpenPositionsAtHardDeadline(ctx context.Context) error {
now := s.nowUTC()
for _, pos := range positionsList {
switch pos.Status {
case domain.PositionHoldingOvernight, domain.PositionExitPartiallyFilled, domain.PositionExitOrderSent:
case domain.PositionEntryFilled, domain.PositionHoldingOvernight, domain.PositionExitPartiallyFilled, domain.PositionExitOrderSent:
pos.Status = domain.PositionExitFailed
pos.UpdatedAt = now
if err := s.svc.Repo.UpsertPosition(ctx, pos); err != nil {
+43
View File
@@ -650,6 +650,49 @@ func TestPlaceExitUsesCurrentTradeDateForOrderAndFreeCounter(t *testing.T) {
}
}
func TestGracefulShutdownCancelsActiveOrders(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
order := domain.Order{
ClientOrderID: "shutdown-order",
BrokerOrderID: "broker-shutdown-order",
AccountIDHash: "hash",
InstrumentUID: "uid",
TradeDate: tradeDate,
Side: domain.SideBuy,
OrderType: domain.OrderTypeLimit,
LimitPrice: decimal.NewFromInt(100),
QuantityLots: 1,
Status: domain.OrderStatusSent,
RawStateJSON: "{}",
}
if err := repo.UpsertOrder(ctx, order); err != nil {
t.Fatal(err)
}
gateway.Orders[order.BrokerOrderID] = order
execEngine := execution.NewEngine(domain.ModeSandbox, "account", gateway, repo)
s := Scheduler{
cfg: Config{Mode: domain.ModeSandbox},
svc: Services{
Repo: repo,
Execution: &execEngine,
AccountIDHash: "hash",
},
}
if err := s.GracefulShutdown(ctx); err != nil {
t.Fatal(err)
}
orders, err := repo.ListOrders(ctx, "hash", tradeDate, tradeDate)
if err != nil {
t.Fatal(err)
}
if len(orders) != 1 || orders[0].Status != domain.OrderStatusCancelled {
t.Fatalf("orders=%+v, want cancelled", orders)
}
}
func mustTOD(raw string) timeutil.TimeOfDay {
tod, err := timeutil.ParseTimeOfDay(raw)
if err != nil {
+15
View File
@@ -9,6 +9,7 @@ import (
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/repository"
"overnight-trading-bot/internal/risk"
)
var _ repository.Repository = (*MemoryRepository)(nil)
@@ -263,6 +264,20 @@ func (r *MemoryRepository) IncrementFreeOrders(_ context.Context, tradeDate time
return nil
}
func (r *MemoryRepository) ReserveFreeOrders(_ context.Context, tradeDate time.Time, instrumentUID string, delta int, limit int) error {
r.mu.Lock()
defer r.mu.Unlock()
if delta <= 0 {
return nil
}
key := featureKey(instrumentUID, tradeDate)
if limit > 0 && r.FreeOrders[key]+delta > limit {
return risk.ErrFreeOrderBudget
}
r.FreeOrders[key] += delta
return nil
}
func (r *MemoryRepository) GetSystemState(_ context.Context) (domain.SystemState, bool, string, error) {
r.mu.Lock()
defer r.mu.Unlock()
+5 -1
View File
@@ -177,11 +177,15 @@ func (g *RealGateway) PostLimitOrder(ctx context.Context, accountID, instrumentU
if side == domain.SideSell {
direction = pb.OrderDirection_ORDER_DIRECTION_SELL
}
quotation, err := money.DecimalToQuotation(price)
if err != nil {
return domain.Order{}, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) {
return g.orders.PostOrder(&investgo.PostOrderRequest{
InstrumentId: instrumentUID,
Quantity: lots,
Price: money.DecimalToQuotation(price),
Price: quotation,
Direction: direction,
AccountId: accountID,
OrderType: pb.OrderType_ORDER_TYPE_LIMIT,
+5 -1
View File
@@ -39,11 +39,15 @@ func (g *SandboxGateway) PostLimitOrder(ctx context.Context, accountID, instrume
if side == domain.SideSell {
direction = pb.OrderDirection_ORDER_DIRECTION_SELL
}
quotation, err := money.DecimalToQuotation(price)
if err != nil {
return domain.Order{}, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) {
return g.sandbox.PostSandboxOrder(&investgo.PostOrderRequest{
InstrumentId: instrumentUID,
Quantity: lots,
Price: money.DecimalToQuotation(price),
Price: quotation,
Direction: direction,
AccountId: accountID,
OrderType: pb.OrderType_ORDER_TYPE_LIMIT,