third version

This commit is contained in:
2026-06-08 07:05:01 +00:00
parent 282c841e11
commit 52a935b8b4
20 changed files with 1371 additions and 151 deletions
+12
View File
@@ -26,6 +26,8 @@ func run() error {
entrySlip := flag.String("entry-slippage-bps", "8", "entry slippage in bps") entrySlip := flag.String("entry-slippage-bps", "8", "entry slippage in bps")
exitSlip := flag.String("exit-slippage-bps", "8", "exit slippage in bps") exitSlip := flag.String("exit-slippage-bps", "8", "exit slippage in bps")
commission := flag.String("commission-roundtrip-bps", "0", "roundtrip commission in bps") commission := flag.String("commission-roundtrip-bps", "0", "roundtrip commission in bps")
riskBuffer := flag.String("risk-buffer-bps", "5", "risk buffer in bps included in signal cost")
assumedSpread := flag.String("assumed-spread-bps", "20", "assumed executable spread cost in bps")
rollingShort := flag.Int("rolling-short", 60, "short rolling window") rollingShort := flag.Int("rolling-short", 60, "short rolling window")
rollingLong := flag.Int("rolling-long", 252, "long rolling window") rollingLong := flag.Int("rolling-long", 252, "long rolling window")
ewmaLambda := flag.Float64("ewma-lambda", 0.08, "EWMA lambda") ewmaLambda := flag.Float64("ewma-lambda", 0.08, "EWMA lambda")
@@ -80,6 +82,14 @@ func run() error {
if err != nil { if err != nil {
return fmt.Errorf("commission: %w", err) return fmt.Errorf("commission: %w", err)
} }
riskBuf, err := decimal.NewFromString(*riskBuffer)
if err != nil {
return fmt.Errorf("risk buffer: %w", err)
}
assumed, err := decimal.NewFromString(*assumedSpread)
if err != nil {
return fmt.Errorf("assumed spread: %w", err)
}
tstat, err := decimal.NewFromString(*minTStat) tstat, err := decimal.NewFromString(*minTStat)
if err != nil { if err != nil {
return fmt.Errorf("min tstat: %w", err) return fmt.Errorf("min tstat: %w", err)
@@ -108,6 +118,7 @@ func run() error {
EntrySlippageBps: entry, EntrySlippageBps: entry,
ExitSlippageBps: exit, ExitSlippageBps: exit,
CommissionRoundtripBps: comm, CommissionRoundtripBps: comm,
RiskBufferBps: riskBuf,
OutputDir: *outputDir, OutputDir: *outputDir,
RollingShort: *rollingShort, RollingShort: *rollingShort,
RollingLong: *rollingLong, RollingLong: *rollingLong,
@@ -118,6 +129,7 @@ func run() error {
MinADVRUB: adv, MinADVRUB: adv,
MaxSpreadBps: spread, MaxSpreadBps: spread,
MaxTickBps: tick, MaxTickBps: tick,
AssumedSpreadBps: assumed,
RequireZeroCommission: *requireZeroCommission, RequireZeroCommission: *requireZeroCommission,
UseMinuteModel: *useMinuteModel, UseMinuteModel: *useMinuteModel,
}) })
+57 -1
View File
@@ -17,6 +17,7 @@ import (
"time" "time"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/config" "overnight-trading-bot/internal/config"
"overnight-trading-bot/internal/domain" "overnight-trading-bot/internal/domain"
@@ -138,6 +139,9 @@ func Run(ctx context.Context, opts Options) error {
if closer != nil { if closer != nil {
defer closer() defer closer()
} }
if err := seedPaperGateway(ctx, repo, gateway); err != nil {
return err
}
notifier, err := notify.NewTelegram(notify.TelegramConfig{ notifier, err := notify.NewTelegram(notify.TelegramConfig{
BotToken: cfg.Telegram.BotToken, BotToken: cfg.Telegram.BotToken,
ChatID: cfg.Telegram.ChatID, ChatID: cfg.Telegram.ChatID,
@@ -161,7 +165,8 @@ func Run(ctx context.Context, opts Options) error {
WithCommissionPolicy(cfg.Commission.RequireZeroCommission, cfg.Commission.QuarantineOnNonZero, cfg.Risk.CommissionToleranceRUB) WithCommissionPolicy(cfg.Commission.RequireZeroCommission, cfg.Commission.QuarantineOnNonZero, cfg.Risk.CommissionToleranceRUB)
sm := statemachine.New(repo, cfg.App.Mode) sm := statemachine.New(repo, cfg.App.Mode)
if _, err := sm.Recover(ctx, recon); err != nil { if _, err := sm.Recover(ctx, recon); err != nil {
log.Warn("state recovery did not resume trading", "err", err) _ = notifier.Alert(ctx, fmt.Sprintf("state recovery failed: %s", err))
return fmt.Errorf("state recovery: %w", err)
} }
health := healthcheck.New(db.DB, gateway, time.Duration(cfg.Risk.MaxClockDriftSec)*time.Second) health := healthcheck.New(db.DB, gateway, time.Duration(cfg.Risk.MaxClockDriftSec)*time.Second)
health.Start(cfg.App.HealthcheckAddr) health.Start(cfg.App.HealthcheckAddr)
@@ -270,6 +275,7 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con
ExitWindowStart: cfg.Execution.ExitWindowStart, ExitWindowStart: cfg.Execution.ExitWindowStart,
ExitWindowEnd: cfg.Execution.ExitWindowEnd, ExitWindowEnd: cfg.Execution.ExitWindowEnd,
HardExitDeadline: cfg.Execution.HardExitDeadline, HardExitDeadline: cfg.Execution.HardExitDeadline,
MarketClose: cfg.Execution.MarketClose,
QuoteDepth: cfg.Execution.QuoteDepth, QuoteDepth: cfg.Execution.QuoteDepth,
MaxQuoteAge: time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second, MaxQuoteAge: time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second,
OrderPollInterval: time.Duration(cfg.Execution.OrderPollIntervalMS) * time.Millisecond, OrderPollInterval: time.Duration(cfg.Execution.OrderPollIntervalMS) * time.Millisecond,
@@ -282,9 +288,23 @@ func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Con
RequireZeroCommission: cfg.Commission.RequireZeroCommission, RequireZeroCommission: cfg.Commission.RequireZeroCommission,
QuarantineOnNonZero: cfg.Commission.QuarantineOnNonZero, QuarantineOnNonZero: cfg.Commission.QuarantineOnNonZero,
ReconciliationInterval: 5 * time.Minute, ReconciliationInterval: 5 * time.Minute,
MaxOpenPositions: minPositive(cfg.Strategy.MaxPositions, cfg.Risk.MaxOpenPositions),
}, services) }, services)
} }
func minPositive(a, b int) int {
switch {
case a <= 0:
return b
case b <= 0:
return a
case a < b:
return a
default:
return b
}
}
func openDB(ctx context.Context, cfg config.Config) (*sqlx.DB, error) { func openDB(ctx context.Context, cfg config.Config) (*sqlx.DB, error) {
db, err := sqlx.Open("mysql", cfg.DB.DSN) db, err := sqlx.Open("mysql", cfg.DB.DSN)
if err != nil { if err != nil {
@@ -340,6 +360,42 @@ func buildGateway(ctx context.Context, cfg config.Config, log *slog.Logger) (tin
} }
} }
func seedPaperGateway(ctx context.Context, repo interface {
ListInstruments(context.Context, bool) ([]domain.Instrument, error)
}, gateway tinvest.Gateway) error {
fake, ok := gateway.(*tinvest.FakeGateway)
if !ok {
return nil
}
instrumentsList, err := repo.ListInstruments(ctx, true)
if err != nil {
return err
}
for _, instrument := range instrumentsList {
remote := instrument
if remote.InstrumentUID == "" || strings.HasPrefix(remote.InstrumentUID, "PENDING:") {
remote.InstrumentUID = "paper-" + strings.ToUpper(remote.Ticker)
}
if remote.Figi == "" {
remote.Figi = remote.InstrumentUID
}
if remote.Lot <= 0 {
remote.Lot = 1
}
if !remote.MinPriceIncrement.IsPositive() {
remote.MinPriceIncrement = decimal.RequireFromString("0.01")
}
if remote.Currency == "" {
remote.Currency = "RUB"
}
remote.Enabled = true
remote.UpdatedAt = time.Now().UTC()
fake.Instruments[remote.InstrumentUID] = remote
fake.Statuses[remote.InstrumentUID] = domain.TradingStatusNormal
}
return nil
}
func accountHash(accountID string) string { func accountHash(accountID string) string {
sum := sha256.Sum256([]byte(accountID)) sum := sha256.Sum256([]byte(accountID))
return hex.EncodeToString(sum[:]) return hex.EncodeToString(sum[:])
+34
View File
@@ -5,6 +5,12 @@ import (
"context" "context"
"strings" "strings"
"testing" "testing"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/testutil"
"overnight-trading-bot/internal/tinvest"
) )
func TestRunRequiresAppMode(t *testing.T) { func TestRunRequiresAppMode(t *testing.T) {
@@ -29,3 +35,31 @@ func TestRunBacktestModeWithoutDB(t *testing.T) {
t.Fatalf("unexpected stdout: %s", stdout.String()) t.Fatalf("unexpected stdout: %s", stdout.String())
} }
} }
func TestSeedPaperGatewayMakesSeedInstrumentsDiscoverable(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
if err := repo.UpsertInstrument(ctx, domain.Instrument{
InstrumentUID: "PENDING:TRUR",
Ticker: "TRUR",
ClassCode: "TQTF",
Name: "TRUR",
Lot: 1,
MinPriceIncrement: decimal.RequireFromString("0.0001"),
Currency: "RUB",
Enabled: true,
}); err != nil {
t.Fatal(err)
}
gateway := tinvest.NewFakeGateway()
if err := seedPaperGateway(ctx, repo, gateway); err != nil {
t.Fatal(err)
}
instrument, err := gateway.GetInstrument(ctx, "TRUR", "TQTF")
if err != nil {
t.Fatal(err)
}
if !instrument.MetadataValid() || strings.HasPrefix(instrument.InstrumentUID, "PENDING:") {
t.Fatalf("instrument was not made runnable for paper: %+v", instrument)
}
}
+14 -3
View File
@@ -22,6 +22,7 @@ type Config struct {
EntrySlippageBps decimal.Decimal EntrySlippageBps decimal.Decimal
ExitSlippageBps decimal.Decimal ExitSlippageBps decimal.Decimal
CommissionRoundtripBps decimal.Decimal CommissionRoundtripBps decimal.Decimal
RiskBufferBps decimal.Decimal
InitialEquity decimal.Decimal InitialEquity decimal.Decimal
OutputDir string OutputDir string
RollingShort int RollingShort int
@@ -122,6 +123,15 @@ func (cfg Config) withDefaults() Config {
if cfg.MaxTickBps.IsZero() { if cfg.MaxTickBps.IsZero() {
cfg.MaxTickBps = decimal.NewFromInt(10) cfg.MaxTickBps = decimal.NewFromInt(10)
} }
if cfg.RiskBufferBps.IsZero() {
cfg.RiskBufferBps = decimal.NewFromInt(5)
}
if cfg.AssumedSpreadBps.IsZero() {
cfg.AssumedSpreadBps = cfg.MaxSpreadBps
}
if cfg.AssumedTickBps.IsZero() {
cfg.AssumedTickBps = cfg.MaxTickBps
}
if !cfg.RequireZeroCommission && cfg.CommissionRoundtripBps.IsZero() { if !cfg.RequireZeroCommission && cfg.CommissionRoundtripBps.IsZero() {
cfg.RequireZeroCommission = true cfg.RequireZeroCommission = true
} }
@@ -169,7 +179,7 @@ func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Can
tradingDateSet := make(map[string]struct{}) tradingDateSet := make(map[string]struct{})
for instrumentUID, candles := range prepared { for instrumentUID, candles := range prepared {
for i := 1; i < len(candles); i++ { for i := 1; i < len(candles); i++ {
if i >= e.cfg.RollingShort { if i >= max(e.cfg.RollingShort, e.cfg.RollingLong) {
tradingDateSet[candles[i].TradeDate.Format("2006-01-02")] = struct{}{} tradingDateSet[candles[i].TradeDate.Format("2006-01-02")] = struct{}{}
} }
candidate, ok, err := e.evaluateCandidate(instrumentUID, candles, i) candidate, ok, err := e.evaluateCandidate(instrumentUID, candles, i)
@@ -366,7 +376,7 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
returns = append(returns, rf) returns = append(returns, rf)
} }
short := features.Rolling(returns, e.cfg.RollingShort, e.cfg.EWMALambda) short := features.Rolling(returns, e.cfg.RollingShort, e.cfg.EWMALambda)
long := features.Rolling(returns, min(e.cfg.RollingLong, len(returns)), e.cfg.EWMALambda) long := features.Rolling(returns, e.cfg.RollingLong, e.cfg.EWMALambda)
if !short.Available || !long.Available || short.StdDev == 0 { if !short.Available || !long.Available || short.StdDev == 0 {
return candidate{}, false, nil return candidate{}, false, nil
} }
@@ -374,7 +384,8 @@ func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle,
cost := e.cfg.AssumedSpreadBps. cost := e.cfg.AssumedSpreadBps.
Add(e.cfg.EntrySlippageBps). Add(e.cfg.EntrySlippageBps).
Add(e.cfg.ExitSlippageBps). Add(e.cfg.ExitSlippageBps).
Add(e.cfg.CommissionRoundtripBps) Add(e.cfg.CommissionRoundtripBps).
Add(e.cfg.RiskBufferBps)
netEdge := rawEdge.Sub(cost) netEdge := rawEdge.Sub(cost)
adv := features.ADV(history, e.cfg.Lot, 20) adv := features.ADV(history, e.cfg.Lot, 20)
switch { switch {
+6
View File
@@ -88,6 +88,7 @@ type ExecutionConfig struct {
ExitWindowStart timeutil.TimeOfDay `env:"EXIT_WINDOW_START" envDefault:"10:05:00"` ExitWindowStart timeutil.TimeOfDay `env:"EXIT_WINDOW_START" envDefault:"10:05:00"`
ExitWindowEnd timeutil.TimeOfDay `env:"EXIT_WINDOW_END" envDefault:"10:25:00"` ExitWindowEnd timeutil.TimeOfDay `env:"EXIT_WINDOW_END" envDefault:"10:25:00"`
HardExitDeadline timeutil.TimeOfDay `env:"HARD_EXIT_DEADLINE" envDefault:"10:45:00"` HardExitDeadline timeutil.TimeOfDay `env:"HARD_EXIT_DEADLINE" envDefault:"10:45:00"`
MarketClose timeutil.TimeOfDay `env:"MARKET_CLOSE" envDefault:"18:50:00"`
MinTimeToCloseSec int `env:"MIN_TIME_TO_CLOSE_SEC" envDefault:"90"` MinTimeToCloseSec int `env:"MIN_TIME_TO_CLOSE_SEC" envDefault:"90"`
AllowMarketOrders bool `env:"ALLOW_MARKET_ORDERS" envDefault:"false"` AllowMarketOrders bool `env:"ALLOW_MARKET_ORDERS" envDefault:"false"`
MaxEntryOrderAttempts int `env:"MAX_ENTRY_ORDER_ATTEMPTS" envDefault:"3"` MaxEntryOrderAttempts int `env:"MAX_ENTRY_ORDER_ATTEMPTS" envDefault:"3"`
@@ -237,5 +238,10 @@ func (c Config) validateWindows() error {
c.Execution.ExitWindowEnd.Duration > c.Execution.HardExitDeadline.Duration { c.Execution.ExitWindowEnd.Duration > c.Execution.HardExitDeadline.Duration {
return errors.New("exit windows must be monotonic from EXEC_EXIT_WATCH_START to EXEC_HARD_EXIT_DEADLINE") return errors.New("exit windows must be monotonic from EXEC_EXIT_WATCH_START to EXEC_HARD_EXIT_DEADLINE")
} }
if c.Execution.MarketClose.Duration > 0 &&
(c.Execution.MarketClose.Duration <= c.Execution.NoNewEntryAfter.Duration ||
c.Execution.MarketClose.Duration <= c.Execution.HardExitDeadline.Duration) {
return errors.New("EXEC_MARKET_CLOSE must be after entry and exit trading windows")
}
return nil return nil
} }
+164 -24
View File
@@ -40,6 +40,7 @@ type MonitorConfig struct {
Instrument domain.Instrument Instrument domain.Instrument
ImproveTicks int ImproveTicks int
Quote func(ctx context.Context, instrumentUID string) (domain.OrderBook, error) Quote func(ctx context.Context, instrumentUID string) (domain.OrderBook, error)
RepostCheck func(ctx context.Context, order domain.Order, instrument domain.Instrument, book domain.OrderBook) error
} }
func NewEngine(mode domain.Mode, accountID string, gateway Gateway, store repository.Repository) Engine { func NewEngine(mode domain.Mode, accountID string, gateway Gateway, store repository.Repository) Engine {
@@ -105,6 +106,9 @@ func (e *Engine) PlaceExit(ctx context.Context, accountIDHash string, instrument
} }
func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Order, error) { func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Order, error) {
lock := e.lockFor(order.InstrumentUID)
lock.Lock()
defer lock.Unlock()
if e.store != nil { if e.store != nil {
existing, err := e.findExisting(ctx, order) existing, err := e.findExisting(ctx, order)
if err != nil { if err != nil {
@@ -127,15 +131,25 @@ func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Ord
if e.gateway == nil { if e.gateway == nil {
return domain.Order{}, errors.New("gateway is nil") return domain.Order{}, errors.New("gateway is nil")
} }
lock := e.lockFor(order.InstrumentUID)
lock.Lock()
defer lock.Unlock()
now := time.Now().UTC()
draft := order
draft.Status = domain.OrderStatusSent
draft.CreatedAt = now
draft.UpdatedAt = now
if draft.RawStateJSON == "" {
draft.RawStateJSON = "{}"
}
if e.store != nil {
if err := e.store.UpsertOrder(ctx, draft); err != nil {
return domain.Order{}, fmt.Errorf("persist draft order: %w", err)
}
}
posted, err := e.gateway.PostLimitOrder(ctx, e.accountID, order.InstrumentUID, order.Side, order.QuantityLots, order.LimitPrice, order.ClientOrderID) posted, err := e.gateway.PostLimitOrder(ctx, e.accountID, order.InstrumentUID, order.Side, order.QuantityLots, order.LimitPrice, order.ClientOrderID)
if err != nil { if err != nil {
order.Status = domain.OrderStatusFailed draft.Status = domain.OrderStatusFailed
if e.store != nil { if e.store != nil {
_ = e.store.UpsertOrder(ctx, order) _ = e.store.UpsertOrder(ctx, draft)
} }
return domain.Order{}, err return domain.Order{}, err
} }
@@ -148,7 +162,7 @@ func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Ord
posted.QuantityLots = order.QuantityLots posted.QuantityLots = order.QuantityLots
posted.AttemptNo = order.AttemptNo posted.AttemptNo = order.AttemptNo
posted.TradeDate = order.TradeDate posted.TradeDate = order.TradeDate
posted.CreatedAt = time.Now().UTC() posted.CreatedAt = now
posted.UpdatedAt = posted.CreatedAt posted.UpdatedAt = posted.CreatedAt
if e.store != nil { if e.store != nil {
if err := e.store.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error { if err := e.store.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
@@ -191,9 +205,7 @@ func (e *Engine) findExisting(ctx context.Context, order domain.Order) (domain.O
return domain.Order{}, err return domain.Order{}, err
} }
for _, existing := range orders { for _, existing := range orders {
if existing.ClientOrderID == order.ClientOrderID && if existing.ClientOrderID == order.ClientOrderID {
existing.Status != domain.OrderStatusFailed &&
existing.Status != domain.OrderStatusRejected {
return existing, nil return existing, nil
} }
} }
@@ -294,12 +306,14 @@ func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg Monit
aggregate.FilledLots < aggregate.QuantityLots && aggregate.FilledLots < aggregate.QuantityLots &&
cfg.Quote != nil cfg.Quote != nil
if shouldRepost { if shouldRepost {
next, err := e.repost(ctx, current, cfg, aggregate.QuantityLots-aggregate.FilledLots) next, reposted, err := e.repost(ctx, current, cfg, aggregate.QuantityLots-aggregate.FilledLots)
if err != nil { if err != nil {
return aggregate, err return aggregate, err
} }
current = next if reposted {
seen[current.ClientOrderID] = current current = next
seen[current.ClientOrderID] = current
}
lastPost = time.Now() lastPost = time.Now()
continue continue
} }
@@ -311,32 +325,158 @@ func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg Monit
} }
} }
func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConfig, remaining int64) (domain.Order, error) { func (e *Engine) MonitorOnce(ctx context.Context, order domain.Order, cfg MonitorConfig) (domain.Order, error) {
if cfg.PollInterval <= 0 {
cfg.PollInterval = 500 * time.Millisecond
}
if cfg.MaxAttempts <= 0 {
cfg.MaxAttempts = 1
}
previous := order
refreshed, err := e.Refresh(ctx, order)
if err != nil {
return order, err
}
aggregate := mergeAggregateFill(order, previous, refreshed)
current := mergeOrderState(order, refreshed)
aggregate.Status = current.Status
aggregate.UpdatedAt = current.UpdatedAt
aggregate.RawStateJSON = current.RawStateJSON
if aggregate.FilledLots >= aggregate.QuantityLots {
aggregate.Status = domain.OrderStatusFilled
return aggregate, nil
}
if isTerminal(current.Status) {
return aggregate, nil
}
if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) {
if err := e.Cancel(ctx, current); err != nil {
return aggregate, err
}
aggregate.Status = domain.OrderStatusExpired
if e.store != nil {
if err := e.store.UpdateOrderStatus(ctx, current.ClientOrderID, aggregate.Status, current.FilledLots, current.RawStateJSON); err != nil {
return aggregate, err
}
}
return aggregate, nil
}
shouldRepost := cfg.RepostAfter > 0 &&
repostDue(current, cfg.RepostAfter) &&
current.AttemptNo < cfg.MaxAttempts &&
aggregate.FilledLots < aggregate.QuantityLots &&
cfg.Quote != nil
if shouldRepost {
next, reposted, err := e.repost(ctx, current, cfg, aggregate.QuantityLots-aggregate.FilledLots)
if err != nil {
return aggregate, err
}
if reposted {
aggregate.BrokerOrderID = next.BrokerOrderID
aggregate.ClientOrderID = next.ClientOrderID
aggregate.Status = next.Status
aggregate.RawStateJSON = next.RawStateJSON
aggregate.UpdatedAt = next.UpdatedAt
}
}
return aggregate, nil
}
func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConfig, remaining int64) (domain.Order, bool, error) {
if err := e.ensureRepostBudget(ctx, order, cfg.Instrument); err != nil { if err := e.ensureRepostBudget(ctx, order, cfg.Instrument); err != nil {
return domain.Order{}, err return domain.Order{}, false, err
} }
if err := e.Cancel(ctx, order); err != nil { if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) {
return domain.Order{}, err return order, false, nil
}
if remaining <= 0 {
order.Status = domain.OrderStatusFilled
return order, nil
} }
book, err := cfg.Quote(ctx, order.InstrumentUID) book, err := cfg.Quote(ctx, order.InstrumentUID)
if err != nil { if err != nil {
return domain.Order{}, err return domain.Order{}, false, err
}
if cfg.RepostCheck != nil {
if err := cfg.RepostCheck(ctx, order, cfg.Instrument, book); err != nil {
return order, false, nil
}
}
if err := e.Cancel(ctx, order); err != nil {
return domain.Order{}, false, err
}
cancelled, err := e.waitTerminal(ctx, order, cfg)
if err != nil {
return domain.Order{}, false, err
}
if remaining <= 0 {
cancelled.Status = domain.OrderStatusFilled
return cancelled, true, nil
}
if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) {
return cancelled, true, nil
}
book, err = cfg.Quote(ctx, order.InstrumentUID)
if err != nil {
return domain.Order{}, false, err
}
if cfg.RepostCheck != nil {
if err := cfg.RepostCheck(ctx, order, cfg.Instrument, book); err != nil {
return cancelled, true, nil
}
} }
attempt := order.AttemptNo + 1 attempt := order.AttemptNo + 1
switch order.Side { switch order.Side {
case domain.SideBuy: case domain.SideBuy:
return e.PlaceEntry(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt) next, err := e.PlaceEntry(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt)
return next, true, err
case domain.SideSell: case domain.SideSell:
return e.PlaceExit(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt) next, err := e.PlaceExit(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt)
return next, true, err
default: default:
return domain.Order{}, fmt.Errorf("unsupported side %s", order.Side) return domain.Order{}, false, fmt.Errorf("unsupported side %s", order.Side)
} }
} }
func (e *Engine) waitTerminal(ctx context.Context, order domain.Order, cfg MonitorConfig) (domain.Order, error) {
current := order
for {
refreshed, err := e.Refresh(ctx, current)
if err != nil {
return domain.Order{}, err
}
current = mergeOrderState(current, refreshed)
if isTerminal(current.Status) {
return current, nil
}
if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) {
return current, nil
}
timer := time.NewTimer(cfg.PollInterval)
select {
case <-ctx.Done():
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
return domain.Order{}, ctx.Err()
case <-timer.C:
}
}
}
func repostDue(order domain.Order, after time.Duration) bool {
if after <= 0 {
return false
}
basis := order.CreatedAt
if basis.IsZero() {
basis = order.UpdatedAt
}
if basis.IsZero() {
return true
}
return time.Since(basis) >= after
}
func (e *Engine) ensureRepostBudget(ctx context.Context, order domain.Order, instrument domain.Instrument) error { func (e *Engine) ensureRepostBudget(ctx context.Context, order domain.Order, instrument domain.Instrument) error {
if e.store == nil || instrument.FreeOrderLimitPerDay <= 0 { if e.store == nil || instrument.FreeOrderLimitPerDay <= 0 {
return nil return nil
+47
View File
@@ -165,3 +165,50 @@ func TestMonitorUntilRepostsAndExpiresAtDeadline(t *testing.T) {
t.Fatalf("free order counter=%d, want 2", sent) t.Fatalf("free order counter=%d, want 2", sent)
} }
} }
func TestMonitorOnceDoesNotRepostWhenCheckRejects(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
engine := NewEngine(domain.ModeSandbox, "account", gateway, repo)
instrument := domain.Instrument{
InstrumentUID: "uid",
Lot: 1,
MinPriceIncrement: decimal.NewFromInt(1),
}
book := domain.OrderBook{
InstrumentUID: "uid",
Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(99), QuantityLots: 10}},
Asks: []domain.OrderBookLevel{{Price: decimal.NewFromInt(101), QuantityLots: 10}},
ReceivedAt: time.Now().UTC(),
}
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
order, err := engine.PlaceEntry(ctx, "hash", instrument, tradeDate, 3, book, 1, 1)
if err != nil {
t.Fatal(err)
}
order.CreatedAt = time.Now().UTC().Add(-time.Minute)
if err := repo.UpsertOrder(ctx, order); err != nil {
t.Fatal(err)
}
if _, err := engine.MonitorOnce(ctx, order, MonitorConfig{
Deadline: time.Now().Add(time.Minute),
PollInterval: time.Millisecond,
MaxAttempts: 2,
RepostAfter: time.Second,
Instrument: instrument,
ImproveTicks: 1,
Quote: func(context.Context, string) (domain.OrderBook, error) {
book.ReceivedAt = time.Now().UTC()
return book, nil
},
RepostCheck: func(context.Context, domain.Order, domain.Instrument, domain.OrderBook) error {
return context.Canceled
},
}); err != nil {
t.Fatal(err)
}
if got := len(gateway.Orders); got != 1 {
t.Fatalf("broker orders=%d, want no repost", got)
}
}
-6
View File
@@ -108,12 +108,6 @@ func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate ti
long := Rolling(overnight, cfg.RollingLong, cfg.EWMALambda) long := Rolling(overnight, cfg.RollingLong, cfg.EWMALambda)
adv := ADV(candles, instrument.Lot, 20) adv := ADV(candles, instrument.Lot, 20)
rawEdgeBps := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000)) rawEdgeBps := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000))
if !entryVolume.IsPositive() {
entryVolume = adv
}
if !exitVolume.IsPositive() {
exitVolume = adv
}
instrumentCommission := instrument.ExpectedCommissionBpsPerSide.Mul(decimal.NewFromInt(2)) instrumentCommission := instrument.ExpectedCommissionBpsPerSide.Mul(decimal.NewFromInt(2))
expectedCost := spread.SpreadBps. expectedCost := spread.SpreadBps.
Add(cfg.EntrySlippageBps). Add(cfg.EntrySlippageBps).
+40 -3
View File
@@ -13,6 +13,8 @@ import (
"overnight-trading-bot/internal/domain" "overnight-trading-bot/internal/domain"
) )
const mustDeliverEnqueueTimeout = 2 * time.Second
type Notifier interface { type Notifier interface {
Info(ctx context.Context, msg string) error Info(ctx context.Context, msg string) error
Warn(ctx context.Context, msg string) error Warn(ctx context.Context, msg string) error
@@ -53,8 +55,9 @@ type Telegram struct {
} }
type outbound struct { type outbound struct {
level domain.Severity level domain.Severity
text string text string
mustDeliver bool
} }
func NewTelegram(cfg TelegramConfig, log *slog.Logger) (Notifier, error) { func NewTelegram(cfg TelegramConfig, log *slog.Logger) (Notifier, error) {
@@ -119,13 +122,19 @@ func (t *Telegram) enqueue(ctx context.Context, level domain.Severity, msg strin
} }
func (t *Telegram) enqueueText(ctx context.Context, level domain.Severity, text string, mustDeliver bool) error { func (t *Telegram) enqueueText(ctx context.Context, level domain.Severity, text string, mustDeliver bool) error {
item := outbound{level: level, text: text} item := outbound{level: level, text: text, mustDeliver: mustDeliver}
if mustDeliver { if mustDeliver {
timer := time.NewTimer(mustDeliverEnqueueTimeout)
defer timer.Stop()
select { select {
case t.queue <- item: case t.queue <- item:
return nil return nil
case <-ctx.Done(): case <-ctx.Done():
t.auditNotificationFailure(context.Background(), item, "notification_context_cancelled", ctx.Err().Error())
return ctx.Err() return ctx.Err()
case <-timer.C:
t.auditNotificationFailure(ctx, item, "notification_undeliverable", "telegram queue full")
return nil
} }
} }
select { select {
@@ -168,8 +177,10 @@ func (t *Telegram) dispatch() {
func (t *Telegram) send(item outbound) { func (t *Telegram) send(item outbound) {
msg := tgbotapi.NewMessage(t.cfg.ChatID, item.text) msg := tgbotapi.NewMessage(t.cfg.ChatID, item.text)
var lastErr error
for attempt := 0; attempt < 3; attempt++ { for attempt := 0; attempt < 3; attempt++ {
if _, err := t.bot.Send(msg); err != nil { if _, err := t.bot.Send(msg); err != nil {
lastErr = err
delay := telegramRetryDelay(err, attempt) delay := telegramRetryDelay(err, attempt)
if t.log != nil { if t.log != nil {
t.log.Warn("telegram send failed", "attempt", attempt+1, "err", err, "retry_in", delay) t.log.Warn("telegram send failed", "attempt", attempt+1, "err", err, "retry_in", delay)
@@ -190,6 +201,32 @@ func (t *Telegram) send(item outbound) {
} }
return return
} }
if item.mustDeliver {
message := "telegram send failed"
if lastErr != nil {
message = lastErr.Error()
}
t.auditNotificationFailure(context.Background(), item, "notification_undeliverable", message)
}
}
func (t *Telegram) auditNotificationFailure(ctx context.Context, item outbound, eventType, message string) {
if t.cfg.AuditSink == nil {
return
}
severity := domain.SeverityWarn
if item.mustDeliver {
severity = domain.SeverityCritical
}
if err := t.cfg.AuditSink.InsertRiskEvent(ctx, domain.RiskEvent{
TS: time.Now().UTC(),
Severity: severity,
EventType: eventType,
Message: message,
ContextJSON: fmt.Sprintf(`{"level":%q}`, item.level),
}); err != nil && t.log != nil {
t.log.Warn("telegram audit fallback failed", "err", err)
}
} }
func telegramRetryDelay(err error, attempt int) time.Duration { func telegramRetryDelay(err error, attempt int) time.Duration {
+12 -3
View File
@@ -15,6 +15,8 @@ import (
"overnight-trading-bot/internal/tinvest" "overnight-trading-bot/internal/tinvest"
) )
var defaultCommissionTolerance = decimal.RequireFromString("0.01")
type Engine struct { type Engine struct {
repo repository.Repository repo repository.Repository
gateway tinvest.Gateway gateway tinvest.Gateway
@@ -34,7 +36,7 @@ func New(repo repository.Repository, gateway tinvest.Gateway, accountID, account
accountID: accountID, accountID: accountID,
accountIDHash: accountIDHash, accountIDHash: accountIDHash,
window: 72 * time.Hour, window: 72 * time.Hour,
commissionTolerance: decimal.NewFromFloat(0.01), commissionTolerance: defaultCommissionTolerance,
} }
} }
@@ -164,7 +166,14 @@ func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
continue continue
} }
if err := e.repo.QuarantineInstrument(ctx, diff.InstrumentUID, diff.Message); err != nil { if err := e.repo.QuarantineInstrument(ctx, diff.InstrumentUID, diff.Message); err != nil {
return nil, err _ = e.repo.InsertRiskEvent(ctx, domain.RiskEvent{
TS: time.Now().UTC(),
Severity: domain.SeverityCritical,
EventType: "quarantine_failed",
InstrumentUID: diff.InstrumentUID,
Message: err.Error(),
ContextJSON: fmt.Sprintf(`{"reconciliation_diff":%q}`, diff.Message),
})
} }
} }
} }
@@ -192,7 +201,7 @@ func HasCritical(diffs []domain.ReconciliationDiff) bool {
} }
func compareOperations(orders []domain.Order, operations []domain.Operation) []domain.ReconciliationDiff { func compareOperations(orders []domain.Order, operations []domain.Operation) []domain.ReconciliationDiff {
return compareOperationsWithPolicy(orders, operations, false, decimal.NewFromFloat(0.01)) return compareOperationsWithPolicy(orders, operations, false, defaultCommissionTolerance)
} }
func compareOperationsWithPolicy(orders []domain.Order, operations []domain.Operation, requireZeroCommission bool, commissionTolerance decimal.Decimal) []domain.ReconciliationDiff { func compareOperationsWithPolicy(orders []domain.Order, operations []domain.Operation, requireZeroCommission bool, commissionTolerance decimal.Decimal) []domain.ReconciliationDiff {
+20
View File
@@ -615,6 +615,23 @@ func (r *Repository) SaveSystemState(ctx context.Context, state domain.SystemSta
_, err := r.execer().ExecContext(ctx, ` _, err := r.execer().ExecContext(ctx, `
INSERT INTO system_state (id, state, mode, halted, halt_reason, last_heartbeat, context_json) INSERT INTO system_state (id, state, mode, halted, halt_reason, last_heartbeat, context_json)
VALUES (1, ?, ?, ?, ?, UTC_TIMESTAMP(3), ?) VALUES (1, ?, ?, ?, ?, UTC_TIMESTAMP(3), ?)
ON DUPLICATE KEY UPDATE
state=IF(halted=1 AND VALUES(halted)=0, state, VALUES(state)),
mode=VALUES(mode),
halted=IF(halted=1 AND VALUES(halted)=0, halted, VALUES(halted)),
halt_reason=IF(halted=1 AND VALUES(halted)=0, halt_reason, VALUES(halt_reason)),
last_heartbeat=VALUES(last_heartbeat),
context_json=VALUES(context_json)`, state, mode, halted, nullableString(reason), contextJSON)
return err
}
func (r *Repository) forceSaveSystemState(ctx context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, contextJSON string) error {
if contextJSON == "" {
contextJSON = "{}"
}
_, err := r.execer().ExecContext(ctx, `
INSERT INTO system_state (id, state, mode, halted, halt_reason, last_heartbeat, context_json)
VALUES (1, ?, ?, ?, ?, UTC_TIMESTAMP(3), ?)
ON DUPLICATE KEY UPDATE ON DUPLICATE KEY UPDATE
state=VALUES(state), mode=VALUES(mode), halted=VALUES(halted), state=VALUES(state), mode=VALUES(mode), halted=VALUES(halted),
halt_reason=VALUES(halt_reason), last_heartbeat=VALUES(last_heartbeat), halt_reason=VALUES(halt_reason), last_heartbeat=VALUES(last_heartbeat),
@@ -647,6 +664,9 @@ func (r *Repository) Unhalt(ctx context.Context, reason string) error {
} }
mode = currentMode mode = currentMode
} }
if txRepo, ok := repo.(*Repository); ok {
return txRepo.forceSaveSystemState(ctx, domain.StateInit, mode, false, "", `{"manual_unhalt":true}`)
}
return repo.SaveSystemState(ctx, domain.StateInit, mode, false, "", `{"manual_unhalt":true}`) return repo.SaveSystemState(ctx, domain.StateInit, mode, false, "", `{"manual_unhalt":true}`)
}) })
} }
+13 -3
View File
@@ -29,6 +29,8 @@ type SizingConfig struct {
type SizingInput struct { type SizingInput struct {
Portfolio domain.Portfolio Portfolio domain.Portfolio
SelectedInstruments int SelectedInstruments int
ExistingExposure decimal.Decimal
ReservedCash decimal.Decimal
LimitPrice decimal.Decimal LimitPrice decimal.Decimal
Lot int64 Lot int64
EntryIntervalVolume decimal.Decimal EntryIntervalVolume decimal.Decimal
@@ -66,11 +68,19 @@ func (s Sizer) Size(input SizingInput) SizingResult {
input.SelectedInstruments = 1 input.SelectedInstruments = 1
} }
capLimit := input.Portfolio.Equity.Mul(s.cfg.MaxPositionPct) capLimit := input.Portfolio.Equity.Mul(s.cfg.MaxPositionPct)
exposureLimit := input.Portfolio.Equity.Mul(s.cfg.MaxTotalExposurePct). totalExposureLimit := input.Portfolio.Equity.Mul(s.cfg.MaxTotalExposurePct)
Div(decimal.NewFromInt(int64(input.SelectedInstruments))) remainingExposure := totalExposureLimit.Sub(input.ExistingExposure)
if remainingExposure.IsNegative() {
remainingExposure = decimal.Zero
}
exposureLimit := remainingExposure.Div(decimal.NewFromInt(int64(input.SelectedInstruments)))
liquidityLimit := money.Min(input.EntryIntervalVolume, input.ExitIntervalVolume). liquidityLimit := money.Min(input.EntryIntervalVolume, input.ExitIntervalVolume).
Mul(s.cfg.MaxParticipationRate) Mul(s.cfg.MaxParticipationRate)
cashLimit := input.Portfolio.Cash.Mul(s.cfg.CashUsageBuffer) availableCash := input.Portfolio.Cash.Sub(input.ReservedCash)
if availableCash.IsNegative() {
availableCash = decimal.Zero
}
cashLimit := availableCash.Mul(s.cfg.CashUsageBuffer)
riskLimit := capLimit riskLimit := capLimit
if input.Q05OvernightAbs.IsPositive() { if input.Q05OvernightAbs.IsPositive() {
riskBudget := input.Portfolio.Equity.Mul(s.cfg.RiskBudgetPerInstrumentPct) riskBudget := input.Portfolio.Equity.Mul(s.cfg.RiskBudgetPerInstrumentPct)
+25
View File
@@ -170,3 +170,28 @@ func TestSizerAppliesSizeReductionFactor(t *testing.T) {
t.Fatalf("unexpected reduced sizing: %+v", got) t.Fatalf("unexpected reduced sizing: %+v", got)
} }
} }
func TestSizerSubtractsExistingExposureAndReservedCash(t *testing.T) {
sizer := NewSizer(SizingConfig{
MaxPositionPct: rd("1"),
MaxTotalExposurePct: rd("0.50"),
MaxParticipationRate: rd("1"),
CashUsageBuffer: rd("1"),
RiskBudgetPerInstrumentPct: rd("1"),
MinOrderNotionalRUB: rd("1"),
})
got := sizer.Size(SizingInput{
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("50000")},
SelectedInstruments: 2,
ExistingExposure: rd("30000"),
ReservedCash: rd("10000"),
LimitPrice: rd("100"),
Lot: 1,
EntryIntervalVolume: rd("1000000"),
ExitIntervalVolume: rd("1000000"),
Q05OvernightAbs: rd("1"),
})
if got.Lots != 100 || !got.TargetNotional.Equal(rd("10000")) {
t.Fatalf("unexpected sizing with reserved exposure: %+v", got)
}
}
+486 -64
View File
@@ -48,6 +48,7 @@ type Config struct {
ExitWindowStart timeutil.TimeOfDay ExitWindowStart timeutil.TimeOfDay
ExitWindowEnd timeutil.TimeOfDay ExitWindowEnd timeutil.TimeOfDay
HardExitDeadline timeutil.TimeOfDay HardExitDeadline timeutil.TimeOfDay
MarketClose timeutil.TimeOfDay
QuoteDepth int32 QuoteDepth int32
MaxQuoteAge time.Duration MaxQuoteAge time.Duration
OrderPollInterval time.Duration OrderPollInterval time.Duration
@@ -60,6 +61,7 @@ type Config struct {
RequireZeroCommission bool RequireZeroCommission bool
QuarantineOnNonZero bool QuarantineOnNonZero bool
ReconciliationInterval time.Duration ReconciliationInterval time.Duration
MaxOpenPositions int
} }
type Services struct { type Services struct {
@@ -91,6 +93,13 @@ type Scheduler struct {
lastReconciledAt time.Time lastReconciledAt time.Time
} }
type signalCandidate struct {
Signal domain.Signal
Instrument domain.Instrument
Feature domain.FeatureSet
Book domain.OrderBook
}
func New(clock timeutil.Clock, sm statemachine.System, cfg Config, svc Services) Scheduler { func New(clock timeutil.Clock, sm statemachine.System, cfg Config, svc Services) Scheduler {
if cfg.TickInterval <= 0 { if cfg.TickInterval <= 0 {
cfg.TickInterval = 30 * time.Second cfg.TickInterval = 30 * time.Second
@@ -205,22 +214,39 @@ func (s *Scheduler) prepareSignals(ctx context.Context, now time.Time) error {
if err != nil { if err != nil {
return err return err
} }
instrumentByUID := make(map[string]domain.Instrument, len(instrumentsList))
for _, instrument := range instrumentsList { for _, instrument := range instrumentsList {
if err := s.generateInstrumentSignal(ctx, now, tradeDate, portfolio, len(openPositions), instrument); err != nil { instrumentByUID[instrument.InstrumentUID] = instrument
}
existingExposure := positionsExposure(openPositions, instrumentByUID, portfolio)
generated := make([]signalCandidate, 0, len(instrumentsList))
for _, instrument := range instrumentsList {
candidate, err := s.generateInstrumentSignal(ctx, tradeDate, len(openPositions), instrument)
if err != nil {
return err
}
generated = append(generated, candidate)
}
s.applyBatchSignalLimits(portfolio, existingExposure, len(openPositions), generated)
for _, candidate := range generated {
if err := s.svc.Repo.UpsertSignal(ctx, candidate.Signal); err != nil {
return err
}
if err := s.notifySignal(ctx, now, candidate.Signal); err != nil {
return err return err
} }
} }
return s.transitionTo(ctx, domain.StateWaitEntryWindow) return s.transitionTo(ctx, domain.StateWaitEntryWindow)
} }
func (s Scheduler) generateInstrumentSignal(ctx context.Context, now, tradeDate time.Time, portfolio domain.Portfolio, openPositionCount int, instrument domain.Instrument) error { func (s Scheduler) generateInstrumentSignal(ctx context.Context, tradeDate time.Time, openPositionCount int, instrument domain.Instrument) (signalCandidate, error) {
book, err := s.svc.MarketData.LatestQuote(ctx, instrument.InstrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge) book, err := s.svc.MarketData.LatestQuote(ctx, instrument.InstrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
if err != nil { if err != nil {
return s.saveRejectedSignal(ctx, tradeDate, instrument, "quote_unavailable", err) return s.rejectedSignal(tradeDate, instrument, "quote_unavailable", err), nil
} }
spread, err := spreadFromBook(book, instrument.MinPriceIncrement) spread, err := spreadFromBook(book, instrument.MinPriceIncrement)
if err != nil { if err != nil {
return s.saveRejectedSignal(ctx, tradeDate, instrument, "spread_unavailable", err) return s.rejectedSignal(tradeDate, instrument, "spread_unavailable", err), nil
} }
tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, instrument.InstrumentUID) tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, instrument.InstrumentUID)
if err != nil { if err != nil {
@@ -228,7 +254,7 @@ func (s Scheduler) generateInstrumentSignal(ctx context.Context, now, tradeDate
} }
feature, err := s.svc.Features.Recompute(ctx, instrument, tradeDate, spread) feature, err := s.svc.Features.Recompute(ctx, instrument, tradeDate, spread)
if err != nil { if err != nil {
return s.saveRejectedSignal(ctx, tradeDate, instrument, "features_unavailable", err) return s.rejectedSignal(tradeDate, instrument, "features_unavailable", err), nil
} }
remaining, err := s.svc.FreeOrders.Check(ctx, tradeDate, instrument, s.maxOrderAttemptsPerTrade()) remaining, err := s.svc.FreeOrders.Check(ctx, tradeDate, instrument, s.maxOrderAttemptsPerTrade())
freeOrderOK := err == nil freeOrderOK := err == nil
@@ -245,30 +271,10 @@ func (s Scheduler) generateInstrumentSignal(ctx context.Context, now, tradeDate
"spread_bps": spread.SpreadBps.String(), "spread_bps": spread.SpreadBps.String(),
}, },
}) })
if sig.Decision == domain.DecisionEnter { return signalCandidate{Signal: sig, Instrument: instrument, Feature: feature, Book: book}, nil
sized, sizingErr := s.sizeSignal(ctx, portfolio, instrument, feature, book, 1)
switch {
case sizingErr != nil:
sig.Decision = domain.DecisionReject
sig.RejectReason = sizingErr.Error()
case sized.Lots <= 0:
sig.Decision = domain.DecisionReject
if isSizingSkipReason(sized.Reason) {
sig.Decision = domain.DecisionSkip
}
sig.RejectReason = sized.Reason
default:
sig.TargetLots = sized.Lots
sig.TargetNotional = sized.TargetNotional
}
}
if err := s.svc.Repo.UpsertSignal(ctx, sig); err != nil {
return err
}
return s.notifySignal(ctx, now, sig)
} }
func (s Scheduler) saveRejectedSignal(ctx context.Context, tradeDate time.Time, instrument domain.Instrument, reason string, cause error) error { func (s Scheduler) rejectedSignal(tradeDate time.Time, instrument domain.Instrument, reason string, cause error) signalCandidate {
sig := domain.Signal{ sig := domain.Signal{
TradeDate: tradeDate, TradeDate: tradeDate,
InstrumentUID: instrument.InstrumentUID, InstrumentUID: instrument.InstrumentUID,
@@ -277,10 +283,63 @@ func (s Scheduler) saveRejectedSignal(ctx context.Context, tradeDate time.Time,
ContextJSON: fmt.Sprintf(`{"error":%q}`, cause.Error()), ContextJSON: fmt.Sprintf(`{"error":%q}`, cause.Error()),
CreatedAt: s.nowUTC(), CreatedAt: s.nowUTC(),
} }
return s.svc.Repo.UpsertSignal(ctx, sig) return signalCandidate{Signal: sig, Instrument: instrument}
} }
func (s Scheduler) sizeSignal(_ context.Context, portfolio domain.Portfolio, instrument domain.Instrument, feature domain.FeatureSet, book domain.OrderBook, selected int) (risk.SizingResult, error) { func (s Scheduler) applyBatchSignalLimits(portfolio domain.Portfolio, existingExposure decimal.Decimal, openPositionCount int, generated []signalCandidate) {
enterIndexes := make([]int, 0, len(generated))
for i := range generated {
if generated[i].Signal.Decision == domain.DecisionEnter {
enterIndexes = append(enterIndexes, i)
}
}
sort.SliceStable(enterIndexes, func(i, j int) bool {
left := generated[enterIndexes[i]].Signal
right := generated[enterIndexes[j]].Signal
if left.Score.Equal(right.Score) {
return left.InstrumentUID < right.InstrumentUID
}
return left.Score.GreaterThan(right.Score)
})
remainingSlots := len(enterIndexes)
if s.cfg.MaxOpenPositions > 0 {
remainingSlots = s.cfg.MaxOpenPositions - openPositionCount
if remainingSlots < 0 {
remainingSlots = 0
}
if remainingSlots > len(enterIndexes) {
remainingSlots = len(enterIndexes)
}
}
selectedCount := remainingSlots
for rank, index := range enterIndexes {
candidate := &generated[index]
if rank >= remainingSlots {
candidate.Signal.Decision = domain.DecisionSkip
candidate.Signal.TargetLots = 0
candidate.Signal.TargetNotional = decimal.Zero
candidate.Signal.RejectReason = signal.ReasonMaxPositions
continue
}
sized, sizingErr := s.sizeSignal(portfolio, candidate.Instrument, candidate.Feature, candidate.Book, selectedCount, existingExposure, decimal.Zero)
switch {
case sizingErr != nil:
candidate.Signal.Decision = domain.DecisionReject
candidate.Signal.RejectReason = sizingErr.Error()
case sized.Lots <= 0:
candidate.Signal.Decision = domain.DecisionReject
if isSizingSkipReason(sized.Reason) {
candidate.Signal.Decision = domain.DecisionSkip
}
candidate.Signal.RejectReason = sized.Reason
default:
candidate.Signal.TargetLots = sized.Lots
candidate.Signal.TargetNotional = sized.TargetNotional
}
}
}
func (s Scheduler) sizeSignal(portfolio domain.Portfolio, instrument domain.Instrument, feature domain.FeatureSet, book domain.OrderBook, selected int, existingExposure, reservedCash decimal.Decimal) (risk.SizingResult, error) {
bid, ask, err := bestBidAsk(book) bid, ask, err := bestBidAsk(book)
if err != nil { if err != nil {
return risk.SizingResult{}, err return risk.SizingResult{}, err
@@ -292,6 +351,8 @@ func (s Scheduler) sizeSignal(_ context.Context, portfolio domain.Portfolio, ins
return s.svc.Sizer.Size(risk.SizingInput{ return s.svc.Sizer.Size(risk.SizingInput{
Portfolio: portfolio, Portfolio: portfolio,
SelectedInstruments: selected, SelectedInstruments: selected,
ExistingExposure: existingExposure,
ReservedCash: reservedCash,
LimitPrice: price, LimitPrice: price,
Lot: instrument.Lot, Lot: instrument.Lot,
EntryIntervalVolume: feature.EntryIntervalVolume, EntryIntervalVolume: feature.EntryIntervalVolume,
@@ -313,6 +374,7 @@ func (s *Scheduler) placeEntryOrders(ctx context.Context, now time.Time) error {
if err != nil { if err != nil {
return err return err
} }
sortSignalsForEntry(signals)
existing, err := s.svc.Repo.ListOrders(ctx, s.svc.AccountIDHash, tradeDate, tradeDate) existing, err := s.svc.Repo.ListOrders(ctx, s.svc.AccountIDHash, tradeDate, tradeDate)
if err != nil { if err != nil {
return err return err
@@ -325,10 +387,26 @@ func (s *Scheduler) placeEntryOrders(ctx context.Context, now time.Time) error {
if err != nil { if err != nil {
return err return err
} }
portfolio, err := s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID)
if err != nil {
return err
}
baseExposure := positionsExposure(openPositions, instrumentByUID, portfolio)
pendingExposure := ordersExposure(existing, instrumentByUID, domain.SideBuy, true)
reservedCash := pendingExposure
projectedOpenPositions := len(openPositions) + countActiveOrders(existing, domain.SideBuy, tradeDate)
entryCandidates := entryOrderCandidates(signals, existing)
for _, sig := range signals { for _, sig := range signals {
if sig.Decision != domain.DecisionEnter || sig.TargetLots <= 0 || hasOrder(existing, sig.InstrumentUID, domain.SideBuy) { if sig.Decision != domain.DecisionEnter || sig.TargetLots <= 0 || hasOrder(existing, sig.InstrumentUID, domain.SideBuy) {
continue continue
} }
remainingSelections := remainingSignalCount(entryCandidates, sig.InstrumentUID)
if s.cfg.MaxOpenPositions > 0 && projectedOpenPositions >= s.cfg.MaxOpenPositions {
if err := s.recordPreTradeReject(ctx, sig.InstrumentUID, signal.ReasonMaxPositions, `{"reason":"max_positions_reached"}`); err != nil {
return err
}
continue
}
instrument, ok := instrumentByUID[sig.InstrumentUID] instrument, ok := instrumentByUID[sig.InstrumentUID]
if !ok { if !ok {
return fmt.Errorf("instrument %s is not in registry", sig.InstrumentUID) return fmt.Errorf("instrument %s is not in registry", sig.InstrumentUID)
@@ -352,38 +430,56 @@ func (s *Scheduler) placeEntryOrders(ctx context.Context, now time.Time) error {
if err != nil { if err != nil {
return err return err
} }
if err := s.checkSpreadBeforeOrder(ctx, instrument, book); err != nil {
if insertErr := s.recordPreTradeReject(ctx, sig.InstrumentUID, err.Error(), `{"reason":"spread_limit"}`); insertErr != nil {
return insertErr
}
continue
}
tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, sig.InstrumentUID) tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, sig.InstrumentUID)
if err != nil { if err != nil {
tradingStatus = domain.TradingStatusUnknown tradingStatus = domain.TradingStatusUnknown
} }
portfolio, err := s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID) portfolio, err = s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID)
if err != nil { if err != nil {
return err return err
} }
pre := s.svc.Risk.PreTradeCheck(risk.PreTradeInput{ feature, err := s.svc.Repo.GetFeature(ctx, sig.InstrumentUID, tradeDate)
Portfolio: portfolio, if err != nil {
OpenPositions: len(openPositions), return err
TradingStatus: tradingStatus, }
QuoteReceivedAt: book.ReceivedAt, sized, err := s.sizeSignal(portfolio, instrument, feature, book, remainingSelections, baseExposure.Add(pendingExposure), reservedCash)
Now: now.UTC(), if err != nil {
MarketClose: s.cfg.EntryWindowEnd.On(now, s.cfg.Location).UTC(), return err
}) }
if !pre.Allowed { lots := min(sig.TargetLots, sized.Lots)
if err := s.svc.Repo.InsertRiskEvent(ctx, domain.RiskEvent{ if lots <= 0 {
Severity: domain.SeverityWarn, reason := sized.Reason
EventType: "pre_trade_reject", if reason == "" {
InstrumentUID: sig.InstrumentUID, reason = risk.ErrNoSizingCapacity.Error()
Message: pre.Reason, }
ContextJSON: "{}", if err := s.recordPreTradeReject(ctx, sig.InstrumentUID, reason, `{"reason":"sizing"}`); err != nil {
}); err != nil {
return err return err
} }
continue continue
} }
placed, err := s.svc.Execution.PlaceEntry(ctx, s.svc.AccountIDHash, instrument, tradeDate, sig.TargetLots, book, s.cfg.PassiveImproveTicks, 1) pre, err := s.preTradeCheck(ctx, now, portfolio, projectedOpenPositions, tradingStatus, book.ReceivedAt)
if err != nil {
return err
}
if !pre.Allowed {
if err := s.recordPreTradeReject(ctx, sig.InstrumentUID, pre.Reason, "{}"); err != nil {
return err
}
continue
}
placed, err := s.svc.Execution.PlaceEntry(ctx, s.svc.AccountIDHash, instrument, tradeDate, lots, book, s.cfg.PassiveImproveTicks, 1)
if err != nil && !errors.Is(err, execution.ErrBrokerOrdersDisabled) { if err != nil && !errors.Is(err, execution.ErrBrokerOrdersDisabled) {
return err return err
} }
if errors.Is(err, execution.ErrBrokerOrdersDisabled) {
continue
}
_ = s.svc.Notifier.Info(ctx, fmt.Sprintf("entry order %s %s lots=%d status=%s", instrument.Ticker, placed.Side, placed.QuantityLots, placed.Status)) _ = s.svc.Notifier.Info(ctx, fmt.Sprintf("entry order %s %s lots=%d status=%s", instrument.Ticker, placed.Side, placed.QuantityLots, placed.Status))
if placed.FilledLots > 0 { if placed.FilledLots > 0 {
if err := s.recordEntryFill(ctx, instrument, placed); err != nil { if err := s.recordEntryFill(ctx, instrument, placed); err != nil {
@@ -391,6 +487,10 @@ func (s *Scheduler) placeEntryOrders(ctx context.Context, now time.Time) error {
} }
} }
existing = append(existing, placed) existing = append(existing, placed)
notional := orderNotional(placed, instrument)
pendingExposure = pendingExposure.Add(notional)
reservedCash = reservedCash.Add(notional)
projectedOpenPositions++
} }
return s.transitionTo(ctx, domain.StateMonitorEntryOrders) return s.transitionTo(ctx, domain.StateMonitorEntryOrders)
} }
@@ -411,15 +511,16 @@ func (s *Scheduler) monitorEntryOrders(ctx context.Context, now time.Time) error
if !s.nowUTC().Before(deadline) { if !s.nowUTC().Before(deadline) {
return s.closeEntryWindow(ctx) return s.closeEntryWindow(ctx)
} }
tradeDate := tradingDate(now)
for _, order := range orders { for _, order := range orders {
if order.Side != domain.SideBuy || order.BrokerOrderID == "" { if order.Side != domain.SideBuy || order.BrokerOrderID == "" || !sameTradingDate(order.TradeDate, tradeDate) {
continue continue
} }
instrument, ok := instrumentByUID[order.InstrumentUID] instrument, ok := instrumentByUID[order.InstrumentUID]
if !ok { if !ok {
return fmt.Errorf("instrument %s is not in registry", order.InstrumentUID) return fmt.Errorf("instrument %s is not in registry", order.InstrumentUID)
} }
monitored, err := s.svc.Execution.MonitorUntil(ctx, order, execution.MonitorConfig{ monitored, err := s.svc.Execution.MonitorOnce(ctx, order, execution.MonitorConfig{
Deadline: deadline, Deadline: deadline,
PollInterval: s.cfg.OrderPollInterval, PollInterval: s.cfg.OrderPollInterval,
MaxAttempts: s.cfg.MaxEntryOrderAttempts, MaxAttempts: s.cfg.MaxEntryOrderAttempts,
@@ -429,6 +530,9 @@ func (s *Scheduler) monitorEntryOrders(ctx context.Context, now time.Time) error
Quote: func(ctx context.Context, instrumentUID string) (domain.OrderBook, error) { Quote: func(ctx context.Context, instrumentUID string) (domain.OrderBook, error) {
return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge) return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
}, },
RepostCheck: func(ctx context.Context, order domain.Order, instrument domain.Instrument, book domain.OrderBook) error {
return s.repostPreTradeCheck(ctx, now, order, instrument, book)
},
}) })
if err != nil { if err != nil {
return err return err
@@ -460,11 +564,12 @@ func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error {
if err := s.transitionTo(ctx, domain.StatePlaceExitOrders); err != nil { if err := s.transitionTo(ctx, domain.StatePlaceExitOrders); err != nil {
return err return err
} }
exitTradeDate := tradingDate(now)
positionsList, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash) positionsList, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
if err != nil { if err != nil {
return err return err
} }
existing, err := s.svc.Repo.ListOrders(ctx, s.svc.AccountIDHash, tradingDate(now).AddDate(0, 0, -1), tradingDate(now)) existing, err := s.svc.Repo.ListOrders(ctx, s.svc.AccountIDHash, exitTradeDate.AddDate(0, 0, -1), exitTradeDate)
if err != nil { if err != nil {
return err return err
} }
@@ -480,10 +585,22 @@ func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error {
if !ok { if !ok {
return fmt.Errorf("instrument %s is not in registry", pos.InstrumentUID) return fmt.Errorf("instrument %s is not in registry", pos.InstrumentUID)
} }
if _, err := s.svc.FreeOrders.Check(ctx, exitTradeDate, instrument, s.cfg.MaxExitOrderAttempts); err != nil {
if insertErr := s.recordPreTradeReject(ctx, pos.InstrumentUID, err.Error(), `{"reason":"free_order_budget_insufficient"}`); insertErr != nil {
return insertErr
}
continue
}
book, err := s.svc.MarketData.LatestQuote(ctx, pos.InstrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge) book, err := s.svc.MarketData.LatestQuote(ctx, pos.InstrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
if err != nil { if err != nil {
return err return err
} }
if err := s.checkSpreadBeforeOrder(ctx, instrument, book); err != nil {
if insertErr := s.recordPreTradeReject(ctx, pos.InstrumentUID, err.Error(), `{"reason":"spread_limit"}`); insertErr != nil {
return insertErr
}
continue
}
tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, pos.InstrumentUID) tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, pos.InstrumentUID)
if err != nil { if err != nil {
tradingStatus = domain.TradingStatusUnknown tradingStatus = domain.TradingStatusUnknown
@@ -492,21 +609,20 @@ func (s *Scheduler) placeExitOrders(ctx context.Context, now time.Time) error {
if err != nil { if err != nil {
return err return err
} }
pre := s.svc.Risk.PreTradeCheck(risk.PreTradeInput{ pre, err := s.preTradeCheck(ctx, now, portfolio, len(positionsList), tradingStatus, book.ReceivedAt)
Portfolio: portfolio, if err != nil {
OpenPositions: len(positionsList), return err
TradingStatus: tradingStatus, }
QuoteReceivedAt: book.ReceivedAt,
Now: now.UTC(),
MarketClose: s.cfg.HardExitDeadline.On(now, s.cfg.Location).UTC(),
})
if !pre.Allowed { if !pre.Allowed {
return fmt.Errorf("exit pre-trade rejected: %s", pre.Reason) return fmt.Errorf("exit pre-trade rejected: %s", pre.Reason)
} }
placed, err := s.svc.Execution.PlaceExit(ctx, s.svc.AccountIDHash, instrument, pos.OpenTradeDate, pos.Lots, book, s.cfg.PassiveImproveTicks, 1) placed, err := s.svc.Execution.PlaceExit(ctx, s.svc.AccountIDHash, instrument, exitTradeDate, pos.Lots, book, s.cfg.PassiveImproveTicks, 1)
if err != nil && !errors.Is(err, execution.ErrBrokerOrdersDisabled) { if err != nil && !errors.Is(err, execution.ErrBrokerOrdersDisabled) {
return err return err
} }
if errors.Is(err, execution.ErrBrokerOrdersDisabled) {
continue
}
if placed.FilledLots > 0 || placed.Commission.IsPositive() { if placed.FilledLots > 0 || placed.Commission.IsPositive() {
if err := s.recordExitFill(ctx, pos, placed); err != nil { if err := s.recordExitFill(ctx, pos, placed); err != nil {
return err return err
@@ -545,15 +661,16 @@ func (s *Scheduler) monitorExitOrders(ctx context.Context, now time.Time) error
return err return err
} }
deadline := s.cfg.HardExitDeadline.On(now, s.cfg.Location).UTC() deadline := s.cfg.HardExitDeadline.On(now, s.cfg.Location).UTC()
exitTradeDate := tradingDate(now)
for _, order := range orders { for _, order := range orders {
if order.Side != domain.SideSell || order.BrokerOrderID == "" { if order.Side != domain.SideSell || order.BrokerOrderID == "" || !sameTradingDate(order.TradeDate, exitTradeDate) {
continue continue
} }
instrument, ok := instrumentByUID[order.InstrumentUID] instrument, ok := instrumentByUID[order.InstrumentUID]
if !ok { if !ok {
return fmt.Errorf("instrument %s is not in registry", order.InstrumentUID) return fmt.Errorf("instrument %s is not in registry", order.InstrumentUID)
} }
monitored, err := s.svc.Execution.MonitorUntil(ctx, order, execution.MonitorConfig{ monitored, err := s.svc.Execution.MonitorOnce(ctx, order, execution.MonitorConfig{
Deadline: deadline, Deadline: deadline,
PollInterval: s.cfg.OrderPollInterval, PollInterval: s.cfg.OrderPollInterval,
MaxAttempts: s.cfg.MaxExitOrderAttempts, MaxAttempts: s.cfg.MaxExitOrderAttempts,
@@ -563,6 +680,9 @@ func (s *Scheduler) monitorExitOrders(ctx context.Context, now time.Time) error
Quote: func(ctx context.Context, instrumentUID string) (domain.OrderBook, error) { Quote: func(ctx context.Context, instrumentUID string) (domain.OrderBook, error) {
return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge) return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
}, },
RepostCheck: func(ctx context.Context, order domain.Order, instrument domain.Instrument, book domain.OrderBook) error {
return s.repostPreTradeCheck(ctx, now, order, instrument, book)
},
}) })
if err != nil { if err != nil {
return err return err
@@ -740,21 +860,32 @@ func (s *Scheduler) checkInfrastructure(ctx context.Context) error {
s.infraFailedSince = time.Time{} s.infraFailedSince = time.Time{}
return nil return nil
} }
return s.recordInfrastructureFailure(fmt.Errorf("server_time_unavailable: %w", err)) return s.recordInfrastructureFailure(ctx, fmt.Errorf("server_time_unavailable: %w", err))
} }
drift := timeutil.Drift(s.nowUTC(), serverTime) drift := timeutil.Drift(s.nowUTC(), serverTime)
if drift > s.cfg.MaxClockDrift { if drift > s.cfg.MaxClockDrift {
return s.recordInfrastructureFailure(fmt.Errorf("server_clock_drift_too_high: %s > %s", drift, s.cfg.MaxClockDrift)) return s.recordInfrastructureFailure(ctx, fmt.Errorf("server_clock_drift_too_high: %s > %s", drift, s.cfg.MaxClockDrift))
} }
s.infraFailedSince = time.Time{} s.infraFailedSince = time.Time{}
return nil return nil
} }
func (s *Scheduler) recordInfrastructureFailure(err error) error { func (s *Scheduler) recordInfrastructureFailure(ctx context.Context, err error) error {
now := s.nowUTC() now := s.nowUTC()
if s.infraFailedSince.IsZero() { if s.infraFailedSince.IsZero() {
s.infraFailedSince = now s.infraFailedSince = now
s.logWarn("infrastructure check failed; waiting for outage threshold", "err", err, "threshold", s.cfg.APIOutageHalt) s.logWarn("infrastructure check failed; waiting for outage threshold", "err", err, "threshold", s.cfg.APIOutageHalt)
if s.svc.Repo != nil {
if insertErr := s.svc.Repo.InsertRiskEvent(ctx, domain.RiskEvent{
TS: now,
Severity: domain.SeverityWarn,
EventType: "infrastructure_outage_started",
Message: err.Error(),
ContextJSON: fmt.Sprintf(`{"threshold_sec":%d}`, int(s.cfg.APIOutageHalt.Seconds())),
}); insertErr != nil {
return insertErr
}
}
return nil return nil
} }
if s.cfg.APIOutageHalt <= 0 || now.Sub(s.infraFailedSince) >= s.cfg.APIOutageHalt { if s.cfg.APIOutageHalt <= 0 || now.Sub(s.infraFailedSince) >= s.cfg.APIOutageHalt {
@@ -921,6 +1052,183 @@ func (s *Scheduler) failOpenPositionsAtHardDeadline(ctx context.Context) error {
return s.svc.Risk.Halt(ctx, s.cfg.Mode, "hard_exit_deadline_missed", fmt.Sprintf("%d positions remain open after hard deadline", len(failed)), "") return s.svc.Risk.Halt(ctx, s.cfg.Mode, "hard_exit_deadline_missed", fmt.Sprintf("%d positions remain open after hard deadline", len(failed)), "")
} }
func (s Scheduler) checkSpreadBeforeOrder(_ context.Context, instrument domain.Instrument, book domain.OrderBook) error {
spread, err := spreadFromBook(book, instrument.MinPriceIncrement)
if err != nil {
return err
}
limit := s.svc.Signals.SpreadLimit(instrument)
if limit.IsPositive() && spread.SpreadBps.GreaterThan(limit) {
return fmt.Errorf("%s: spread_bps=%s max_spread_bps=%s", signal.ReasonSpread, spread.SpreadBps.String(), limit.String())
}
return nil
}
func (s Scheduler) repostPreTradeCheck(ctx context.Context, now time.Time, order domain.Order, instrument domain.Instrument, book domain.OrderBook) error {
if err := s.checkSpreadBeforeOrder(ctx, instrument, book); err != nil {
_ = s.recordPreTradeReject(ctx, order.InstrumentUID, err.Error(), `{"reason":"spread_limit","stage":"repost"}`)
return err
}
tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, order.InstrumentUID)
if err != nil {
tradingStatus = domain.TradingStatusUnknown
}
portfolio, err := s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID)
if err != nil {
return err
}
openPositions, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
if err != nil {
return err
}
pre, err := s.preTradeCheck(ctx, now, portfolio, len(openPositions), tradingStatus, book.ReceivedAt)
if err != nil {
return err
}
if !pre.Allowed {
_ = s.recordPreTradeReject(ctx, order.InstrumentUID, pre.Reason, `{"stage":"repost"}`)
return errors.New(pre.Reason)
}
return nil
}
func (s Scheduler) preTradeCheck(ctx context.Context, now time.Time, portfolio domain.Portfolio, openPositions int, tradingStatus domain.TradingStatus, quoteReceivedAt time.Time) (risk.PreTradeResult, error) {
metrics, err := s.riskMetrics(ctx, now, portfolio)
if err != nil {
return risk.PreTradeResult{}, err
}
return s.svc.Risk.PreTradeCheck(risk.PreTradeInput{
Portfolio: portfolio,
OpenPositions: openPositions,
DailyPnL: metrics.dailyPnL,
WeeklyPnL: metrics.weeklyPnL,
MonthlyDrawdownPct: metrics.monthlyDrawdownPct,
AvgSlippageBps10: metrics.avgSlippageBps10,
TradingStatus: tradingStatus,
QuoteReceivedAt: quoteReceivedAt,
Now: now.UTC(),
MarketClose: s.marketCloseOn(now),
}), nil
}
type preTradeMetrics struct {
dailyPnL decimal.Decimal
weeklyPnL decimal.Decimal
monthlyDrawdownPct decimal.Decimal
avgSlippageBps10 decimal.Decimal
}
func (s Scheduler) riskMetrics(ctx context.Context, now time.Time, portfolio domain.Portfolio) (preTradeMetrics, error) {
today := tradingDate(now)
monthStart := today.AddDate(0, -1, 0)
positionsList, err := s.svc.Repo.ListPositions(ctx, s.svc.AccountIDHash, monthStart.AddDate(0, 0, -7), today)
if err != nil {
return preTradeMetrics{}, err
}
weekStart := today.AddDate(0, 0, -6)
var metrics preTradeMetrics
monthlyPnL := decimal.Zero
var closed []domain.Position
for _, pos := range positionsList {
if pos.Status != domain.PositionExitFilled {
continue
}
closedAt := positionCloseTime(pos)
if closedAt.IsZero() {
continue
}
closeDate := tradingDate(closedAt)
if closeDate.Equal(today) {
metrics.dailyPnL = metrics.dailyPnL.Add(pos.NetPnL)
}
if !closeDate.Before(weekStart) {
metrics.weeklyPnL = metrics.weeklyPnL.Add(pos.NetPnL)
}
if !closeDate.Before(monthStart) {
monthlyPnL = monthlyPnL.Add(pos.NetPnL)
}
closed = append(closed, pos)
}
if monthlyPnL.IsNegative() && portfolio.Equity.IsPositive() {
metrics.monthlyDrawdownPct = monthlyPnL.Neg().Div(portfolio.Equity)
}
avg, err := s.averageAdverseSlippageBps(ctx, closed, 10)
if err != nil {
return preTradeMetrics{}, err
}
metrics.avgSlippageBps10 = avg
return metrics, nil
}
func (s Scheduler) averageAdverseSlippageBps(ctx context.Context, positionsList []domain.Position, limit int) (decimal.Decimal, error) {
if limit <= 0 {
return decimal.Zero, nil
}
sort.Slice(positionsList, func(i, j int) bool {
return positionCloseTime(positionsList[i]).After(positionCloseTime(positionsList[j]))
})
signalsByDate := make(map[string][]domain.Signal)
var values []decimal.Decimal
for _, pos := range positionsList {
key := tradingDate(pos.OpenTradeDate).Format("2006-01-02")
signals, ok := signalsByDate[key]
if !ok {
var err error
signals, err = s.svc.Repo.ListSignals(ctx, tradingDate(pos.OpenTradeDate))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return decimal.Zero, err
}
signalsByDate[key] = signals
}
for _, sig := range signals {
if sig.InstrumentUID != pos.InstrumentUID || sig.Decision != domain.DecisionEnter {
continue
}
adverse := sig.NetEdgeBps.Sub(pos.RealizedEdgeBps)
if adverse.IsNegative() {
adverse = decimal.Zero
}
values = append(values, adverse)
break
}
if len(values) == limit {
break
}
}
if len(values) == 0 {
return decimal.Zero, nil
}
sum := decimal.Zero
for _, value := range values {
sum = sum.Add(value)
}
return sum.Div(decimal.NewFromInt(int64(len(values)))), nil
}
func positionCloseTime(pos domain.Position) time.Time {
if pos.ClosedAt != nil {
return pos.ClosedAt.UTC()
}
return pos.UpdatedAt.UTC()
}
func (s Scheduler) marketCloseOn(now time.Time) time.Time {
if s.cfg.MarketClose.Duration <= 0 {
return time.Time{}
}
return s.cfg.MarketClose.On(now, s.cfg.Location).UTC()
}
func (s Scheduler) recordPreTradeReject(ctx context.Context, instrumentUID, message, contextJSON string) error {
return s.svc.Repo.InsertRiskEvent(ctx, domain.RiskEvent{
Severity: domain.SeverityWarn,
EventType: "pre_trade_reject",
InstrumentUID: instrumentUID,
Message: message,
ContextJSON: contextJSON,
})
}
func (s Scheduler) nowUTC() time.Time { func (s Scheduler) nowUTC() time.Time {
if s.clock != nil { if s.clock != nil {
return s.clock.Now().UTC() return s.clock.Now().UTC()
@@ -1062,6 +1370,120 @@ func hasOrder(orders []domain.Order, instrumentUID string, side domain.Side) boo
return false return false
} }
func sortSignalsForEntry(signals []domain.Signal) {
sort.SliceStable(signals, func(i, j int) bool {
if signals[i].Decision != signals[j].Decision {
return signals[i].Decision == domain.DecisionEnter
}
if signals[i].Score.Equal(signals[j].Score) {
return signals[i].InstrumentUID < signals[j].InstrumentUID
}
return signals[i].Score.GreaterThan(signals[j].Score)
})
}
func entryOrderCandidates(signals []domain.Signal, existing []domain.Order) []string {
out := make([]string, 0, len(signals))
for _, sig := range signals {
if sig.Decision == domain.DecisionEnter && sig.TargetLots > 0 && !hasOrder(existing, sig.InstrumentUID, domain.SideBuy) {
out = append(out, sig.InstrumentUID)
}
}
return out
}
func remainingSignalCount(candidates []string, instrumentUID string) int {
for i, candidate := range candidates {
if candidate == instrumentUID {
return len(candidates) - i
}
}
return 1
}
func countActiveOrders(orders []domain.Order, side domain.Side, tradeDate time.Time) int {
count := 0
for _, order := range orders {
if order.Side == side && sameTradingDate(order.TradeDate, tradeDate) && isActiveOrder(order.Status) {
count++
}
}
return count
}
func ordersExposure(orders []domain.Order, instruments map[string]domain.Instrument, side domain.Side, activeOnly bool) decimal.Decimal {
total := decimal.Zero
for _, order := range orders {
if order.Side != side {
continue
}
if activeOnly && !isActiveOrder(order.Status) {
continue
}
instrument := instruments[order.InstrumentUID]
total = total.Add(orderRemainingNotional(order, instrument))
}
return total
}
func positionsExposure(positions []domain.Position, instruments map[string]domain.Instrument, portfolio domain.Portfolio) decimal.Decimal {
local := decimal.Zero
for _, pos := range positions {
instrument := instruments[pos.InstrumentUID]
lot := pos.Lot
if lot <= 0 {
lot = instrument.Lot
}
if lot <= 0 || !pos.AvgBuyPrice.IsPositive() || pos.Lots <= 0 {
continue
}
local = local.Add(pos.AvgBuyPrice.Mul(decimal.NewFromInt(pos.Lots)).Mul(decimal.NewFromInt(lot)))
}
return money.Max(local, portfolioExposure(portfolio))
}
func portfolioExposure(portfolio domain.Portfolio) decimal.Decimal {
total := decimal.Zero
for _, holding := range portfolio.Holdings {
if holding.MarketValue.IsPositive() {
total = total.Add(holding.MarketValue)
}
}
return total
}
func orderNotional(order domain.Order, instrument domain.Instrument) decimal.Decimal {
lot := instrument.Lot
if lot <= 0 {
lot = 1
}
lots := order.QuantityLots
if lots <= 0 {
lots = order.FilledLots
}
return order.LimitPrice.Mul(decimal.NewFromInt(lots)).Mul(decimal.NewFromInt(lot))
}
func orderRemainingNotional(order domain.Order, instrument domain.Instrument) decimal.Decimal {
remaining := order.QuantityLots - order.FilledLots
if remaining <= 0 {
return decimal.Zero
}
lot := instrument.Lot
if lot <= 0 {
lot = 1
}
return order.LimitPrice.Mul(decimal.NewFromInt(remaining)).Mul(decimal.NewFromInt(lot))
}
func isActiveOrder(status domain.OrderStatus) bool {
return status == domain.OrderStatusNew || status == domain.OrderStatusSent || status == domain.OrderStatusPartiallyFilled
}
func sameTradingDate(a, b time.Time) bool {
return tradingDate(a).Equal(tradingDate(b))
}
func sinceMidnight(t time.Time) time.Duration { func sinceMidnight(t time.Time) time.Duration {
h, m, s := t.Clock() h, m, s := t.Clock()
return time.Duration(h)*time.Hour + time.Duration(m)*time.Minute + time.Duration(s)*time.Second return time.Duration(h)*time.Hour + time.Duration(m)*time.Minute + time.Duration(s)*time.Second
+254
View File
@@ -9,8 +9,11 @@ import (
"overnight-trading-bot/internal/domain" "overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/execution" "overnight-trading-bot/internal/execution"
"overnight-trading-bot/internal/marketdata"
"overnight-trading-bot/internal/position"
"overnight-trading-bot/internal/reconciliation" "overnight-trading-bot/internal/reconciliation"
"overnight-trading-bot/internal/risk" "overnight-trading-bot/internal/risk"
signalengine "overnight-trading-bot/internal/signal"
"overnight-trading-bot/internal/statemachine" "overnight-trading-bot/internal/statemachine"
"overnight-trading-bot/internal/testutil" "overnight-trading-bot/internal/testutil"
"overnight-trading-bot/internal/timeutil" "overnight-trading-bot/internal/timeutil"
@@ -317,6 +320,234 @@ func TestSizeReductionRuleCutsSizerAfterBadExpectedErrors(t *testing.T) {
} }
} }
func TestBatchSignalLimitsCapSlotsAndExposure(t *testing.T) {
s := Scheduler{
cfg: Config{MaxOpenPositions: 5},
svc: Services{Sizer: risk.NewSizer(risk.SizingConfig{
MaxPositionPct: decimal.NewFromInt(1),
MaxTotalExposurePct: decimal.RequireFromString("0.50"),
MaxParticipationRate: decimal.NewFromInt(1),
CashUsageBuffer: decimal.NewFromInt(1),
RiskBudgetPerInstrumentPct: decimal.NewFromInt(1),
MinOrderNotionalRUB: decimal.NewFromInt(1),
})},
}
book := domain.OrderBook{
Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(99), QuantityLots: 10}},
Asks: []domain.OrderBookLevel{{Price: decimal.NewFromInt(101), QuantityLots: 10}},
}
generated := make([]signalCandidate, 0, 9)
for i := 0; i < 9; i++ {
uid := string(rune('a' + i))
generated = append(generated, signalCandidate{
Signal: domain.Signal{
InstrumentUID: uid,
Decision: domain.DecisionEnter,
Score: decimal.NewFromInt(int64(100 - i)),
},
Instrument: domain.Instrument{InstrumentUID: uid, Lot: 1, MinPriceIncrement: decimal.NewFromInt(1)},
Feature: domain.FeatureSet{
EntryIntervalVolume: decimal.NewFromInt(1_000_000),
ExitIntervalVolume: decimal.NewFromInt(1_000_000),
SigmaOn60: decimal.NewFromInt(1),
},
Book: book,
})
}
s.applyBatchSignalLimits(domain.Portfolio{Equity: decimal.NewFromInt(100_000), Cash: decimal.NewFromInt(100_000)}, decimal.Zero, 0, generated)
enters := 0
total := decimal.Zero
for _, candidate := range generated {
if candidate.Signal.Decision == domain.DecisionEnter {
enters++
total = total.Add(candidate.Signal.TargetNotional)
}
}
if enters != 5 {
t.Fatalf("enter signals=%d, want 5", enters)
}
if total.GreaterThan(decimal.NewFromInt(50_000)) {
t.Fatalf("total target notional=%s exceeds 50%% exposure", total)
}
if generated[5].Signal.RejectReason != signalengine.ReasonMaxPositions {
t.Fatalf("sixth signal reason=%q, want max positions", generated[5].Signal.RejectReason)
}
}
func TestPlaceEntryRejectsWideSpreadBeforeOrder(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
instrument := domain.Instrument{
InstrumentUID: "uid",
Ticker: "TRUR",
ClassCode: "TQTF",
Enabled: true,
Lot: 1,
MinPriceIncrement: decimal.RequireFromString("0.01"),
Currency: "RUB",
}
if err := repo.UpsertInstrument(ctx, instrument); err != nil {
t.Fatal(err)
}
if err := repo.UpsertSignal(ctx, domain.Signal{
TradeDate: tradeDate,
InstrumentUID: "uid",
Decision: domain.DecisionEnter,
Score: decimal.NewFromInt(10),
TargetLots: 1,
}); err != nil {
t.Fatal(err)
}
if err := repo.UpsertFeature(ctx, domain.FeatureSet{
InstrumentUID: "uid",
TradeDate: tradeDate,
EntryIntervalVolume: decimal.NewFromInt(1_000_000),
ExitIntervalVolume: decimal.NewFromInt(1_000_000),
SigmaOn60: decimal.NewFromInt(1),
}); err != nil {
t.Fatal(err)
}
gateway := tinvest.NewFakeGateway()
gateway.OrderBooks["uid"] = domain.OrderBook{
InstrumentUID: "uid",
Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(90), QuantityLots: 10}},
Asks: []domain.OrderBookLevel{{Price: decimal.NewFromInt(110), QuantityLots: 10}},
ReceivedAt: time.Now().UTC(),
}
execEngine := execution.NewEngine(domain.ModePaper, "account", gateway, repo)
now := tradeDate.Add(18 * time.Hour)
s := Scheduler{
clock: fixedClock{now: now},
cfg: Config{
Mode: domain.ModePaper,
Location: time.UTC,
NoNewEntryAfter: mustTOD("23:00:00"),
MaxQuoteAge: time.Minute,
MarketClose: mustTOD("23:30:00"),
MaxOpenPositions: 5,
},
sm: statemachine.New(repo, domain.ModePaper),
svc: Services{
Repo: repo,
Gateway: gateway,
MarketData: marketdata.NewLoader(repo, gateway),
Signals: signalengine.New(signalengine.Config{MaxSpreadBpsDefault: decimal.NewFromInt(20)}),
Sizer: risk.NewSizer(testSizingConfig()),
FreeOrders: risk.NewFreeOrderBudget(repo),
Risk: risk.NewManager(repo, risk.ManagerConfig{MaxOpenPositions: 5}),
Execution: &execEngine,
Positions: position.NewManager(repo),
Notifier: &countNotifier{},
AccountID: "account",
AccountIDHash: "hash",
},
}
if err := repo.SaveSystemState(ctx, domain.StateWaitEntryWindow, domain.ModePaper, false, "", "{}"); err != nil {
t.Fatal(err)
}
if err := s.placeEntryOrders(ctx, now); err != nil {
t.Fatal(err)
}
orders, err := repo.ListOrders(ctx, "hash", tradeDate, tradeDate)
if err != nil {
t.Fatal(err)
}
if len(orders) != 0 {
t.Fatalf("orders=%+v, want no order on wide spread", orders)
}
if len(repo.RiskEvents) != 1 || repo.RiskEvents[0].EventType != "pre_trade_reject" {
t.Fatalf("risk events=%+v", repo.RiskEvents)
}
}
func TestPlaceExitUsesCurrentTradeDateForOrderAndFreeCounter(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
openDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
exitDate := openDate.AddDate(0, 0, 1)
instrument := domain.Instrument{
InstrumentUID: "uid",
Ticker: "TRUR",
ClassCode: "TQTF",
Enabled: true,
Lot: 1,
MinPriceIncrement: decimal.RequireFromString("0.01"),
Currency: "RUB",
FreeOrderLimitPerDay: 10,
}
if err := repo.UpsertInstrument(ctx, instrument); err != nil {
t.Fatal(err)
}
if err := repo.UpsertPosition(ctx, domain.Position{
AccountIDHash: "hash",
InstrumentUID: "uid",
OpenTradeDate: openDate,
Lots: 2,
Lot: 1,
AvgBuyPrice: decimal.NewFromInt(100),
Status: domain.PositionHoldingOvernight,
}); err != nil {
t.Fatal(err)
}
gateway := tinvest.NewFakeGateway()
gateway.OrderBooks["uid"] = domain.OrderBook{
InstrumentUID: "uid",
Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(100), QuantityLots: 10}},
Asks: []domain.OrderBookLevel{{Price: decimal.RequireFromString("100.10"), QuantityLots: 10}},
ReceivedAt: time.Now().UTC(),
}
execEngine := execution.NewEngine(domain.ModePaper, "account", gateway, repo)
s := Scheduler{
cfg: Config{
Mode: domain.ModePaper,
Location: time.UTC,
HardExitDeadline: mustTOD("23:00:00"),
MaxQuoteAge: time.Minute,
MarketClose: mustTOD("23:30:00"),
},
sm: statemachine.New(repo, domain.ModePaper),
svc: Services{
Repo: repo,
Gateway: gateway,
MarketData: marketdata.NewLoader(repo, gateway),
Signals: signalengine.New(signalengine.Config{MaxSpreadBpsDefault: decimal.NewFromInt(20)}),
FreeOrders: risk.NewFreeOrderBudget(repo),
Risk: risk.NewManager(repo, risk.ManagerConfig{}),
Execution: &execEngine,
Positions: position.NewManager(repo),
Reconcile: reconciliation.New(repo, gateway, "account", "hash"),
Notifier: &countNotifier{},
AccountID: "account",
AccountIDHash: "hash",
},
}
if err := repo.SaveSystemState(ctx, domain.StateWaitExitWindow, domain.ModePaper, false, "", "{}"); err != nil {
t.Fatal(err)
}
if err := s.placeExitOrders(ctx, exitDate.Add(10*time.Hour)); err != nil {
t.Fatal(err)
}
orders, err := repo.ListOrders(ctx, "hash", exitDate, exitDate)
if err != nil {
t.Fatal(err)
}
if len(orders) != 1 || !sameTradingDate(orders[0].TradeDate, exitDate) {
t.Fatalf("orders=%+v, want one exit order on current date", orders)
}
sentToday, err := repo.GetFreeOrdersSent(ctx, exitDate, "uid")
if err != nil {
t.Fatal(err)
}
sentOpenDate, err := repo.GetFreeOrdersSent(ctx, openDate, "uid")
if err != nil {
t.Fatal(err)
}
if sentToday != 1 || sentOpenDate != 0 {
t.Fatalf("free counters today=%d openDate=%d, want 1/0", sentToday, sentOpenDate)
}
}
func mustTOD(raw string) timeutil.TimeOfDay { func mustTOD(raw string) timeutil.TimeOfDay {
tod, err := timeutil.ParseTimeOfDay(raw) tod, err := timeutil.ParseTimeOfDay(raw)
if err != nil { if err != nil {
@@ -325,6 +556,29 @@ func mustTOD(raw string) timeutil.TimeOfDay {
return tod return tod
} }
func testSizingConfig() risk.SizingConfig {
return risk.SizingConfig{
MaxPositionPct: decimal.NewFromInt(1),
MaxTotalExposurePct: decimal.NewFromInt(1),
MaxParticipationRate: decimal.NewFromInt(1),
CashUsageBuffer: decimal.NewFromInt(1),
RiskBudgetPerInstrumentPct: decimal.NewFromInt(1),
MinOrderNotionalRUB: decimal.NewFromInt(1),
}
}
type fixedClock struct {
now time.Time
}
func (c fixedClock) Now() time.Time {
return c.now
}
func (fixedClock) Sleep(<-chan struct{}, time.Duration) bool {
return true
}
type countNotifier struct { type countNotifier struct {
reports int reports int
alerts int alerts int
+3 -3
View File
@@ -74,7 +74,7 @@ func (e Engine) Evaluate(c Candidate) domain.Signal {
"ticker": c.Instrument.Ticker, "ticker": c.Instrument.Ticker,
"fund_type": c.Instrument.FundType, "fund_type": c.Instrument.FundType,
"trading_status": c.TradingStatus, "trading_status": c.TradingStatus,
"spread_limit": e.spreadLimit(c.Instrument).String(), "spread_limit": e.SpreadLimit(c.Instrument).String(),
} }
for k, v := range c.ExtraContext { for k, v := range c.ExtraContext {
context[k] = v context[k] = v
@@ -122,7 +122,7 @@ func (e Engine) firstRejectReason(c Candidate) string {
return ReasonWinRate return ReasonWinRate
case features.NetEdgeBps.LessThan(e.cfg.MinNetEdgeBps): case features.NetEdgeBps.LessThan(e.cfg.MinNetEdgeBps):
return ReasonNetEdge return ReasonNetEdge
case features.SpreadBps.GreaterThan(e.spreadLimit(instr)): case features.SpreadBps.GreaterThan(e.SpreadLimit(instr)):
return ReasonSpread return ReasonSpread
case features.TickBps.GreaterThan(e.cfg.MaxTickBps): case features.TickBps.GreaterThan(e.cfg.MaxTickBps):
return ReasonTick return ReasonTick
@@ -137,7 +137,7 @@ func (e Engine) firstRejectReason(c Candidate) string {
} }
} }
func (e Engine) spreadLimit(instr domain.Instrument) decimal.Decimal { func (e Engine) SpreadLimit(instr domain.Instrument) decimal.Decimal {
fundType := strings.ToLower(instr.FundType) fundType := strings.ToLower(instr.FundType)
switch { switch {
case strings.Contains(fundType, "money"): case strings.Contains(fundType, "money"):
+7 -7
View File
@@ -96,20 +96,20 @@ func legalTransition(from, to domain.SystemState) bool {
return true return true
} }
allowed := map[domain.SystemState][]domain.SystemState{ allowed := map[domain.SystemState][]domain.SystemState{
domain.StateInit: {domain.StateSyncInstruments, domain.StateWaitExitWindow}, domain.StateInit: {domain.StateSyncInstruments, domain.StateWaitExitWindow, domain.StatePlaceExitOrders, domain.StateMonitorExitOrders, domain.StateGenerateSignals, domain.StatePlaceEntryOrders, domain.StateHoldOvernight, domain.StateReconcile, domain.StateSleep},
domain.StateSyncInstruments: {domain.StateSyncMarketData}, domain.StateSyncInstruments: {domain.StateSyncMarketData},
domain.StateSyncMarketData: {domain.StateGenerateSignals}, domain.StateSyncMarketData: {domain.StateGenerateSignals},
domain.StateGenerateSignals: {domain.StateWaitEntryWindow}, domain.StateGenerateSignals: {domain.StateWaitEntryWindow, domain.StatePlaceEntryOrders, domain.StateHoldOvernight, domain.StateSleep},
domain.StateWaitEntryWindow: {domain.StatePlaceEntryOrders, domain.StateSleep}, domain.StateWaitEntryWindow: {domain.StatePlaceEntryOrders, domain.StateSleep},
domain.StatePlaceEntryOrders: {domain.StateMonitorEntryOrders, domain.StateReconcile}, domain.StatePlaceEntryOrders: {domain.StateMonitorEntryOrders, domain.StateHoldOvernight, domain.StateWaitExitWindow, domain.StatePlaceExitOrders, domain.StateMonitorExitOrders, domain.StateReconcile},
domain.StateMonitorEntryOrders: {domain.StateHoldOvernight, domain.StateReconcile}, domain.StateMonitorEntryOrders: {domain.StateHoldOvernight, domain.StateWaitExitWindow, domain.StatePlaceExitOrders, domain.StateMonitorExitOrders, domain.StateReconcile},
domain.StateHoldOvernight: {domain.StateWaitExitWindow}, domain.StateHoldOvernight: {domain.StateWaitExitWindow, domain.StatePlaceExitOrders, domain.StateMonitorExitOrders, domain.StateReconcile},
domain.StateWaitExitWindow: {domain.StatePlaceExitOrders}, domain.StateWaitExitWindow: {domain.StatePlaceExitOrders},
domain.StatePlaceExitOrders: {domain.StateMonitorExitOrders, domain.StateReconcile}, domain.StatePlaceExitOrders: {domain.StateMonitorExitOrders, domain.StateReconcile},
domain.StateMonitorExitOrders: {domain.StateReconcile}, domain.StateMonitorExitOrders: {domain.StateReconcile},
domain.StateReconcile: {domain.StateReport, domain.StateHalted}, domain.StateReconcile: {domain.StateReport, domain.StateHalted, domain.StateGenerateSignals, domain.StateSleep},
domain.StateReport: {domain.StateSleep}, domain.StateReport: {domain.StateSleep},
domain.StateSleep: {domain.StateInit, domain.StateWaitExitWindow, domain.StateGenerateSignals}, domain.StateSleep: {domain.StateInit, domain.StateWaitExitWindow, domain.StatePlaceExitOrders, domain.StateMonitorExitOrders, domain.StateGenerateSignals, domain.StatePlaceEntryOrders, domain.StateHoldOvernight, domain.StateReconcile},
} }
for _, candidate := range allowed[from] { for _, candidate := range allowed[from] {
if candidate == to { if candidate == to {
+15
View File
@@ -62,6 +62,21 @@ func TestUnhaltPreservesMode(t *testing.T) {
} }
} }
func TestCalendarRecoveryAllowsRestartInsideExitWindow(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
system := New(repo, domain.ModePaper)
if err := system.Transition(ctx, domain.StateInit, domain.StatePlaceExitOrders); err != nil {
t.Fatalf("INIT -> PLACE_EXIT_ORDERS should be legal on restart: %v", err)
}
if err := repo.SaveSystemState(ctx, domain.StateHoldOvernight, domain.ModePaper, false, "", "{}"); err != nil {
t.Fatal(err)
}
if err := system.Transition(ctx, domain.StateHoldOvernight, domain.StatePlaceExitOrders); err != nil {
t.Fatalf("HOLD_OVERNIGHT -> PLACE_EXIT_ORDERS should be legal on restart: %v", err)
}
}
func TestRecoverFromMonitorEntryHaltsOnCriticalReconciliationDiff(t *testing.T) { func TestRecoverFromMonitorEntryHaltsOnCriticalReconciliationDiff(t *testing.T) {
ctx := context.Background() ctx := context.Background()
repo := testutil.NewMemoryRepository() repo := testutil.NewMemoryRepository()
+38 -30
View File
@@ -246,6 +246,44 @@ func (g *RealGateway) GetPortfolio(ctx context.Context, accountID string) (domai
if err != nil { if err != nil {
return domain.Portfolio{}, err return domain.Portfolio{}, err
} }
return portfolioFromResponse(resp.PortfolioResponse)
}
func (g *RealGateway) GetOperations(ctx context.Context, accountID string, from, to time.Time) ([]domain.Operation, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.OperationsResponse, error) {
return g.operations.GetOperations(&investgo.GetOperationsRequest{
AccountId: accountID,
From: from,
To: to,
})
})
if err != nil {
return nil, err
}
return operationsFromResponse(resp.OperationsResponse), nil
}
func operationsFromResponse(resp *pb.OperationsResponse) []domain.Operation {
ops := resp.GetOperations()
out := make([]domain.Operation, 0, len(ops))
for _, op := range ops {
payment := money.MoneyValueToDecimal(op.GetPayment())
out = append(out, domain.Operation{
ID: op.GetId(),
InstrumentUID: op.GetInstrumentUid(),
Type: op.GetOperationType().String(),
Payment: payment,
Commission: operationCommission(op.GetOperationType(), payment),
ExecutedAt: op.GetDate().AsTime().UTC(),
})
}
return out
}
func portfolioFromResponse(resp *pb.PortfolioResponse) (domain.Portfolio, error) {
positions := resp.GetPositions() positions := resp.GetPositions()
holdings := make([]domain.Holding, 0, len(positions)) holdings := make([]domain.Holding, 0, len(positions))
for _, position := range positions { for _, position := range positions {
@@ -272,36 +310,6 @@ func (g *RealGateway) GetPortfolio(ctx context.Context, accountID string) (domai
}, nil }, nil
} }
func (g *RealGateway) GetOperations(ctx context.Context, accountID string, from, to time.Time) ([]domain.Operation, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.OperationsResponse, error) {
return g.operations.GetOperations(&investgo.GetOperationsRequest{
AccountId: accountID,
From: from,
To: to,
})
})
if err != nil {
return nil, err
}
ops := resp.GetOperations()
out := make([]domain.Operation, 0, len(ops))
for _, op := range ops {
payment := money.MoneyValueToDecimal(op.GetPayment())
out = append(out, domain.Operation{
ID: op.GetId(),
InstrumentUID: op.GetInstrumentUid(),
Type: op.GetOperationType().String(),
Payment: payment,
Commission: operationCommission(op.GetOperationType(), payment),
ExecutedAt: op.GetDate().AsTime().UTC(),
})
}
return out, nil
}
func (g *RealGateway) GetServerTime(ctx context.Context) (time.Time, error) { func (g *RealGateway) GetServerTime(ctx context.Context) (time.Time, error) {
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return time.Time{}, err return time.Time{}, err
+124 -4
View File
@@ -1,10 +1,130 @@
package tinvest package tinvest
import "context" import (
"context"
"time"
"github.com/russianinvestments/invest-api-go-sdk/investgo"
pb "github.com/russianinvestments/invest-api-go-sdk/proto"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/money"
)
const sandboxEndpoint = "sandbox-invest-public-api.tinkoff.ru:443" const sandboxEndpoint = "sandbox-invest-public-api.tinkoff.ru:443"
func NewSandboxGateway(ctx context.Context, opts Options) (*RealGateway, error) { type SandboxGateway struct {
opts.Endpoint = sandboxEndpoint *RealGateway
return NewRealGateway(ctx, opts) sandbox *investgo.SandboxServiceClient
}
func NewSandboxGateway(ctx context.Context, opts Options) (*SandboxGateway, error) {
opts.Endpoint = sandboxEndpoint
realGateway, err := NewRealGateway(ctx, opts)
if err != nil {
return nil, err
}
return &SandboxGateway{
RealGateway: realGateway,
sandbox: realGateway.client.NewSandboxServiceClient(),
}, nil
}
func (g *SandboxGateway) PostLimitOrder(ctx context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error) {
if err := ctx.Err(); err != nil {
return domain.Order{}, err
}
direction := pb.OrderDirection_ORDER_DIRECTION_BUY
if side == domain.SideSell {
direction = pb.OrderDirection_ORDER_DIRECTION_SELL
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) {
return g.sandbox.PostSandboxOrder(&investgo.PostOrderRequest{
InstrumentId: instrumentUID,
Quantity: lots,
Price: money.DecimalToQuotation(price),
Direction: direction,
AccountId: accountID,
OrderType: pb.OrderType_ORDER_TYPE_LIMIT,
OrderId: clientOrderID,
TimeInForce: pb.TimeInForceType_TIME_IN_FORCE_DAY,
PriceType: pb.PriceType_PRICE_TYPE_CURRENCY,
})
})
if err != nil {
return domain.Order{}, err
}
return orderFromPostResponse(resp.PostOrderResponse, accountID, clientOrderID, side, price), nil
}
func (g *SandboxGateway) CancelOrder(ctx context.Context, accountID, orderID string) error {
if err := ctx.Err(); err != nil {
return err
}
return withRetry(ctx, g.retryAttempts, g.retryBackoff, func() error {
_, err := g.sandbox.CancelSandboxOrder(accountID, orderID)
return err
})
}
func (g *SandboxGateway) GetOrderState(ctx context.Context, accountID, orderID string) (domain.Order, error) {
if err := ctx.Err(); err != nil {
return domain.Order{}, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderStateResponse, error) {
return g.sandbox.GetSandboxOrderState(accountID, orderID)
})
if err != nil {
return domain.Order{}, err
}
return orderFromState(resp.OrderState, accountID), nil
}
func (g *SandboxGateway) GetActiveOrders(ctx context.Context, accountID string) ([]domain.Order, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrdersResponse, error) {
return g.sandbox.GetSandboxOrders(accountID)
})
if err != nil {
return nil, err
}
states := resp.GetOrders()
out := make([]domain.Order, 0, len(states))
for _, state := range states {
out = append(out, orderFromState(state, accountID))
}
return out, nil
}
func (g *SandboxGateway) GetPortfolio(ctx context.Context, accountID string) (domain.Portfolio, error) {
if err := ctx.Err(); err != nil {
return domain.Portfolio{}, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PortfolioResponse, error) {
return g.sandbox.GetSandboxPortfolio(accountID, pb.PortfolioRequest_RUB)
})
if err != nil {
return domain.Portfolio{}, err
}
return portfolioFromResponse(resp.PortfolioResponse)
}
func (g *SandboxGateway) GetOperations(ctx context.Context, accountID string, from, to time.Time) ([]domain.Operation, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.OperationsResponse, error) {
return g.sandbox.GetSandboxOperations(&investgo.GetOperationsRequest{
AccountId: accountID,
From: from,
To: to,
})
})
if err != nil {
return nil, err
}
return operationsFromResponse(resp.OperationsResponse), nil
} }