first version

This commit is contained in:
2026-06-07 21:01:40 +00:00
parent ee7167accf
commit f19bab1100
79 changed files with 10355 additions and 145 deletions
+328 -13
View File
@@ -2,38 +2,353 @@ package app
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"io"
"log/slog"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/jmoiron/sqlx"
"overnight-trading-bot/internal/config"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/execution"
"overnight-trading-bot/internal/features"
"overnight-trading-bot/internal/healthcheck"
"overnight-trading-bot/internal/instruments"
"overnight-trading-bot/internal/logging"
"overnight-trading-bot/internal/marketdata"
"overnight-trading-bot/internal/notify"
"overnight-trading-bot/internal/position"
"overnight-trading-bot/internal/reconciliation"
mysqlrepo "overnight-trading-bot/internal/repository/mysql"
"overnight-trading-bot/internal/risk"
"overnight-trading-bot/internal/scheduler"
signalengine "overnight-trading-bot/internal/signal"
"overnight-trading-bot/internal/statemachine"
"overnight-trading-bot/internal/timeutil"
"overnight-trading-bot/internal/tinvest"
)
type Options struct {
ConfigPath string
Stdout io.Writer
Stdout io.Writer
Stderr io.Writer
ModeOverride string
Unhalt bool
Reason string
Healthcheck bool
HealthcheckURL string
RunOnce bool
}
func Run(ctx context.Context, opts Options) error {
if err := ctx.Err(); err != nil {
return err
}
if opts.ConfigPath == "" {
return errors.New("config path is required")
if opts.Healthcheck {
target := opts.HealthcheckURL
if target == "" {
target = "http://127.0.0.1:3300/health"
}
return healthcheck.CheckEndpoint(ctx, target)
}
if opts.Stdout == nil {
opts.Stdout = io.Discard
}
if _, err := os.Stat(opts.ConfigPath); err != nil {
if errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("config file %q does not exist; copy config.example.yaml to config.yaml and fill credentials", opts.ConfigPath)
if opts.Stderr == nil {
opts.Stderr = io.Discard
}
cfg, err := config.Load()
if err != nil {
return fmt.Errorf("load ENV config: %w", err)
}
if opts.ModeOverride != "" {
mode, err := domain.ParseMode(opts.ModeOverride)
if err != nil {
return err
}
cfg.App.Mode = mode
if err := cfg.Validate(); err != nil {
return err
}
}
log := logging.New(cfg.App.LogLevel, opts.Stdout)
log.Info("overnight trading bot starting", "mode", cfg.App.Mode)
return fmt.Errorf("check config file %q: %w", opts.ConfigPath, err)
if cfg.App.Mode == domain.ModeBacktest && !opts.Unhalt {
_, _ = fmt.Fprintf(opts.Stdout, "overnight trading bot initialized in %s mode\n", cfg.App.Mode)
return nil
}
fmt.Fprintf(opts.Stdout, "overnight trading bot initialized with config %q\n", opts.ConfigPath)
return nil
db, err := openDB(ctx, cfg)
if err != nil {
return err
}
defer func() {
_ = db.Close()
}()
if cfg.DB.MigrationsAutoApply {
if err := mysqlrepo.ApplyMigrations(ctx, db.DB); err != nil {
return err
}
}
repo := mysqlrepo.NewRepository(db)
if opts.Unhalt {
if strings.TrimSpace(opts.Reason) == "" {
return errors.New("-unhalt requires -reason")
}
gateway, closer, err := buildGateway(ctx, cfg, log)
if err != nil {
return err
}
if closer != nil {
defer closer()
}
accountIDHash := accountHash(cfg.TInvest.AccountID)
recon := reconciliation.New(repo, gateway, cfg.TInvest.AccountID, accountIDHash).
WithWindow(time.Duration(cfg.Risk.ReconciliationWindowHours) * time.Hour).
WithInFlightGrace(time.Duration(cfg.Risk.ReconciliationSkewSec) * time.Second)
diffs, err := recon.Run(ctx)
if err != nil {
return fmt.Errorf("pre-unhalt reconciliation: %w", err)
}
if reconciliation.HasCritical(diffs) {
return fmt.Errorf("pre-unhalt reconciliation has critical diffs: %d", len(diffs))
}
if err := repo.Unhalt(ctx, opts.Reason); err != nil {
return err
}
_, _ = fmt.Fprintf(opts.Stdout, "system unhalted: %s\n", opts.Reason)
return nil
}
gateway, closer, err := buildGateway(ctx, cfg, log)
if err != nil {
return err
}
if closer != nil {
defer closer()
}
notifier, err := notify.NewTelegram(notify.TelegramConfig{
BotToken: cfg.Telegram.BotToken,
ChatID: cfg.Telegram.ChatID,
NotifyInfo: cfg.Telegram.NotifyInfo,
NotifyWarn: cfg.Telegram.NotifyWarn,
NotifyAlert: cfg.Telegram.NotifyAlert,
NotifyReport: cfg.Telegram.NotifyReport,
AuditSink: repo,
}, log)
if err != nil {
return fmt.Errorf("create notifier: %w", err)
}
defer func() {
_ = notifier.Close()
}()
accountIDHash := accountHash(cfg.TInvest.AccountID)
recon := reconciliation.New(repo, gateway, cfg.TInvest.AccountID, accountIDHash).
WithWindow(time.Duration(cfg.Risk.ReconciliationWindowHours) * time.Hour).
WithInFlightGrace(time.Duration(cfg.Risk.ReconciliationSkewSec) * time.Second)
sm := statemachine.New(repo, cfg.App.Mode)
if _, err := sm.Recover(ctx, recon); err != nil {
log.Warn("state recovery did not resume trading", "err", err)
}
health := healthcheck.New(db.DB, gateway, time.Duration(cfg.Risk.MaxClockDriftSec)*time.Second)
health.Start(cfg.App.HealthcheckAddr)
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.App.ShutdownTimeoutSec)*time.Second)
defer cancel()
_ = health.Shutdown(shutdownCtx)
}()
if err := notifier.Info(ctx, fmt.Sprintf("bot started in %s mode", cfg.App.Mode)); err != nil {
log.Warn("notify startup failed", "err", err)
}
if opts.RunOnce {
_, _ = fmt.Fprintf(opts.Stdout, "overnight trading bot initialized in %s mode\n", cfg.App.Mode)
return nil
}
runCtx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer stop()
clock := timeutil.RealClock{Loc: cfg.Location}
runtime := buildScheduler(clock, sm, cfg, repo, gateway, notifier, recon, accountIDHash, log)
return runtime.Run(runCtx)
}
func buildScheduler(clock timeutil.Clock, sm statemachine.System, cfg config.Config, repo *mysqlrepo.Repository, gateway tinvest.Gateway, notifier notify.Notifier, recon reconciliation.Engine, accountIDHash string, log *slog.Logger) scheduler.Scheduler {
registry := instruments.NewRegistry(repo, gateway)
loader := marketdata.NewLoader(repo, gateway)
pipeline := features.NewPipeline(repo, features.PipelineConfig{
RollingShort: cfg.Strategy.RollingShort,
RollingLong: cfg.Strategy.RollingLong,
EWMALambda: cfg.Strategy.EWMALambda,
RiskBufferBps: cfg.Strategy.RiskBufferBps,
EntrySlippageBps: cfg.Backtest.EntrySlippageBps,
ExitSlippageBps: cfg.Backtest.ExitSlippageBps,
CommissionRoundtripBps: cfg.Backtest.CommissionRoundtripBps,
EntryWindow: timeutil.Window{
Start: cfg.Execution.EntryWindowStart,
End: cfg.Execution.EntryWindowEnd,
},
ExitWindow: timeutil.Window{
Start: cfg.Execution.ExitWindowStart,
End: cfg.Execution.ExitWindowEnd,
},
Location: cfg.Location,
})
signalEngine := signalengine.New(signalengine.Config{
MinTStat60: cfg.Strategy.MinTStat60,
MinWinRate60: cfg.Strategy.MinWinRate60,
MinNetEdgeBps: cfg.Strategy.MinNetEdgeBps,
MinADVRUB: cfg.Liquidity.MinADVRUB,
MaxSpreadBpsDefault: cfg.Liquidity.MaxSpreadBpsDefault,
MaxSpreadBpsMoneyMarket: cfg.Liquidity.MaxSpreadBpsMoneyMarket,
MaxSpreadBpsBondFunds: cfg.Liquidity.MaxSpreadBpsBondFunds,
MaxSpreadBpsEquityFunds: cfg.Liquidity.MaxSpreadBpsEquityFunds,
MaxTickBps: cfg.Liquidity.MaxTickBps,
RequireZeroCommission: cfg.Commission.RequireZeroCommission,
MaxPositions: cfg.Strategy.MaxPositions,
})
sizer := risk.NewSizer(risk.SizingConfig{
MaxPositionPct: cfg.Risk.MaxPositionPct,
MaxTotalExposurePct: cfg.Risk.MaxTotalExposurePct,
MaxParticipationRate: cfg.Liquidity.MaxParticipationRate,
CashUsageBuffer: cfg.Risk.CashUsageBuffer,
RiskBudgetPerInstrumentPct: cfg.Risk.RiskBudgetPerInstrumentPct,
MinOrderNotionalRUB: cfg.Risk.MinOrderNotionalRUB,
})
freeOrders := risk.NewFreeOrderBudget(repo)
riskManager := risk.NewManager(repo, risk.ManagerConfig{
MaxDailyLossPct: cfg.Risk.MaxDailyLossPct,
MaxWeeklyLossPct: cfg.Risk.MaxWeeklyLossPct,
MaxMonthlyDrawdownPct: cfg.Risk.MaxMonthlyDrawdownPct,
MaxAvgSlippageBps10Trades: cfg.Risk.MaxAvgSlippageBps10Trades,
MaxOpenPositions: cfg.Risk.MaxOpenPositions,
MinTimeToClose: time.Duration(cfg.Execution.MinTimeToCloseSec) * time.Second,
MaxQuoteAge: time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second,
})
execEngine := execution.NewEngine(cfg.App.Mode, cfg.TInvest.AccountID, gateway, repo)
execEngine.SetMaxQuoteAge(time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second)
services := scheduler.Services{
Repo: repo,
Gateway: gateway,
Registry: registry,
MarketData: loader,
Features: pipeline,
Signals: signalEngine,
Sizer: sizer,
FreeOrders: freeOrders,
Risk: riskManager,
Execution: &execEngine,
Positions: position.NewManager(repo),
Reconcile: recon,
Notifier: notifier,
AccountID: cfg.TInvest.AccountID,
AccountIDHash: accountIDHash,
Log: log,
}
return scheduler.New(clock, sm, scheduler.Config{
Mode: cfg.App.Mode,
Location: cfg.Location,
RollingLong: cfg.Strategy.RollingLong,
TickInterval: 30 * time.Second,
EntrySignalTime: cfg.Execution.EntrySignalTime,
EntryWindowStart: cfg.Execution.EntryWindowStart,
EntryWindowEnd: cfg.Execution.EntryWindowEnd,
NoNewEntryAfter: cfg.Execution.NoNewEntryAfter,
ExitWatchStart: cfg.Execution.ExitWatchStart,
ExitWindowStart: cfg.Execution.ExitWindowStart,
ExitWindowEnd: cfg.Execution.ExitWindowEnd,
HardExitDeadline: cfg.Execution.HardExitDeadline,
QuoteDepth: cfg.Execution.QuoteDepth,
MaxQuoteAge: time.Duration(cfg.Execution.MaxQuoteAgeSec) * time.Second,
OrderPollInterval: time.Duration(cfg.Execution.OrderPollIntervalMS) * time.Millisecond,
PassiveImproveTicks: cfg.Execution.PassiveImproveTicks,
MaxEntryOrderAttempts: cfg.Execution.MaxEntryOrderAttempts,
MaxExitOrderAttempts: cfg.Execution.MaxExitOrderAttempts,
MinTimeToClose: time.Duration(cfg.Execution.MinTimeToCloseSec) * time.Second,
MaxClockDrift: time.Duration(cfg.Risk.MaxClockDriftSec) * time.Second,
APIOutageHalt: time.Duration(cfg.Risk.APIOutageHaltSec) * time.Second,
}, services)
}
func openDB(ctx context.Context, cfg config.Config) (*sqlx.DB, error) {
db, err := sqlx.Open("mysql", cfg.DB.DSN)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(cfg.DB.MaxOpenConns)
db.SetMaxIdleConns(cfg.DB.MaxIdleConns)
db.SetConnMaxLifetime(time.Duration(cfg.DB.ConnMaxLifetimeMin) * time.Minute)
if err := db.PingContext(ctx); err != nil {
_ = db.Close()
return nil, fmt.Errorf("db ping: %w", err)
}
return db, nil
}
func buildGateway(ctx context.Context, cfg config.Config, log *slog.Logger) (tinvest.Gateway, func(), error) {
switch cfg.App.Mode {
case domain.ModePaper:
return tinvest.NewFakeGateway(), nil, nil
case domain.ModeSandbox:
gw, err := tinvest.NewSandboxGateway(ctx, tinvest.Options{
Token: cfg.TInvest.Token,
AccountID: cfg.TInvest.AccountID,
AppName: cfg.TInvest.AppName,
RetryCount: cfg.TInvest.RetryCount,
RetryBackoff: time.Duration(cfg.TInvest.RetryBackoffSec) * time.Second,
Logger: log,
})
if err != nil {
return nil, nil, err
}
return gw, func() { _ = gw.Close() }, nil
case domain.ModeLiveReadonly, domain.ModeLiveTrade:
endpoint := cfg.TInvest.Endpoint
if cfg.TInvest.UseSandbox {
return nil, nil, errors.New("TINVEST_USE_SANDBOX is only allowed with APP_MODE=sandbox")
}
gw, err := tinvest.NewRealGateway(ctx, tinvest.Options{
Token: cfg.TInvest.Token,
AccountID: cfg.TInvest.AccountID,
Endpoint: endpoint,
AppName: cfg.TInvest.AppName,
RetryCount: cfg.TInvest.RetryCount,
RetryBackoff: time.Duration(cfg.TInvest.RetryBackoffSec) * time.Second,
Logger: log,
})
if err != nil {
return nil, nil, err
}
return gw, func() { _ = gw.Close() }, nil
default:
return tinvest.NewFakeGateway(), nil, nil
}
}
func accountHash(accountID string) string {
sum := sha256.Sum256([]byte(accountID))
return hex.EncodeToString(sum[:])
}
func HealthURL(addr string) string {
if strings.HasPrefix(addr, ":") {
return "http://127.0.0.1" + addr + "/health"
}
if _, err := url.ParseRequestURI(addr); err == nil && strings.HasPrefix(addr, "http") {
return addr
}
return "http://" + addr + "/health"
}
func PingDB(ctx context.Context, db *sql.DB) error {
return db.PingContext(ctx)
}
+9 -29
View File
@@ -3,49 +3,29 @@ package app
import (
"bytes"
"context"
"os"
"strings"
"testing"
)
func TestRunRequiresConfigPath(t *testing.T) {
err := Run(context.Background(), Options{})
func TestRunRequiresAppMode(t *testing.T) {
t.Setenv("APP_MODE", "")
err := Run(context.Background(), Options{RunOnce: true})
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "config path is required") {
if !strings.Contains(err.Error(), "APP_MODE") && !strings.Contains(err.Error(), "MODE") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestRunReportsMissingConfig(t *testing.T) {
err := Run(context.Background(), Options{ConfigPath: "missing.yaml"})
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "does not exist") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestRunUsesExistingConfig(t *testing.T) {
touch := t.TempDir() + "/config.yaml"
if err := os.WriteFile(touch, []byte("instruments: []\n"), 0o600); err != nil {
t.Fatal(err)
}
func TestRunBacktestModeWithoutDB(t *testing.T) {
t.Setenv("APP_MODE", "backtest")
var stdout bytes.Buffer
err := Run(context.Background(), Options{
ConfigPath: touch,
Stdout: &stdout,
})
err := Run(context.Background(), Options{Stdout: &stdout, RunOnce: true})
if err != nil {
t.Fatal(err)
}
if !strings.Contains(stdout.String(), "initialized") {
t.Fatalf("unexpected stdout: %q", stdout.String())
if !strings.Contains(stdout.String(), "backtest") {
t.Fatalf("unexpected stdout: %s", stdout.String())
}
}
+563
View File
@@ -0,0 +1,563 @@
package backtest
import (
"encoding/csv"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/features"
"overnight-trading-bot/internal/money"
"overnight-trading-bot/internal/risk"
)
type Config struct {
EntrySlippageBps decimal.Decimal
ExitSlippageBps decimal.Decimal
CommissionRoundtripBps decimal.Decimal
InitialEquity decimal.Decimal
OutputDir string
RollingShort int
RollingLong int
EWMALambda float64
MinTStat60 decimal.Decimal
MinWinRate60 decimal.Decimal
MinNetEdgeBps decimal.Decimal
MinADVRUB decimal.Decimal
MaxSpreadBps decimal.Decimal
MaxTickBps decimal.Decimal
RequireZeroCommission bool
MaxPositions int
MaxPositionPct decimal.Decimal
MaxTotalExposurePct decimal.Decimal
MaxParticipationRate decimal.Decimal
CashUsageBuffer decimal.Decimal
RiskBudgetPct decimal.Decimal
MinOrderNotionalRUB decimal.Decimal
AssumedSpreadBps decimal.Decimal
AssumedTickBps decimal.Decimal
Lot int64
UseMinuteModel bool
}
type Trade struct {
InstrumentUID string `json:"instrument_uid"`
EntryDate string `json:"entry_date"`
ExitDate string `json:"exit_date"`
BuyPrice decimal.Decimal `json:"buy_price"`
SellPrice decimal.Decimal `json:"sell_price"`
Return decimal.Decimal `json:"return"`
Lots int64 `json:"lots"`
Notional decimal.Decimal `json:"notional"`
NetPnL decimal.Decimal `json:"net_pnl"`
SpreadBps decimal.Decimal `json:"spread_bps"`
SlippageBps decimal.Decimal `json:"slippage_bps"`
OvernightGap decimal.Decimal `json:"overnight_gap"`
CapacityRUB decimal.Decimal `json:"capacity_rub"`
}
type Result struct {
Metrics Metrics `json:"metrics"`
Trades []Trade `json:"trades"`
Equity []Point `json:"equity"`
}
type Point struct {
Date string `json:"date"`
Equity decimal.Decimal `json:"equity"`
Return decimal.Decimal `json:"return"`
}
type Engine struct {
cfg Config
}
func New(cfg Config) Engine {
cfg = cfg.withDefaults()
return Engine{cfg: cfg}
}
func (cfg Config) withDefaults() Config {
if cfg.InitialEquity.IsZero() {
cfg.InitialEquity = decimal.NewFromInt(100_000)
}
if cfg.RollingShort == 0 {
cfg.RollingShort = 60
}
if cfg.RollingLong == 0 {
cfg.RollingLong = 252
}
if cfg.EWMALambda == 0 {
cfg.EWMALambda = 0.08
}
if cfg.MinTStat60.IsZero() {
cfg.MinTStat60 = decimal.NewFromFloat(1.25)
}
if cfg.MinWinRate60.IsZero() {
cfg.MinWinRate60 = decimal.NewFromFloat(0.55)
}
if cfg.MinNetEdgeBps.IsZero() {
cfg.MinNetEdgeBps = decimal.NewFromInt(10)
}
if cfg.MinADVRUB.IsZero() {
cfg.MinADVRUB = decimal.NewFromInt(5_000_000)
}
if cfg.MaxSpreadBps.IsZero() {
cfg.MaxSpreadBps = decimal.NewFromInt(20)
}
if cfg.MaxTickBps.IsZero() {
cfg.MaxTickBps = decimal.NewFromInt(10)
}
if !cfg.RequireZeroCommission && cfg.CommissionRoundtripBps.IsZero() {
cfg.RequireZeroCommission = true
}
if cfg.MaxPositions == 0 {
cfg.MaxPositions = 5
}
if cfg.MaxPositionPct.IsZero() {
cfg.MaxPositionPct = decimal.NewFromFloat(0.10)
}
if cfg.MaxTotalExposurePct.IsZero() {
cfg.MaxTotalExposurePct = decimal.NewFromFloat(0.50)
}
if cfg.MaxParticipationRate.IsZero() {
cfg.MaxParticipationRate = decimal.NewFromFloat(0.01)
}
if cfg.CashUsageBuffer.IsZero() {
cfg.CashUsageBuffer = decimal.NewFromFloat(0.95)
}
if cfg.RiskBudgetPct.IsZero() {
cfg.RiskBudgetPct = decimal.NewFromFloat(0.005)
}
if cfg.MinOrderNotionalRUB.IsZero() {
cfg.MinOrderNotionalRUB = decimal.NewFromInt(1000)
}
if cfg.Lot == 0 {
cfg.Lot = 1
}
return cfg
}
func (e Engine) Run(candlesByInstrument map[string][]domain.Candle) (Result, error) {
return e.RunWithMinuteCandles(candlesByInstrument, nil)
}
func (e Engine) RunWithMinuteCandles(candlesByInstrument map[string][]domain.Candle, minuteCandlesByInstrument map[string][]domain.Candle) (Result, error) {
prepared := prepareCandles(candlesByInstrument)
preparedMinutes := prepareCandles(minuteCandlesByInstrument)
candidatesByExitDate := make(map[string][]candidate)
for instrumentUID, candles := range prepared {
for i := 1; i < len(candles); i++ {
candidate, ok, err := e.evaluateCandidate(instrumentUID, candles, i)
if err != nil {
return Result{}, err
}
if ok {
candidatesByExitDate[candidate.exit.TradeDate.Format("2006-01-02")] = append(candidatesByExitDate[candidate.exit.TradeDate.Format("2006-01-02")], candidate)
}
}
}
dates := make([]string, 0, len(candidatesByExitDate))
for date := range candidatesByExitDate {
dates = append(dates, date)
}
sort.Strings(dates)
equity := e.cfg.InitialEquity
cash := e.cfg.InitialEquity
var trades []Trade
points := []Point{{Date: "START", Equity: equity}}
sizer := risk.NewSizer(risk.SizingConfig{
MaxPositionPct: e.cfg.MaxPositionPct,
MaxTotalExposurePct: e.cfg.MaxTotalExposurePct,
MaxParticipationRate: e.cfg.MaxParticipationRate,
CashUsageBuffer: e.cfg.CashUsageBuffer,
RiskBudgetPerInstrumentPct: e.cfg.RiskBudgetPct,
MinOrderNotionalRUB: e.cfg.MinOrderNotionalRUB,
})
for _, date := range dates {
dayCandidates := candidatesByExitDate[date]
sort.Slice(dayCandidates, func(i, j int) bool {
if dayCandidates[i].netEdge.Equal(dayCandidates[j].netEdge) {
return dayCandidates[i].instrumentUID < dayCandidates[j].instrumentUID
}
return dayCandidates[i].netEdge.GreaterThan(dayCandidates[j].netEdge)
})
if len(dayCandidates) > e.cfg.MaxPositions {
dayCandidates = dayCandidates[:e.cfg.MaxPositions]
}
dayStartEquity := equity
dayPnL := decimal.Zero
for _, c := range dayCandidates {
sized := sizer.Size(risk.SizingInput{
Portfolio: domain.Portfolio{Equity: equity, Cash: cash},
SelectedInstruments: len(dayCandidates),
LimitPrice: c.buy,
Lot: e.cfg.Lot,
EntryIntervalVolume: c.adv,
ExitIntervalVolume: c.adv,
Q05OvernightAbs: c.q05Abs,
})
if sized.Lots <= 0 {
continue
}
lots := sized.Lots
capacity := c.capacity
if e.cfg.UseMinuteModel {
executedLots, minuteCapacity, ok := e.minuteExecution(c, preparedMinutes[c.instrumentUID], sized.Lots)
if !ok {
continue
}
lots = executedLots
capacity = minuteCapacity
}
notional := c.buy.Mul(decimal.NewFromInt(lots)).Mul(decimal.NewFromInt(e.cfg.Lot))
ret := c.sell.Div(c.buy).Sub(decimal.NewFromInt(1)).Sub(money.FromBps(e.cfg.CommissionRoundtripBps))
pnl := notional.Mul(ret)
dayPnL = dayPnL.Add(pnl)
cash = cash.Sub(notional)
trades = append(trades, Trade{
InstrumentUID: c.instrumentUID,
EntryDate: c.entry.TradeDate.Format("2006-01-02"),
ExitDate: c.exit.TradeDate.Format("2006-01-02"),
BuyPrice: c.buy,
SellPrice: c.sell,
Return: ret,
Lots: lots,
Notional: notional,
NetPnL: pnl,
SpreadBps: e.cfg.AssumedSpreadBps,
SlippageBps: e.cfg.EntrySlippageBps.Add(e.cfg.ExitSlippageBps),
OvernightGap: c.overnightGap,
CapacityRUB: capacity,
})
}
if !dayPnL.IsZero() {
equity = equity.Add(dayPnL)
cash = equity
points = append(points, Point{
Date: date,
Equity: equity,
Return: dayPnL.Div(dayStartEquity),
})
}
}
sort.Slice(trades, func(i, j int) bool {
if trades[i].ExitDate == trades[j].ExitDate {
return trades[i].InstrumentUID < trades[j].InstrumentUID
}
return trades[i].ExitDate < trades[j].ExitDate
})
return Result{
Metrics: ComputeMetrics(points, trades),
Trades: trades,
Equity: points,
}, nil
}
func (e Engine) minuteExecution(c candidate, minutes []domain.Candle, requestedLots int64) (int64, decimal.Decimal, bool) {
if requestedLots <= 0 || len(minutes) == 0 {
return 0, decimal.Zero, false
}
entryLots, entryCapacity := e.fillableMinuteLots(minutes, c.entry.TradeDate, c.buy, domain.SideBuy)
exitLots, exitCapacity := e.fillableMinuteLots(minutes, c.exit.TradeDate, c.sell, domain.SideSell)
lots := min(requestedLots, entryLots)
lots = min(lots, exitLots)
if lots <= 0 {
return 0, decimal.Zero, false
}
return lots, money.Min(entryCapacity, exitCapacity), true
}
func (e Engine) fillableMinuteLots(minutes []domain.Candle, date time.Time, limitPrice decimal.Decimal, side domain.Side) (int64, decimal.Decimal) {
if !limitPrice.IsPositive() || e.cfg.Lot <= 0 {
return 0, decimal.Zero
}
lotNotional := limitPrice.Mul(decimal.NewFromInt(e.cfg.Lot))
if !lotNotional.IsPositive() {
return 0, decimal.Zero
}
capacity := decimal.Zero
for _, candle := range minutes {
if !sameDate(candle.TradeDate, date) {
continue
}
reachable := side == domain.SideBuy && candle.Low.LessThanOrEqual(limitPrice)
reachable = reachable || side == domain.SideSell && candle.High.GreaterThanOrEqual(limitPrice)
if !reachable {
continue
}
minuteCapacity := candle.VolumeLots.Mul(lotNotional).Mul(e.cfg.MaxParticipationRate)
capacity = capacity.Add(minuteCapacity)
}
return capacity.Div(lotNotional).Floor().IntPart(), capacity
}
type candidate struct {
instrumentUID string
entry domain.Candle
exit domain.Candle
buy decimal.Decimal
sell decimal.Decimal
netEdge decimal.Decimal
adv decimal.Decimal
q05Abs decimal.Decimal
overnightGap decimal.Decimal
capacity decimal.Decimal
}
func (e Engine) evaluateCandidate(instrumentUID string, candles []domain.Candle, exitIndex int) (candidate, bool, error) {
if exitIndex < e.cfg.RollingShort || exitIndex <= 0 {
return candidate{}, false, nil
}
history := candles[:exitIndex]
returns := make([]float64, 0, len(history)-1)
for j := 1; j < len(history); j++ {
r, err := features.OvernightReturn(history[j].Open, history[j-1].Close)
if err != nil {
return candidate{}, false, err
}
rf, _ := r.Float64()
returns = append(returns, rf)
}
short := features.Rolling(returns, e.cfg.RollingShort, e.cfg.EWMALambda)
long := features.Rolling(returns, min(e.cfg.RollingLong, len(returns)), e.cfg.EWMALambda)
if !short.Available || !long.Available || short.StdDev == 0 {
return candidate{}, false, nil
}
rawEdge := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000))
cost := e.cfg.AssumedSpreadBps.
Add(e.cfg.EntrySlippageBps).
Add(e.cfg.ExitSlippageBps).
Add(e.cfg.CommissionRoundtripBps)
netEdge := rawEdge.Sub(cost)
adv := features.ADV(history, e.cfg.Lot, 20)
switch {
case e.cfg.RequireZeroCommission && e.cfg.CommissionRoundtripBps.IsPositive():
return candidate{}, false, nil
case !decimal.NewFromFloat(short.Mean).IsPositive() || !decimal.NewFromFloat(long.Mean).IsPositive():
return candidate{}, false, nil
case decimal.NewFromFloat(short.TStat).LessThan(e.cfg.MinTStat60):
return candidate{}, false, nil
case decimal.NewFromFloat(short.WinRate).LessThan(e.cfg.MinWinRate60):
return candidate{}, false, nil
case netEdge.LessThan(e.cfg.MinNetEdgeBps):
return candidate{}, false, nil
case e.cfg.AssumedSpreadBps.GreaterThan(e.cfg.MaxSpreadBps):
return candidate{}, false, nil
case e.cfg.AssumedTickBps.GreaterThan(e.cfg.MaxTickBps):
return candidate{}, false, nil
case adv.LessThan(e.cfg.MinADVRUB):
return candidate{}, false, nil
}
entry := candles[exitIndex-1]
exit := candles[exitIndex]
buy := entry.Close.Mul(decimal.NewFromInt(1).Add(money.FromBps(e.cfg.EntrySlippageBps)))
sell := exit.Open.Mul(decimal.NewFromInt(1).Sub(money.FromBps(e.cfg.ExitSlippageBps)))
gap, err := features.OvernightReturn(exit.Open, entry.Close)
if err != nil {
return candidate{}, false, err
}
q05Abs := decimal.NewFromFloat(features.Quantile(returns, 0.05))
if q05Abs.IsNegative() {
q05Abs = q05Abs.Neg()
}
return candidate{
instrumentUID: instrumentUID,
entry: entry,
exit: exit,
buy: buy,
sell: sell,
netEdge: netEdge,
adv: adv,
q05Abs: q05Abs,
overnightGap: gap,
capacity: adv.Mul(e.cfg.MaxParticipationRate),
}, true, nil
}
func prepareCandles(candlesByInstrument map[string][]domain.Candle) map[string][]domain.Candle {
prepared := make(map[string][]domain.Candle, len(candlesByInstrument))
for instrumentUID, candles := range candlesByInstrument {
cp := append([]domain.Candle(nil), candles...)
sort.Slice(cp, func(i, j int) bool {
return cp[i].TradeDate.Before(cp[j].TradeDate)
})
prepared[instrumentUID] = cp
}
return prepared
}
func (r Result) Write(outputDir string) error {
if outputDir == "" {
outputDir = "./backtest_out"
}
if err := os.MkdirAll(outputDir, 0o750); err != nil {
return err
}
summary, err := json.MarshalIndent(r.Metrics, "", " ")
if err != nil {
return err
}
if err := os.WriteFile(filepath.Join(outputDir, "summary.json"), summary, 0o600); err != nil {
return err
}
if err := writeTrades(filepath.Join(outputDir, "trades.csv"), r.Trades); err != nil {
return err
}
return writeEquity(filepath.Join(outputDir, "equity.csv"), r.Equity)
}
func LoadCandlesCSV(reader io.Reader) (map[string][]domain.Candle, error) {
r := csv.NewReader(reader)
r.FieldsPerRecord = -1
records, err := r.ReadAll()
if err != nil {
return nil, err
}
out := make(map[string][]domain.Candle)
for i, record := range records {
if i == 0 && len(record) > 0 && record[0] == "instrument_uid" {
continue
}
if len(record) < 7 {
return nil, fmt.Errorf("line %d: expected 7 fields", i+1)
}
date, err := parseCandleTime(record[1])
if err != nil {
return nil, err
}
open, err := decimal.NewFromString(record[2])
if err != nil {
return nil, err
}
high, err := decimal.NewFromString(record[3])
if err != nil {
return nil, err
}
low, err := decimal.NewFromString(record[4])
if err != nil {
return nil, err
}
closePrice, err := decimal.NewFromString(record[5])
if err != nil {
return nil, err
}
volume, err := decimal.NewFromString(record[6])
if err != nil {
return nil, err
}
candle := domain.Candle{
InstrumentUID: record[0],
TradeDate: date,
Open: open,
High: high,
Low: low,
Close: closePrice,
VolumeLots: volume,
Source: "csv",
LoadedAt: time.Now().UTC(),
}
out[candle.InstrumentUID] = append(out[candle.InstrumentUID], candle)
}
return out, nil
}
func parseCandleTime(raw string) (time.Time, error) {
layouts := []string{
time.RFC3339,
"2006-01-02 15:04:05",
"2006-01-02T15:04:05",
"2006-01-02",
}
var lastErr error
for _, layout := range layouts {
parsed, err := time.Parse(layout, raw)
if err == nil {
return parsed.UTC(), nil
}
lastErr = err
}
return time.Time{}, lastErr
}
func sameDate(a, b time.Time) bool {
return dateOnly(a).Equal(dateOnly(b))
}
func dateOnly(t time.Time) time.Time {
y, m, d := t.UTC().Date()
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
}
func writeTrades(path string, trades []Trade) error {
// #nosec G304 -- path is the user-selected backtest output destination.
f, err := os.Create(path)
if err != nil {
return err
}
defer func() {
_ = f.Close()
}()
w := csv.NewWriter(f)
defer w.Flush()
if err := w.Write([]string{"instrument_uid", "entry_date", "exit_date", "buy_price", "sell_price", "return", "lots", "notional", "net_pnl", "spread_bps", "slippage_bps", "overnight_gap", "capacity_rub"}); err != nil {
return err
}
for _, trade := range trades {
if err := w.Write([]string{
trade.InstrumentUID,
trade.EntryDate,
trade.ExitDate,
trade.BuyPrice.String(),
trade.SellPrice.String(),
trade.Return.String(),
fmt.Sprintf("%d", trade.Lots),
trade.Notional.String(),
trade.NetPnL.String(),
trade.SpreadBps.String(),
trade.SlippageBps.String(),
trade.OvernightGap.String(),
trade.CapacityRUB.String(),
}); err != nil {
return err
}
}
return w.Error()
}
func writeEquity(path string, points []Point) error {
// #nosec G304 -- path is the user-selected backtest output destination.
f, err := os.Create(path)
if err != nil {
return err
}
defer func() {
_ = f.Close()
}()
w := csv.NewWriter(f)
defer w.Flush()
if err := w.Write([]string{"date", "equity", "return"}); err != nil {
return err
}
for _, point := range points {
if err := w.Write([]string{point.Date, point.Equity.String(), point.Return.String()}); err != nil {
return err
}
}
return w.Error()
}
func ParseDecimalFlag(raw string) (decimal.Decimal, error) {
if raw == "" {
return decimal.Zero, nil
}
return decimal.NewFromString(raw)
}
+70
View File
@@ -0,0 +1,70 @@
package backtest
import (
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
func TestBacktestNoLookAheadWithFutureOnlyEdge(t *testing.T) {
var candles []domain.Candle
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
for i := 0; i < 80; i++ {
open := decimal.NewFromInt(100)
if i == 79 {
open = decimal.NewFromInt(110)
}
candles = append(candles, domain.Candle{
InstrumentUID: "uid",
TradeDate: start.AddDate(0, 0, i),
Open: open,
High: open,
Low: open,
Close: decimal.NewFromInt(100),
})
}
result, err := New(Config{}).Run(map[string][]domain.Candle{"uid": candles})
if err != nil {
t.Fatal(err)
}
if len(result.Trades) != 0 {
t.Fatalf("future-only edge leaked into signals: %d trades", len(result.Trades))
}
}
func TestMinuteExecutionRequiresReachableLimitAndParticipation(t *testing.T) {
engine := New(Config{
Lot: 10,
MaxParticipationRate: decimal.NewFromFloat(0.10),
})
entryDate := time.Date(2024, 1, 2, 18, 25, 0, 0, time.UTC)
exitDate := time.Date(2024, 1, 3, 10, 5, 0, 0, time.UTC)
c := candidate{
instrumentUID: "uid",
entry: domain.Candle{TradeDate: entryDate},
exit: domain.Candle{TradeDate: exitDate},
buy: decimal.NewFromInt(100),
sell: decimal.NewFromInt(105),
}
minutes := []domain.Candle{
{TradeDate: entryDate, Low: decimal.NewFromInt(99), High: decimal.NewFromInt(101), VolumeLots: decimal.NewFromInt(20)},
{TradeDate: exitDate, Low: decimal.NewFromInt(104), High: decimal.NewFromInt(106), VolumeLots: decimal.NewFromInt(20)},
}
lots, capacity, ok := engine.minuteExecution(c, minutes, 5)
if !ok {
t.Fatal("expected minute execution")
}
if lots != 2 {
t.Fatalf("lots=%d, want 2", lots)
}
if !capacity.Equal(decimal.NewFromInt(2000)) {
t.Fatalf("capacity=%s, want 2000", capacity)
}
c.sell = decimal.NewFromInt(110)
if _, _, ok := engine.minuteExecution(c, minutes, 5); ok {
t.Fatal("sell limit should be unreachable")
}
}
+213
View File
@@ -0,0 +1,213 @@
package backtest
import (
"math"
"sort"
"github.com/shopspring/decimal"
)
type Metrics struct {
TotalReturn float64 `json:"total_return"`
CAGR float64 `json:"cagr"`
AnnualizedVolatility float64 `json:"annualized_volatility"`
SharpeRatio float64 `json:"sharpe_ratio"`
SortinoRatio float64 `json:"sortino_ratio"`
MaxDrawdown float64 `json:"max_drawdown"`
CalmarRatio float64 `json:"calmar_ratio"`
WinRate float64 `json:"win_rate"`
AverageTradeReturn float64 `json:"average_trade_return"`
MedianTradeReturn float64 `json:"median_trade_return"`
ProfitFactor float64 `json:"profit_factor"`
AverageSpreadBps float64 `json:"average_spread_bps"`
AverageSlippageBps float64 `json:"average_slippage_bps"`
NumberOfTrades int `json:"number_of_trades"`
WorstOvernightGap float64 `json:"worst_overnight_gap"`
VaR95 float64 `json:"var_95"`
CVaR95 float64 `json:"cvar_95"`
CapacityEstimate float64 `json:"capacity_estimate"`
}
func ComputeMetrics(points []Point, trades []Trade) Metrics {
if len(points) == 0 {
return Metrics{}
}
start, _ := points[0].Equity.Float64()
end, _ := points[len(points)-1].Equity.Float64()
returns := make([]float64, 0, len(points)-1)
for _, point := range points[1:] {
r, _ := point.Return.Float64()
returns = append(returns, r)
}
tradeReturns := make([]float64, 0, len(trades))
spreads := make([]float64, 0, len(trades))
slippages := make([]float64, 0, len(trades))
profits := 0.0
losses := 0.0
wins := 0
worstGap := 0.0
capacity := 0.0
for _, trade := range trades {
r, _ := trade.Return.Float64()
tradeReturns = append(tradeReturns, r)
spread, _ := trade.SpreadBps.Float64()
spreads = append(spreads, spread)
slippage, _ := trade.SlippageBps.Float64()
slippages = append(slippages, slippage)
if r > 0 {
wins++
profits += r
} else {
losses += r
}
gap, _ := trade.OvernightGap.Float64()
if gap < worstGap {
worstGap = gap
}
tradeCapacity, _ := trade.CapacityRUB.Float64()
if tradeCapacity > 0 && (capacity == 0 || tradeCapacity < capacity) {
capacity = tradeCapacity
}
}
totalReturn := 0.0
if start > 0 {
totalReturn = end/start - 1
}
vol := stddev(returns) * math.Sqrt(252)
meanReturn := mean(returns)
sharpe := 0.0
if std := stddev(returns); std > 0 {
sharpe = meanReturn / std * math.Sqrt(252)
}
sortino := 0.0
if down := downsideStddev(returns); down > 0 {
sortino = meanReturn / down * math.Sqrt(252)
}
tradingDays := math.Max(float64(len(returns)), 1)
cagr := 0.0
if start > 0 && end > 0 {
cagr = math.Pow(end/start, 252/tradingDays) - 1
}
maxDD := maxDrawdown(points)
calmar := 0.0
if maxDD != 0 {
calmar = cagr / math.Abs(maxDD)
}
pf := 0.0
if losses != 0 {
pf = profits / math.Abs(losses)
}
var95 := percentile(returns, 0.05)
cvar95 := conditionalMean(returns, var95)
return Metrics{
TotalReturn: totalReturn,
CAGR: cagr,
AnnualizedVolatility: vol,
SharpeRatio: sharpe,
SortinoRatio: sortino,
MaxDrawdown: maxDD,
CalmarRatio: calmar,
WinRate: ratio(wins, len(tradeReturns)),
AverageTradeReturn: mean(tradeReturns),
MedianTradeReturn: percentile(tradeReturns, 0.50),
ProfitFactor: pf,
AverageSpreadBps: mean(spreads),
AverageSlippageBps: mean(slippages),
NumberOfTrades: len(trades),
WorstOvernightGap: worstGap,
VaR95: var95,
CVaR95: cvar95,
CapacityEstimate: capacity,
}
}
func maxDrawdown(points []Point) float64 {
peak := 0.0
maxDD := 0.0
for _, point := range points {
e, _ := point.Equity.Float64()
if e > peak {
peak = e
}
if peak > 0 {
dd := e/peak - 1
if dd < maxDD {
maxDD = dd
}
}
}
return maxDD
}
func mean(values []float64) float64 {
if len(values) == 0 {
return 0
}
sum := 0.0
for _, value := range values {
sum += value
}
return sum / float64(len(values))
}
func stddev(values []float64) float64 {
if len(values) < 2 {
return 0
}
m := mean(values)
sum := 0.0
for _, value := range values {
diff := value - m
sum += diff * diff
}
return math.Sqrt(sum / float64(len(values)-1))
}
func downsideStddev(values []float64) float64 {
var downs []float64
for _, value := range values {
if value < 0 {
downs = append(downs, value)
}
}
return stddev(downs)
}
func percentile(values []float64, q float64) float64 {
if len(values) == 0 {
return 0
}
cp := append([]float64(nil), values...)
sort.Float64s(cp)
pos := q * float64(len(cp)-1)
lower := int(math.Floor(pos))
upper := int(math.Ceil(pos))
if lower == upper {
return cp[lower]
}
weight := pos - float64(lower)
return cp[lower]*(1-weight) + cp[upper]*weight
}
func conditionalMean(values []float64, threshold float64) float64 {
var selected []float64
for _, value := range values {
if value <= threshold {
selected = append(selected, value)
}
}
return mean(selected)
}
func ratio(n, d int) float64 {
if d == 0 {
return 0
}
return float64(n) / float64(d)
}
func point(date string, equity, ret string) Point {
e, _ := decimal.NewFromString(equity)
r, _ := decimal.NewFromString(ret)
return Point{Date: date, Equity: e, Return: r}
}
+14
View File
@@ -0,0 +1,14 @@
package backtest
import "testing"
func TestMetrics(t *testing.T) {
got := ComputeMetrics([]Point{
point("START", "100", "0"),
point("2024-01-02", "110", "0.10"),
point("2024-01-03", "99", "-0.10"),
}, []Trade{{Return: point("", "0", "0.10").Return}, {Return: point("", "0", "-0.10").Return}})
if got.NumberOfTrades != 2 || got.WinRate != 0.5 || got.MaxDrawdown >= 0 || got.VaR95 >= 0 {
t.Fatalf("unexpected metrics: %+v", got)
}
}
+234
View File
@@ -0,0 +1,234 @@
package config
import (
"errors"
"fmt"
"time"
"github.com/caarlos0/env/v11"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/timeutil"
)
const liveTradeAck = "I_ACCEPT_RISK"
const maxQuoteDepth = 50
type Config struct {
App AppConfig `envPrefix:"APP_"`
TInvest TInvestConfig `envPrefix:"TINVEST_"`
DB DBConfig `envPrefix:"DB_"`
Telegram TelegramConfig `envPrefix:"TELEGRAM_"`
Strategy StrategyConfig `envPrefix:"STRATEGY_"`
Execution ExecutionConfig `envPrefix:"EXEC_"`
Risk RiskConfig `envPrefix:"RISK_"`
Liquidity LiquidityConfig `envPrefix:"LIQ_"`
Commission CommissionConfig `envPrefix:"COMM_"`
Backtest BacktestConfig `envPrefix:"BT_"`
Live LiveConfig `envPrefix:"LIVE_"`
Location *time.Location `env:"-"`
}
type AppConfig struct {
Mode domain.Mode `env:"MODE,required"`
Timezone string `env:"TIMEZONE" envDefault:"Europe/Moscow"`
LogLevel string `env:"LOG_LEVEL" envDefault:"info"`
HealthcheckAddr string `env:"HEALTHCHECK_ADDR" envDefault:":3300"`
ShutdownTimeoutSec int `env:"SHUTDOWN_TIMEOUT_SEC" envDefault:"30"`
}
type TInvestConfig struct {
Token string `env:"TOKEN"`
AccountID string `env:"ACCOUNT_ID"`
Endpoint string `env:"ENDPOINT" envDefault:"invest-public-api.tinkoff.ru:443"`
AppName string `env:"APP_NAME" envDefault:"overnight-trading-bot"`
RequestTimeoutSec int `env:"REQUEST_TIMEOUT_SEC" envDefault:"10"`
RetryCount int `env:"RETRY_COUNT" envDefault:"3"`
RetryBackoffSec int `env:"RETRY_BACKOFF_SEC" envDefault:"2"`
UseSandbox bool `env:"USE_SANDBOX" envDefault:"false"`
}
type DBConfig struct {
DSN string `env:"DSN"`
MaxOpenConns int `env:"MAX_OPEN_CONNS" envDefault:"20"`
MaxIdleConns int `env:"MAX_IDLE_CONNS" envDefault:"5"`
ConnMaxLifetimeMin int `env:"CONN_MAX_LIFETIME_MIN" envDefault:"30"`
MigrationsAutoApply bool `env:"MIGRATIONS_AUTO_APPLY" envDefault:"true"`
}
type TelegramConfig struct {
BotToken string `env:"BOT_TOKEN"`
ChatID int64 `env:"CHAT_ID"`
NotifyInfo bool `env:"NOTIFY_INFO" envDefault:"true"`
NotifyWarn bool `env:"NOTIFY_WARN" envDefault:"true"`
NotifyAlert bool `env:"NOTIFY_ALERT" envDefault:"true"`
NotifyReport bool `env:"NOTIFY_REPORT" envDefault:"true"`
}
type StrategyConfig struct {
RollingShort int `env:"ROLLING_SHORT" envDefault:"60"`
RollingLong int `env:"ROLLING_LONG" envDefault:"252"`
EWMALambda float64 `env:"EWMA_LAMBDA" envDefault:"0.08"`
MinTStat60 decimal.Decimal `env:"MIN_TSTAT_60" envDefault:"1.25"`
MinWinRate60 decimal.Decimal `env:"MIN_WIN_RATE_60" envDefault:"0.55"`
MinNetEdgeBps decimal.Decimal `env:"MIN_NET_EDGE_BPS" envDefault:"10"`
RiskBufferBps decimal.Decimal `env:"RISK_BUFFER_BPS" envDefault:"5"`
MaxPositions int `env:"MAX_POSITIONS" envDefault:"5"`
}
type ExecutionConfig struct {
EntrySignalTime timeutil.TimeOfDay `env:"ENTRY_SIGNAL_TIME" envDefault:"18:10:00"`
EntryWindowStart timeutil.TimeOfDay `env:"ENTRY_WINDOW_START" envDefault:"18:20:00"`
EntryWindowEnd timeutil.TimeOfDay `env:"ENTRY_WINDOW_END" envDefault:"18:38:30"`
NoNewEntryAfter timeutil.TimeOfDay `env:"NO_NEW_ENTRY_AFTER" envDefault:"18:38:30"`
ExitWatchStart timeutil.TimeOfDay `env:"EXIT_WATCH_START" envDefault:"09:50:00"`
ExitNotBefore timeutil.TimeOfDay `env:"EXIT_NOT_BEFORE" envDefault:"10:03:00"`
ExitWindowStart timeutil.TimeOfDay `env:"EXIT_WINDOW_START" envDefault:"10:05:00"`
ExitWindowEnd timeutil.TimeOfDay `env:"EXIT_WINDOW_END" envDefault:"10:25:00"`
HardExitDeadline timeutil.TimeOfDay `env:"HARD_EXIT_DEADLINE" envDefault:"10:45:00"`
MinTimeToCloseSec int `env:"MIN_TIME_TO_CLOSE_SEC" envDefault:"90"`
AllowMarketOrders bool `env:"ALLOW_MARKET_ORDERS" envDefault:"false"`
MaxEntryOrderAttempts int `env:"MAX_ENTRY_ORDER_ATTEMPTS" envDefault:"3"`
MaxExitOrderAttempts int `env:"MAX_EXIT_ORDER_ATTEMPTS" envDefault:"3"`
PassiveImproveTicks int `env:"PASSIVE_IMPROVE_TICKS" envDefault:"1"`
QuoteDepth int32 `env:"QUOTE_DEPTH" envDefault:"20"`
MaxQuoteAgeSec int `env:"MAX_QUOTE_AGE_SEC" envDefault:"3"`
OrderPollIntervalMS int `env:"ORDER_POLL_INTERVAL_MS" envDefault:"500"`
}
type RiskConfig struct {
UseMargin bool `env:"USE_MARGIN" envDefault:"false"`
AllowShort bool `env:"ALLOW_SHORT" envDefault:"false"`
MaxTotalExposurePct decimal.Decimal `env:"MAX_TOTAL_EXPOSURE_PCT" envDefault:"0.50"`
MaxPositionPct decimal.Decimal `env:"MAX_POSITION_PCT" envDefault:"0.10"`
MaxDailyLossPct decimal.Decimal `env:"MAX_DAILY_LOSS_PCT" envDefault:"0.01"`
MaxWeeklyLossPct decimal.Decimal `env:"MAX_WEEKLY_LOSS_PCT" envDefault:"0.03"`
MaxMonthlyDrawdownPct decimal.Decimal `env:"MAX_MONTHLY_DRAWDOWN_PCT" envDefault:"0.07"`
MaxOpenPositions int `env:"MAX_OPEN_POSITIONS" envDefault:"5"`
MaxAvgSlippageBps10Trades decimal.Decimal `env:"MAX_AVG_SLIPPAGE_BPS_10_TRADES" envDefault:"15"`
APIOutageHaltSec int `env:"API_OUTAGE_HALT_SEC" envDefault:"180"`
MaxClockDriftSec int `env:"MAX_CLOCK_DRIFT_SEC" envDefault:"2"`
ReconciliationWindowHours int `env:"RECONCILIATION_WINDOW_HOURS" envDefault:"72"`
ReconciliationSkewSec int `env:"RECONCILIATION_SKEW_SEC" envDefault:"10"`
CashUsageBuffer decimal.Decimal `env:"CASH_USAGE_BUFFER" envDefault:"0.95"`
RiskBudgetPerInstrumentPct decimal.Decimal `env:"RISK_BUDGET_PER_INSTRUMENT_PCT" envDefault:"0.005"`
MinOrderNotionalRUB decimal.Decimal `env:"MIN_ORDER_NOTIONAL_RUB" envDefault:"1000"`
}
type LiquidityConfig struct {
MinADVRUB decimal.Decimal `env:"MIN_ADV_RUB" envDefault:"5000000"`
MaxParticipationRate decimal.Decimal `env:"MAX_PARTICIPATION_RATE" envDefault:"0.01"`
MaxSpreadBpsDefault decimal.Decimal `env:"MAX_SPREAD_BPS_DEFAULT" envDefault:"20"`
MaxSpreadBpsMoneyMarket decimal.Decimal `env:"MAX_SPREAD_BPS_MONEY_MARKET" envDefault:"5"`
MaxSpreadBpsBondFunds decimal.Decimal `env:"MAX_SPREAD_BPS_BOND_FUNDS" envDefault:"10"`
MaxSpreadBpsEquityFunds decimal.Decimal `env:"MAX_SPREAD_BPS_EQUITY_FUNDS" envDefault:"25"`
MaxTickBps decimal.Decimal `env:"MAX_TICK_BPS" envDefault:"10"`
}
type CommissionConfig struct {
RequireZeroCommission bool `env:"REQUIRE_ZERO_COMMISSION" envDefault:"true"`
QuarantineOnNonZero bool `env:"QUARANTINE_ON_NONZERO" envDefault:"true"`
FreeOrderCountPolicy string `env:"FREE_ORDER_COUNT_POLICY" envDefault:"submitted"`
}
type BacktestConfig struct {
DateFrom string `env:"DATE_FROM"`
DateTo string `env:"DATE_TO"`
EntrySlippageBps decimal.Decimal `env:"ENTRY_SLIPPAGE_BPS" envDefault:"8"`
ExitSlippageBps decimal.Decimal `env:"EXIT_SLIPPAGE_BPS" envDefault:"8"`
CommissionRoundtripBps decimal.Decimal `env:"COMMISSION_ROUNDTRIP_BPS" envDefault:"0"`
UseMinuteModel bool `env:"USE_MINUTE_MODEL" envDefault:"false"`
OutputDir string `env:"OUTPUT_DIR" envDefault:"./backtest_out"`
}
type LiveConfig struct {
TradeAck string `env:"TRADE_ACK"`
}
func Load() (Config, error) {
var cfg Config
if err := env.Parse(&cfg); err != nil {
return Config{}, err
}
if err := cfg.Validate(); err != nil {
return Config{}, err
}
return cfg, nil
}
func (c *Config) Validate() error {
if c.App.Mode == "" {
return errors.New("APP_MODE is required")
}
loc, err := time.LoadLocation(c.App.Timezone)
if err != nil {
return fmt.Errorf("load timezone %q: %w", c.App.Timezone, err)
}
if c.App.Timezone != "Europe/Moscow" {
return fmt.Errorf("APP_TIMEZONE must be Europe/Moscow, got %q", c.App.Timezone)
}
c.Location = loc
if c.App.ShutdownTimeoutSec <= 0 {
return errors.New("APP_SHUTDOWN_TIMEOUT_SEC must be positive")
}
if c.Execution.AllowMarketOrders {
return errors.New("EXEC_ALLOW_MARKET_ORDERS must remain false: strategy is LIMIT-only")
}
if c.Execution.QuoteDepth <= 0 || c.Execution.QuoteDepth > maxQuoteDepth {
return fmt.Errorf("EXEC_QUOTE_DEPTH must be between 1 and %d", maxQuoteDepth)
}
if c.Execution.OrderPollIntervalMS <= 0 {
return errors.New("EXEC_ORDER_POLL_INTERVAL_MS must be positive")
}
if c.Risk.UseMargin {
return errors.New("RISK_USE_MARGIN must remain false")
}
if c.Risk.AllowShort {
return errors.New("RISK_ALLOW_SHORT must remain false")
}
if c.Risk.APIOutageHaltSec <= 0 {
return errors.New("RISK_API_OUTAGE_HALT_SEC must be positive")
}
if c.Risk.ReconciliationWindowHours <= 0 {
return errors.New("RISK_RECONCILIATION_WINDOW_HOURS must be positive")
}
if c.Risk.ReconciliationSkewSec < 0 {
return errors.New("RISK_RECONCILIATION_SKEW_SEC must be non-negative")
}
if c.Commission.FreeOrderCountPolicy != "submitted" {
return fmt.Errorf("COMM_FREE_ORDER_COUNT_POLICY must be submitted, got %q", c.Commission.FreeOrderCountPolicy)
}
if err := c.validateWindows(); err != nil {
return err
}
if c.App.Mode != domain.ModeBacktest && c.DB.DSN == "" {
return errors.New("DB_DSN is required outside backtest mode")
}
if (c.App.Mode == domain.ModeSandbox || c.App.Mode == domain.ModeLiveReadonly || c.App.Mode == domain.ModeLiveTrade) && c.TInvest.Token == "" {
return fmt.Errorf("TINVEST_TOKEN is required for APP_MODE=%s", c.App.Mode)
}
if c.TInvest.UseSandbox && c.App.Mode != domain.ModeSandbox {
return errors.New("TINVEST_USE_SANDBOX=true is only valid with APP_MODE=sandbox")
}
if c.App.Mode == domain.ModeLiveTrade && c.Live.TradeAck != liveTradeAck {
return fmt.Errorf("LIVE_TRADE_ACK=%s is required for APP_MODE=live_trade", liveTradeAck)
}
return nil
}
func (c Config) validateWindows() error {
if c.Execution.EntryWindowStart.Duration >= c.Execution.EntryWindowEnd.Duration ||
c.Execution.EntryWindowEnd.Duration > c.Execution.NoNewEntryAfter.Duration {
return errors.New("entry windows must satisfy EXEC_ENTRY_WINDOW_START < EXEC_ENTRY_WINDOW_END <= EXEC_NO_NEW_ENTRY_AFTER")
}
if c.Execution.ExitWatchStart.Duration > c.Execution.ExitNotBefore.Duration ||
c.Execution.ExitNotBefore.Duration > c.Execution.ExitWindowStart.Duration ||
c.Execution.ExitWindowStart.Duration >= c.Execution.ExitWindowEnd.Duration ||
c.Execution.ExitWindowEnd.Duration > c.Execution.HardExitDeadline.Duration {
return errors.New("exit windows must be monotonic from EXEC_EXIT_WATCH_START to EXEC_HARD_EXIT_DEADLINE")
}
return nil
}
+311
View File
@@ -0,0 +1,311 @@
package domain
import (
"fmt"
"strings"
"time"
"github.com/shopspring/decimal"
)
type Mode string
const (
ModeBacktest Mode = "backtest"
ModePaper Mode = "paper"
ModeSandbox Mode = "sandbox"
ModeLiveReadonly Mode = "live_readonly"
ModeLiveTrade Mode = "live_trade"
)
func ParseMode(raw string) (Mode, error) {
mode := Mode(strings.TrimSpace(raw))
switch mode {
case ModeBacktest, ModePaper, ModeSandbox, ModeLiveReadonly, ModeLiveTrade:
return mode, nil
default:
return "", fmt.Errorf("unsupported app mode %q", raw)
}
}
func (m Mode) AllowsBrokerOrders() bool {
return m == ModeSandbox || m == ModeLiveTrade
}
func (m *Mode) UnmarshalText(text []byte) error {
mode, err := ParseMode(string(text))
if err != nil {
return err
}
*m = mode
return nil
}
type Side string
const (
SideBuy Side = "BUY"
SideSell Side = "SELL"
)
type OrderType string
const (
OrderTypeLimit OrderType = "LIMIT"
)
type OrderStatus string
const (
OrderStatusNew OrderStatus = "NEW"
OrderStatusSent OrderStatus = "SENT"
OrderStatusPartiallyFilled OrderStatus = "PARTIALLY_FILLED"
OrderStatusFilled OrderStatus = "FILLED"
OrderStatusCancelled OrderStatus = "CANCELLED"
OrderStatusRejected OrderStatus = "REJECTED"
OrderStatusExpired OrderStatus = "EXPIRED"
OrderStatusFailed OrderStatus = "FAILED"
)
type SignalDecision string
const (
DecisionEnter SignalDecision = "ENTER"
DecisionSkip SignalDecision = "SKIP"
DecisionReject SignalDecision = "REJECT"
)
type PositionStatus string
const (
PositionNoPosition PositionStatus = "NO_POSITION"
PositionEntrySignalled PositionStatus = "ENTRY_SIGNALLED"
PositionEntryOrderSent PositionStatus = "ENTRY_ORDER_SENT"
PositionEntryPartiallyFilled PositionStatus = "ENTRY_PARTIALLY_FILLED"
PositionEntryFilled PositionStatus = "ENTRY_FILLED"
PositionHoldingOvernight PositionStatus = "HOLDING_OVERNIGHT"
PositionExitOrderSent PositionStatus = "EXIT_ORDER_SENT"
PositionExitPartiallyFilled PositionStatus = "EXIT_PARTIALLY_FILLED"
PositionExitFilled PositionStatus = "EXIT_FILLED"
PositionExitFailed PositionStatus = "EXIT_FAILED"
PositionQuarantine PositionStatus = "QUARANTINE"
)
type SystemState string
const (
StateInit SystemState = "INIT"
StateSyncInstruments SystemState = "SYNC_INSTRUMENTS"
StateSyncMarketData SystemState = "SYNC_MARKET_DATA"
StateGenerateSignals SystemState = "GENERATE_SIGNALS"
StateWaitEntryWindow SystemState = "WAIT_ENTRY_WINDOW"
StatePlaceEntryOrders SystemState = "PLACE_ENTRY_ORDERS"
StateMonitorEntryOrders SystemState = "MONITOR_ENTRY_ORDERS"
StateHoldOvernight SystemState = "HOLD_OVERNIGHT"
StateWaitExitWindow SystemState = "WAIT_EXIT_WINDOW"
StatePlaceExitOrders SystemState = "PLACE_EXIT_ORDERS"
StateMonitorExitOrders SystemState = "MONITOR_EXIT_ORDERS"
StateReconcile SystemState = "RECONCILE"
StateReport SystemState = "REPORT"
StateSleep SystemState = "SLEEP"
StateHalted SystemState = "HALTED"
)
type Severity string
const (
SeverityInfo Severity = "INFO"
SeverityWarn Severity = "WARN"
SeverityAlert Severity = "ALERT"
SeverityCritical Severity = "CRITICAL"
)
type TradingStatus string
const (
TradingStatusNormal TradingStatus = "NORMAL_TRADING"
TradingStatusClosed TradingStatus = "CLOSED"
TradingStatusUnknown TradingStatus = "UNKNOWN"
)
type Instrument struct {
InstrumentUID string
Figi string
Ticker string
ClassCode string
Name string
Lot int64
MinPriceIncrement decimal.Decimal
Currency string
Enabled bool
FundType string
ExpectedCommissionBpsPerSide decimal.Decimal
FreeOrderLimitPerDay int
Quarantine bool
QuarantineReason string
ExcludeReason string
UpdatedAt time.Time
}
func (i Instrument) MetadataValid() bool {
return i.InstrumentUID != "" &&
!strings.HasPrefix(i.InstrumentUID, "PENDING:") &&
i.Lot > 0 &&
i.MinPriceIncrement.IsPositive() &&
strings.EqualFold(i.Currency, "RUB")
}
type Candle struct {
InstrumentUID string
TradeDate time.Time
Open decimal.Decimal
High decimal.Decimal
Low decimal.Decimal
Close decimal.Decimal
VolumeLots decimal.Decimal
Source string
LoadedAt time.Time
}
type FeatureSet struct {
InstrumentUID string
TradeDate time.Time
ROn decimal.Decimal
RDay decimal.Decimal
MuOn60 decimal.Decimal
MuOn252 decimal.Decimal
SigmaOn60 decimal.Decimal
TStatOn60 decimal.Decimal
WinOn60 decimal.Decimal
EWMAOn decimal.Decimal
SpreadBps decimal.Decimal
HalfSpreadBps decimal.Decimal
TickBps decimal.Decimal
ADV20 decimal.Decimal
ExpectedCostBps decimal.Decimal
NetEdgeBps decimal.Decimal
EntryIntervalVolume decimal.Decimal
ExitIntervalVolume decimal.Decimal
CalculatedAt time.Time
}
type Signal struct {
ID int64
TradeDate time.Time
InstrumentUID string
Decision SignalDecision
Score decimal.Decimal
NetEdgeBps decimal.Decimal
TargetNotional decimal.Decimal
TargetLots int64
RejectReason string
ContextJSON string
CreatedAt time.Time
}
type Order struct {
ClientOrderID string
BrokerOrderID string
AccountIDHash string
InstrumentUID string
TradeDate time.Time
Side Side
OrderType OrderType
LimitPrice decimal.Decimal
QuantityLots int64
FilledLots int64
AvgFillPrice decimal.Decimal
Status OrderStatus
Commission decimal.Decimal
AttemptNo int
RawStateJSON string
CreatedAt time.Time
UpdatedAt time.Time
}
type Position struct {
ID int64
AccountIDHash string
InstrumentUID string
OpenTradeDate time.Time
Lots int64
Lot int64
ExitFilledLots int64
AvgBuyPrice decimal.Decimal
AvgSellPrice decimal.Decimal
Status PositionStatus
GrossPnL decimal.Decimal
NetPnL decimal.Decimal
CommissionTotal decimal.Decimal
RealizedEdgeBps decimal.Decimal
OpenedAt *time.Time
ClosedAt *time.Time
UpdatedAt time.Time
}
type RiskEvent struct {
ID int64
TS time.Time
Severity Severity
EventType string
InstrumentUID string
Message string
ContextJSON string
}
type Holding struct {
InstrumentUID string
QuantityLots int64
AveragePrice decimal.Decimal
MarketValue decimal.Decimal
}
type Portfolio struct {
Equity decimal.Decimal
Cash decimal.Decimal
Holdings []Holding
CheckedAt time.Time
}
type OrderBookLevel struct {
Price decimal.Decimal
QuantityLots int64
}
type OrderBook struct {
InstrumentUID string
Bids []OrderBookLevel
Asks []OrderBookLevel
Time time.Time
ReceivedAt time.Time
}
func (o OrderBook) BestBid() (decimal.Decimal, bool) {
if len(o.Bids) == 0 {
return decimal.Zero, false
}
return o.Bids[0].Price, true
}
func (o OrderBook) BestAsk() (decimal.Decimal, bool) {
if len(o.Asks) == 0 {
return decimal.Zero, false
}
return o.Asks[0].Price, true
}
type Operation struct {
ID string
InstrumentUID string
Type string
Payment decimal.Decimal
Commission decimal.Decimal
ExecutedAt time.Time
}
type ReconciliationDiff struct {
Kind string
InstrumentUID string
Message string
Critical bool
}
+405
View File
@@ -0,0 +1,405 @@
package execution
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/repository"
)
var ErrBrokerOrdersDisabled = errors.New("broker orders are disabled for current mode")
var ErrEmptyOrderBook = errors.New("order book has no usable bid/ask")
type Gateway interface {
PostLimitOrder(ctx context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error)
CancelOrder(ctx context.Context, accountID, orderID string) error
GetOrderState(ctx context.Context, accountID, orderID string) (domain.Order, error)
}
type Engine struct {
mode domain.Mode
accountID string
gateway Gateway
store repository.Repository
maxQuoteAge time.Duration
mu sync.Map
}
type MonitorConfig struct {
Deadline time.Time
PollInterval time.Duration
MaxAttempts int
RepostAfter time.Duration
Instrument domain.Instrument
ImproveTicks int
Quote func(ctx context.Context, instrumentUID string) (domain.OrderBook, error)
}
func NewEngine(mode domain.Mode, accountID string, gateway Gateway, store repository.Repository) Engine {
return Engine{mode: mode, accountID: accountID, gateway: gateway, store: store}
}
func (e *Engine) SetMaxQuoteAge(maxQuoteAge time.Duration) {
e.maxQuoteAge = maxQuoteAge
}
func (e *Engine) PlaceEntry(ctx context.Context, accountIDHash string, instrument domain.Instrument, tradeDate time.Time, lots int64, book domain.OrderBook, improveTicks int, attempt int) (domain.Order, error) {
if err := e.checkQuoteFresh(book); err != nil {
return domain.Order{}, err
}
bid, ask, err := bestBidAsk(book)
if err != nil {
return domain.Order{}, err
}
price, err := LimitBuyPrice(bid, ask, instrument.MinPriceIncrement, improveTicks)
if err != nil {
return domain.Order{}, err
}
return e.PlaceLimit(ctx, domain.Order{
ClientOrderID: ClientOrderID(tradeDate, instrument.InstrumentUID, domain.SideBuy, attempt),
AccountIDHash: accountIDHash,
InstrumentUID: instrument.InstrumentUID,
TradeDate: tradeDate,
Side: domain.SideBuy,
OrderType: domain.OrderTypeLimit,
LimitPrice: price,
QuantityLots: lots,
Status: domain.OrderStatusNew,
AttemptNo: attempt,
RawStateJSON: "{}",
})
}
func (e *Engine) PlaceExit(ctx context.Context, accountIDHash string, instrument domain.Instrument, tradeDate time.Time, lots int64, book domain.OrderBook, improveTicks int, attempt int) (domain.Order, error) {
if err := e.checkQuoteFresh(book); err != nil {
return domain.Order{}, err
}
bid, ask, err := bestBidAsk(book)
if err != nil {
return domain.Order{}, err
}
price, err := LimitSellPrice(bid, ask, instrument.MinPriceIncrement, improveTicks)
if err != nil {
return domain.Order{}, err
}
return e.PlaceLimit(ctx, domain.Order{
ClientOrderID: ClientOrderID(tradeDate, instrument.InstrumentUID, domain.SideSell, attempt),
AccountIDHash: accountIDHash,
InstrumentUID: instrument.InstrumentUID,
TradeDate: tradeDate,
Side: domain.SideSell,
OrderType: domain.OrderTypeLimit,
LimitPrice: price,
QuantityLots: lots,
Status: domain.OrderStatusNew,
AttemptNo: attempt,
RawStateJSON: "{}",
})
}
func (e *Engine) PlaceLimit(ctx context.Context, order domain.Order) (domain.Order, error) {
if e.store != nil {
existing, err := e.findExisting(ctx, order)
if err != nil {
return domain.Order{}, err
}
if existing.ClientOrderID != "" {
return existing, nil
}
}
if !e.mode.AllowsBrokerOrders() {
order.Status = domain.OrderStatusNew
if e.store != nil {
return order, e.store.UpsertOrder(ctx, order)
}
return order, ErrBrokerOrdersDisabled
}
if e.gateway == nil {
return domain.Order{}, errors.New("gateway is nil")
}
lock := e.lockFor(order.InstrumentUID)
lock.Lock()
defer lock.Unlock()
posted, err := e.gateway.PostLimitOrder(ctx, e.accountID, order.InstrumentUID, order.Side, order.QuantityLots, order.LimitPrice, order.ClientOrderID)
if err != nil {
order.Status = domain.OrderStatusFailed
if e.store != nil {
_ = e.store.UpsertOrder(ctx, order)
}
return domain.Order{}, err
}
posted.ClientOrderID = order.ClientOrderID
posted.AccountIDHash = order.AccountIDHash
posted.InstrumentUID = order.InstrumentUID
posted.Side = order.Side
posted.OrderType = order.OrderType
posted.LimitPrice = order.LimitPrice
posted.QuantityLots = order.QuantityLots
posted.AttemptNo = order.AttemptNo
posted.TradeDate = order.TradeDate
posted.CreatedAt = time.Now().UTC()
posted.UpdatedAt = posted.CreatedAt
if e.store != nil {
if err := e.store.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
if err := repo.UpsertOrder(ctx, posted); err != nil {
return fmt.Errorf("persist posted order: %w", err)
}
return repo.IncrementFreeOrders(ctx, order.TradeDate, order.InstrumentUID, 1)
}); err != nil {
return domain.Order{}, err
}
}
return posted, nil
}
func (e *Engine) findExisting(ctx context.Context, order domain.Order) (domain.Order, error) {
orders, err := e.store.ListOrders(ctx, order.AccountIDHash, order.TradeDate, order.TradeDate)
if err != nil {
return domain.Order{}, err
}
for _, existing := range orders {
if existing.ClientOrderID == order.ClientOrderID &&
existing.Status != domain.OrderStatusFailed &&
existing.Status != domain.OrderStatusRejected {
return existing, nil
}
}
return domain.Order{}, nil
}
func (e *Engine) Refresh(ctx context.Context, order domain.Order) (domain.Order, error) {
if e.gateway == nil {
return domain.Order{}, errors.New("gateway is nil")
}
lock := e.lockFor(order.InstrumentUID)
lock.Lock()
defer lock.Unlock()
state, err := e.gateway.GetOrderState(ctx, e.accountID, order.BrokerOrderID)
if err != nil {
return domain.Order{}, err
}
state.ClientOrderID = order.ClientOrderID
state.AccountIDHash = order.AccountIDHash
state.InstrumentUID = order.InstrumentUID
state.TradeDate = order.TradeDate
state.Side = order.Side
state.OrderType = order.OrderType
state.LimitPrice = order.LimitPrice
state.QuantityLots = order.QuantityLots
state.AttemptNo = order.AttemptNo
if e.store != nil {
if err := e.store.UpsertOrder(ctx, state); err != nil {
return domain.Order{}, err
}
}
return state, nil
}
func (e *Engine) Cancel(ctx context.Context, order domain.Order) error {
if e.gateway == nil {
return errors.New("gateway is nil")
}
lock := e.lockFor(order.InstrumentUID)
lock.Lock()
defer lock.Unlock()
if err := e.gateway.CancelOrder(ctx, e.accountID, order.BrokerOrderID); err != nil {
return err
}
if e.store != nil {
return e.store.UpdateOrderStatus(ctx, order.ClientOrderID, domain.OrderStatusCancelled, order.FilledLots, order.RawStateJSON)
}
return nil
}
func (e *Engine) MonitorUntil(ctx context.Context, order domain.Order, cfg MonitorConfig) (domain.Order, error) {
if cfg.PollInterval <= 0 {
cfg.PollInterval = 500 * time.Millisecond
}
if cfg.MaxAttempts <= 0 {
cfg.MaxAttempts = 1
}
lastPost := time.Now()
current := order
aggregate := order
seen := map[string]domain.Order{order.ClientOrderID: order}
ticker := time.NewTicker(cfg.PollInterval)
defer ticker.Stop()
for {
previous := seen[current.ClientOrderID]
refreshed, err := e.Refresh(ctx, current)
if err != nil {
return aggregate, err
}
aggregate = mergeAggregateFill(aggregate, previous, refreshed)
seen[current.ClientOrderID] = refreshed
current = mergeOrderState(current, refreshed)
aggregate.Status = current.Status
aggregate.UpdatedAt = current.UpdatedAt
aggregate.RawStateJSON = current.RawStateJSON
if aggregate.FilledLots >= aggregate.QuantityLots {
aggregate.Status = domain.OrderStatusFilled
return aggregate, nil
}
if isTerminal(current.Status) {
return aggregate, nil
}
if !cfg.Deadline.IsZero() && !time.Now().Before(cfg.Deadline) {
if err := e.Cancel(ctx, current); err != nil {
return aggregate, err
}
aggregate.Status = domain.OrderStatusExpired
if e.store != nil {
if err := e.store.UpdateOrderStatus(ctx, current.ClientOrderID, aggregate.Status, current.FilledLots, current.RawStateJSON); err != nil {
return aggregate, err
}
}
return aggregate, nil
}
shouldRepost := cfg.RepostAfter > 0 &&
time.Since(lastPost) >= cfg.RepostAfter &&
current.AttemptNo < cfg.MaxAttempts &&
aggregate.FilledLots < aggregate.QuantityLots &&
cfg.Quote != nil
if shouldRepost {
next, err := e.repost(ctx, current, cfg, aggregate.QuantityLots-aggregate.FilledLots)
if err != nil {
return aggregate, err
}
current = next
seen[current.ClientOrderID] = current
lastPost = time.Now()
continue
}
select {
case <-ctx.Done():
return aggregate, ctx.Err()
case <-ticker.C:
}
}
}
func (e *Engine) repost(ctx context.Context, order domain.Order, cfg MonitorConfig, remaining int64) (domain.Order, error) {
if err := e.Cancel(ctx, order); err != nil {
return domain.Order{}, err
}
if remaining <= 0 {
order.Status = domain.OrderStatusFilled
return order, nil
}
book, err := cfg.Quote(ctx, order.InstrumentUID)
if err != nil {
return domain.Order{}, err
}
attempt := order.AttemptNo + 1
switch order.Side {
case domain.SideBuy:
return e.PlaceEntry(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt)
case domain.SideSell:
return e.PlaceExit(ctx, order.AccountIDHash, cfg.Instrument, order.TradeDate, remaining, book, cfg.ImproveTicks, attempt)
default:
return domain.Order{}, fmt.Errorf("unsupported side %s", order.Side)
}
}
func (e *Engine) checkQuoteFresh(book domain.OrderBook) error {
if e.maxQuoteAge <= 0 {
return nil
}
receivedAt := book.ReceivedAt
if receivedAt.IsZero() {
receivedAt = book.Time
}
if receivedAt.IsZero() {
return fmt.Errorf("quote timestamp is missing")
}
age := time.Since(receivedAt)
if age > e.maxQuoteAge {
return fmt.Errorf("quote age %s exceeds %s", age, e.maxQuoteAge)
}
return nil
}
func (e *Engine) lockFor(instrumentUID string) *sync.Mutex {
value, _ := e.mu.LoadOrStore(instrumentUID, &sync.Mutex{})
lock, ok := value.(*sync.Mutex)
if !ok {
panic("execution lock has unexpected type")
}
return lock
}
func bestBidAsk(book domain.OrderBook) (decimal.Decimal, decimal.Decimal, error) {
bid, ok := book.BestBid()
if !ok {
return decimal.Zero, decimal.Zero, ErrEmptyOrderBook
}
ask, ok := book.BestAsk()
if !ok {
return decimal.Zero, decimal.Zero, ErrEmptyOrderBook
}
return bid, ask, nil
}
func isTerminal(status domain.OrderStatus) bool {
switch status {
case domain.OrderStatusFilled, domain.OrderStatusCancelled, domain.OrderStatusRejected, domain.OrderStatusExpired, domain.OrderStatusFailed:
return true
default:
return false
}
}
func mergeOrderState(base, state domain.Order) domain.Order {
base.BrokerOrderID = state.BrokerOrderID
base.FilledLots = state.FilledLots
base.AvgFillPrice = state.AvgFillPrice
base.Status = state.Status
base.Commission = state.Commission
base.RawStateJSON = state.RawStateJSON
base.UpdatedAt = state.UpdatedAt
return base
}
func mergeAggregateFill(aggregate, previous, current domain.Order) domain.Order {
deltaLots := current.FilledLots - previous.FilledLots
if deltaLots > 0 {
deltaAvg := fillDeltaAvg(previous, current, deltaLots)
previousValue := aggregate.AvgFillPrice.Mul(decimal.NewFromInt(aggregate.FilledLots))
deltaValue := deltaAvg.Mul(decimal.NewFromInt(deltaLots))
aggregate.FilledLots += deltaLots
aggregate.AvgFillPrice = previousValue.Add(deltaValue).Div(decimal.NewFromInt(aggregate.FilledLots))
}
deltaCommission := current.Commission.Sub(previous.Commission)
if deltaCommission.IsPositive() {
aggregate.Commission = aggregate.Commission.Add(deltaCommission)
}
return aggregate
}
func fillDeltaAvg(previous, current domain.Order, deltaLots int64) decimal.Decimal {
if deltaLots <= 0 {
return decimal.Zero
}
if previous.FilledLots <= 0 {
if current.AvgFillPrice.IsPositive() {
return current.AvgFillPrice
}
return current.LimitPrice
}
currentValue := current.AvgFillPrice.Mul(decimal.NewFromInt(current.FilledLots))
previousValue := previous.AvgFillPrice.Mul(decimal.NewFromInt(previous.FilledLots))
if currentValue.GreaterThan(previousValue) {
return currentValue.Sub(previousValue).Div(decimal.NewFromInt(deltaLots))
}
if current.AvgFillPrice.IsPositive() {
return current.AvgFillPrice
}
return current.LimitPrice
}
+58
View File
@@ -0,0 +1,58 @@
package execution
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"regexp"
"strings"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/money"
)
var nonIDChar = regexp.MustCompile(`[^A-Za-z0-9_-]+`)
func LimitBuyPrice(bestBid, bestAsk, tick decimal.Decimal, improveTicks int) (decimal.Decimal, error) {
if improveTicks < 0 {
improveTicks = 0
}
if !tick.IsPositive() {
return decimal.Zero, money.ErrInvalidTick
}
candidate := bestBid.Add(tick.Mul(decimal.NewFromInt(int64(improveTicks))))
upper := bestAsk.Sub(tick)
if candidate.LessThanOrEqual(upper) {
return money.RoundToTick(candidate, tick, money.RoundFloor)
}
return money.RoundToTick(bestBid, tick, money.RoundFloor)
}
func LimitSellPrice(bestBid, bestAsk, tick decimal.Decimal, improveTicks int) (decimal.Decimal, error) {
if improveTicks < 0 {
improveTicks = 0
}
if !tick.IsPositive() {
return decimal.Zero, money.ErrInvalidTick
}
candidate := bestAsk.Sub(tick.Mul(decimal.NewFromInt(int64(improveTicks))))
lower := bestBid.Add(tick)
if candidate.GreaterThanOrEqual(lower) {
return money.RoundToTick(candidate, tick, money.RoundCeil)
}
return money.RoundToTick(bestAsk, tick, money.RoundCeil)
}
func ClientOrderID(tradeDate time.Time, instrumentUID string, side domain.Side, attempt int) string {
base := fmt.Sprintf("%s|%s|%s|%d", tradeDate.Format("20060102"), instrumentUID, side, attempt)
sum := sha256.Sum256([]byte(base))
suffix := hex.EncodeToString(sum[:])[:8]
cleanUID := nonIDChar.ReplaceAllString(instrumentUID, "_")
if len(cleanUID) > 24 {
cleanUID = cleanUID[:24]
}
return strings.ToLower(fmt.Sprintf("otb-%s-%s-%s-%02d-%s", tradeDate.Format("20060102"), cleanUID, side, attempt, suffix))
}
+49
View File
@@ -0,0 +1,49 @@
package execution
import (
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
func ed(raw string) decimal.Decimal {
v, err := decimal.NewFromString(raw)
if err != nil {
panic(err)
}
return v
}
func TestLimitPricesDoNotCross(t *testing.T) {
buy, err := LimitBuyPrice(ed("100"), ed("100.03"), ed("0.01"), 1)
if err != nil {
t.Fatal(err)
}
if !buy.Equal(ed("100.01")) {
t.Fatalf("buy=%s", buy)
}
sell, err := LimitSellPrice(ed("100"), ed("100.03"), ed("0.01"), 1)
if err != nil {
t.Fatal(err)
}
if !sell.Equal(ed("100.02")) {
t.Fatalf("sell=%s", sell)
}
tightBuy, _ := LimitBuyPrice(ed("100"), ed("100.01"), ed("0.01"), 1)
if !tightBuy.Equal(ed("100")) {
t.Fatalf("tight buy=%s", tightBuy)
}
}
func TestClientOrderIDDeterministic(t *testing.T) {
date := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
a := ClientOrderID(date, "uid", domain.SideBuy, 1)
b := ClientOrderID(date, "uid", domain.SideBuy, 1)
c := ClientOrderID(date, "uid", domain.SideBuy, 2)
if a != b || a == c {
t.Fatalf("unexpected ids: %s %s %s", a, b, c)
}
}
+137
View File
@@ -0,0 +1,137 @@
package execution
import (
"context"
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/testutil"
"overnight-trading-bot/internal/tinvest"
)
func TestClientOrderIDIncludesAttempt(t *testing.T) {
date := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
first := ClientOrderID(date, "uid:TRUR", domain.SideBuy, 1)
second := ClientOrderID(date, "uid:TRUR", domain.SideBuy, 1)
third := ClientOrderID(date, "uid:TRUR", domain.SideBuy, 2)
if first != second {
t.Fatalf("client order id is not deterministic: %s != %s", first, second)
}
if first == third {
t.Fatalf("attempt is not part of client order id: %s", first)
}
}
func TestPlaceLimitSuppressesDuplicateSubmit(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
engine := NewEngine(domain.ModeSandbox, "account", gateway, repo)
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
order := domain.Order{
ClientOrderID: "order-1",
AccountIDHash: "hash",
InstrumentUID: "uid",
TradeDate: tradeDate,
Side: domain.SideBuy,
OrderType: domain.OrderTypeLimit,
LimitPrice: decimal.NewFromInt(100),
QuantityLots: 1,
Status: domain.OrderStatusNew,
AttemptNo: 1,
}
first, err := engine.PlaceLimit(ctx, order)
if err != nil {
t.Fatal(err)
}
second, err := engine.PlaceLimit(ctx, order)
if err != nil {
t.Fatal(err)
}
if first.BrokerOrderID != second.BrokerOrderID {
t.Fatalf("duplicate submit posted a new broker order: %s != %s", first.BrokerOrderID, second.BrokerOrderID)
}
if got := len(gateway.Orders); got != 1 {
t.Fatalf("broker posts=%d, want 1", got)
}
sent, err := repo.GetFreeOrdersSent(ctx, tradeDate, "uid")
if err != nil {
t.Fatal(err)
}
if sent != 1 {
t.Fatalf("free order counter=%d, want 1", sent)
}
}
func TestPlaceEntryRejectsStaleQuote(t *testing.T) {
ctx := context.Background()
engine := NewEngine(domain.ModeSandbox, "account", tinvest.NewFakeGateway(), testutil.NewMemoryRepository())
engine.SetMaxQuoteAge(time.Second)
_, err := engine.PlaceEntry(ctx, "hash", domain.Instrument{
InstrumentUID: "uid",
Lot: 1,
MinPriceIncrement: decimal.NewFromInt(1),
}, time.Now().UTC(), 1, domain.OrderBook{
InstrumentUID: "uid",
Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(99), QuantityLots: 10}},
Asks: []domain.OrderBookLevel{{Price: decimal.NewFromInt(101), QuantityLots: 10}},
ReceivedAt: time.Now().UTC().Add(-2 * time.Second),
}, 1, 1)
if err == nil {
t.Fatal("expected stale quote error")
}
}
func TestMonitorUntilRepostsAndExpiresAtDeadline(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
engine := NewEngine(domain.ModeSandbox, "account", gateway, repo)
instrument := domain.Instrument{
InstrumentUID: "uid",
Lot: 1,
MinPriceIncrement: decimal.NewFromInt(1),
}
book := domain.OrderBook{
InstrumentUID: "uid",
Bids: []domain.OrderBookLevel{{Price: decimal.NewFromInt(99), QuantityLots: 10}},
Asks: []domain.OrderBookLevel{{Price: decimal.NewFromInt(101), QuantityLots: 10}},
ReceivedAt: time.Now().UTC(),
}
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
order, err := engine.PlaceEntry(ctx, "hash", instrument, tradeDate, 3, book, 1, 1)
if err != nil {
t.Fatal(err)
}
monitored, err := engine.MonitorUntil(ctx, order, MonitorConfig{
Deadline: time.Now().Add(20 * time.Millisecond),
PollInterval: time.Millisecond,
MaxAttempts: 2,
RepostAfter: time.Nanosecond,
Instrument: instrument,
ImproveTicks: 1,
Quote: func(context.Context, string) (domain.OrderBook, error) {
book.ReceivedAt = time.Now().UTC()
return book, nil
},
})
if err != nil {
t.Fatal(err)
}
if monitored.Status != domain.OrderStatusExpired {
t.Fatalf("status=%s, want EXPIRED", monitored.Status)
}
if got := len(gateway.Orders); got < 2 {
t.Fatalf("broker orders=%d, want repost attempt", got)
}
sent, err := repo.GetFreeOrdersSent(ctx, tradeDate, "uid")
if err != nil {
t.Fatal(err)
}
if sent != 2 {
t.Fatalf("free order counter=%d, want 2", sent)
}
}
+148
View File
@@ -0,0 +1,148 @@
package features
import (
"context"
"fmt"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/repository"
"overnight-trading-bot/internal/timeutil"
)
type PipelineConfig struct {
RollingShort int
RollingLong int
EWMALambda float64
RiskBufferBps decimal.Decimal
EntrySlippageBps decimal.Decimal
ExitSlippageBps decimal.Decimal
CommissionRoundtripBps decimal.Decimal
EntryWindow timeutil.Window
ExitWindow timeutil.Window
Location *time.Location
}
type Pipeline struct {
repo repository.Repository
cfg PipelineConfig
}
func NewPipeline(repo repository.Repository, cfg PipelineConfig) Pipeline {
return Pipeline{repo: repo, cfg: cfg}
}
func (p Pipeline) Recompute(ctx context.Context, instrument domain.Instrument, tradeDate time.Time, spread SpreadResult) (domain.FeatureSet, error) {
from := tradeDate.AddDate(0, 0, -p.cfg.RollingLong-5)
candles, err := p.repo.ListDailyCandles(ctx, instrument.InstrumentUID, from, tradeDate)
if err != nil {
return domain.FeatureSet{}, err
}
entryVolume, err := p.intervalVolume(ctx, instrument, tradeDate, p.cfg.EntryWindow)
if err != nil {
return domain.FeatureSet{}, err
}
exitVolume, err := p.intervalVolume(ctx, instrument, tradeDate.AddDate(0, 0, 1), p.cfg.ExitWindow)
if err != nil {
return domain.FeatureSet{}, err
}
feature, err := Compute(instrument, candles, tradeDate, spread, p.cfg, entryVolume, exitVolume)
if err != nil {
return domain.FeatureSet{}, err
}
if err := p.repo.UpsertFeature(ctx, feature); err != nil {
return domain.FeatureSet{}, err
}
return feature, nil
}
func (p Pipeline) intervalVolume(ctx context.Context, instrument domain.Instrument, date time.Time, window timeutil.Window) (decimal.Decimal, error) {
if window.Start.Duration == 0 && window.End.Duration == 0 {
return decimal.Zero, nil
}
loc := p.cfg.Location
if loc == nil {
loc = time.UTC
}
from := window.Start.On(date, loc).UTC()
to := window.End.On(date, loc).UTC()
candles, err := p.repo.ListMinuteCandles(ctx, instrument.InstrumentUID, from, to)
if err != nil {
return decimal.Zero, err
}
return IntervalVolume(candles, instrument.Lot), nil
}
func Compute(instrument domain.Instrument, candles []domain.Candle, tradeDate time.Time, spread SpreadResult, cfg PipelineConfig, entryVolume, exitVolume decimal.Decimal) (domain.FeatureSet, error) {
if len(candles) < 2 {
return domain.FeatureSet{}, fmt.Errorf("need at least 2 candles, got %d", len(candles))
}
var overnight []float64
var lastROn decimal.Decimal
var lastRDay decimal.Decimal
for i := 1; i < len(candles); i++ {
rOn, err := OvernightReturn(candles[i].Open, candles[i-1].Close)
if err != nil {
return domain.FeatureSet{}, err
}
rDay, err := IntradayReturn(candles[i].Close, candles[i].Open)
if err != nil {
return domain.FeatureSet{}, err
}
onFloat, _ := rOn.Float64()
overnight = append(overnight, onFloat)
lastROn = rOn
lastRDay = rDay
}
short := Rolling(overnight, cfg.RollingShort, cfg.EWMALambda)
long := Rolling(overnight, cfg.RollingLong, cfg.EWMALambda)
adv := ADV(candles, instrument.Lot, 20)
rawEdgeBps := decimal.NewFromFloat(short.Mean).Mul(decimal.NewFromInt(10_000))
if !entryVolume.IsPositive() {
entryVolume = adv
}
if !exitVolume.IsPositive() {
exitVolume = adv
}
instrumentCommission := instrument.ExpectedCommissionBpsPerSide.Mul(decimal.NewFromInt(2))
expectedCost := spread.SpreadBps.
Add(cfg.EntrySlippageBps).
Add(cfg.ExitSlippageBps).
Add(cfg.CommissionRoundtripBps).
Add(instrumentCommission).
Add(cfg.RiskBufferBps)
return domain.FeatureSet{
InstrumentUID: instrument.InstrumentUID,
TradeDate: tradeDate,
ROn: lastROn,
RDay: lastRDay,
MuOn60: decimal.NewFromFloat(short.Mean),
MuOn252: decimal.NewFromFloat(long.Mean),
SigmaOn60: decimal.NewFromFloat(short.StdDev),
TStatOn60: decimal.NewFromFloat(short.TStat),
WinOn60: decimal.NewFromFloat(short.WinRate),
EWMAOn: decimal.NewFromFloat(short.EWMA),
SpreadBps: spread.SpreadBps,
HalfSpreadBps: spread.HalfSpreadBps,
TickBps: spread.TickBps,
ADV20: adv,
ExpectedCostBps: expectedCost,
NetEdgeBps: rawEdgeBps.Sub(expectedCost),
EntryIntervalVolume: entryVolume,
ExitIntervalVolume: exitVolume,
CalculatedAt: time.Now().UTC(),
}, nil
}
func IntervalVolume(candles []domain.Candle, lot int64) decimal.Decimal {
if lot <= 0 {
return decimal.Zero
}
total := decimal.Zero
for _, candle := range candles {
total = total.Add(candle.VolumeLots.Mul(decimal.NewFromInt(lot)).Mul(candle.Close))
}
return total
}
+57
View File
@@ -0,0 +1,57 @@
package features
import (
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
func TestComputeExpectedCostIncludesCommissionAndSlippage(t *testing.T) {
var candles []domain.Candle
start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
for i := 0; i < 6; i++ {
price := decimal.NewFromInt(int64(100 + i))
candles = append(candles, domain.Candle{
InstrumentUID: "uid",
TradeDate: start.AddDate(0, 0, i),
Open: price,
Close: price,
VolumeLots: decimal.NewFromInt(1000),
})
}
got, err := Compute(domain.Instrument{
InstrumentUID: "uid",
Lot: 1,
ExpectedCommissionBpsPerSide: decimal.NewFromInt(1),
}, candles, start.AddDate(0, 0, 5), SpreadResult{SpreadBps: decimal.NewFromInt(10)}, PipelineConfig{
RollingShort: 2,
RollingLong: 2,
EWMALambda: 0.08,
RiskBufferBps: decimal.NewFromInt(5),
EntrySlippageBps: decimal.NewFromInt(2),
ExitSlippageBps: decimal.NewFromInt(3),
CommissionRoundtripBps: decimal.NewFromInt(4),
}, decimal.NewFromInt(10000), decimal.NewFromInt(9000))
if err != nil {
t.Fatal(err)
}
if !got.ExpectedCostBps.Equal(decimal.NewFromInt(26)) {
t.Fatalf("expected cost=%s, want 26", got.ExpectedCostBps)
}
if !got.EntryIntervalVolume.Equal(decimal.NewFromInt(10000)) || !got.ExitIntervalVolume.Equal(decimal.NewFromInt(9000)) {
t.Fatalf("interval volumes were not preserved: %+v", got)
}
}
func TestIntervalVolume(t *testing.T) {
got := IntervalVolume([]domain.Candle{
{Close: decimal.NewFromInt(100), VolumeLots: decimal.NewFromInt(10)},
{Close: decimal.NewFromInt(101), VolumeLots: decimal.NewFromInt(20)},
}, 2)
if !got.Equal(decimal.NewFromInt(6040)) {
t.Fatalf("interval volume=%s, want 6040", got)
}
}
+207
View File
@@ -0,0 +1,207 @@
package features
import (
"errors"
"math"
"sort"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/money"
)
var ErrInvalidPrice = errors.New("price must be positive")
func OvernightReturn(open, previousClose decimal.Decimal) (decimal.Decimal, error) {
if !open.IsPositive() || !previousClose.IsPositive() {
return decimal.Zero, ErrInvalidPrice
}
return open.Div(previousClose).Sub(decimal.NewFromInt(1)), nil
}
func IntradayReturn(close, open decimal.Decimal) (decimal.Decimal, error) {
if !close.IsPositive() || !open.IsPositive() {
return decimal.Zero, ErrInvalidPrice
}
return close.Div(open).Sub(decimal.NewFromInt(1)), nil
}
func LogReturn(to, from decimal.Decimal) (float64, error) {
if !to.IsPositive() || !from.IsPositive() {
return 0, ErrInvalidPrice
}
ratio, _ := to.Div(from).Float64()
return math.Log(ratio), nil
}
func CumulativeLinear(returns []decimal.Decimal) decimal.Decimal {
total := decimal.NewFromInt(1)
for _, r := range returns {
total = total.Mul(decimal.NewFromInt(1).Add(r))
}
return total.Sub(decimal.NewFromInt(1))
}
func CumulativeLog(logReturns []float64) float64 {
sum := 0.0
for _, r := range logReturns {
sum += r
}
return math.Exp(sum) - 1
}
type RollingResult struct {
Mean float64
StdDev float64
TStat float64
WinRate float64
EWMA float64
Available bool
}
func Rolling(values []float64, window int, lambda float64) RollingResult {
if window <= 0 || len(values) < window {
return RollingResult{}
}
sample := values[len(values)-window:]
mean := Mean(sample)
std := StdDev(sample)
win := WinRate(sample)
ewma := EWMA(values, lambda)
res := RollingResult{
Mean: mean,
StdDev: std,
WinRate: win,
EWMA: ewma,
Available: true,
}
if std > 0 {
res.TStat = mean / std * math.Sqrt(float64(window))
}
return res
}
func Mean(values []float64) float64 {
if len(values) == 0 {
return 0
}
sum := 0.0
for _, value := range values {
sum += value
}
return sum / float64(len(values))
}
func StdDev(values []float64) float64 {
if len(values) < 2 {
return 0
}
mean := Mean(values)
sum := 0.0
for _, value := range values {
diff := value - mean
sum += diff * diff
}
return math.Sqrt(sum / float64(len(values)-1))
}
func WinRate(values []float64) float64 {
if len(values) == 0 {
return 0
}
wins := 0
for _, value := range values {
if value > 0 {
wins++
}
}
return float64(wins) / float64(len(values))
}
func EWMA(values []float64, lambda float64) float64 {
if len(values) == 0 {
return 0
}
if lambda <= 0 || lambda > 1 {
lambda = 0.08
}
ewma := values[0]
for _, value := range values[1:] {
ewma = lambda*value + (1-lambda)*ewma
}
return ewma
}
type SpreadResult struct {
SpreadAbs decimal.Decimal
SpreadBps decimal.Decimal
HalfSpreadBps decimal.Decimal
TickBps decimal.Decimal
Mid decimal.Decimal
}
func Spread(bestBid, bestAsk, tick decimal.Decimal) (SpreadResult, error) {
if !bestBid.IsPositive() || !bestAsk.IsPositive() || bestAsk.LessThanOrEqual(bestBid) {
return SpreadResult{}, ErrInvalidPrice
}
mid := bestAsk.Add(bestBid).Div(decimal.NewFromInt(2))
spreadAbs := bestAsk.Sub(bestBid)
spreadBps, err := money.Bps(spreadAbs, mid)
if err != nil {
return SpreadResult{}, err
}
tickBps := decimal.Zero
if tick.IsPositive() {
tickBps, err = money.Bps(tick, mid)
if err != nil {
return SpreadResult{}, err
}
}
return SpreadResult{
SpreadAbs: spreadAbs,
SpreadBps: spreadBps,
HalfSpreadBps: spreadBps.Div(decimal.NewFromInt(2)),
TickBps: tickBps,
Mid: mid,
}, nil
}
func ADV(candles []domain.Candle, lot int64, window int) decimal.Decimal {
if lot <= 0 || window <= 0 || len(candles) == 0 {
return decimal.Zero
}
sort.Slice(candles, func(i, j int) bool {
return candles[i].TradeDate.Before(candles[j].TradeDate)
})
if len(candles) > window {
candles = candles[len(candles)-window:]
}
total := decimal.Zero
for _, candle := range candles {
total = total.Add(candle.VolumeLots.Mul(decimal.NewFromInt(lot)).Mul(candle.Close))
}
return total.Div(decimal.NewFromInt(int64(len(candles))))
}
func Quantile(values []float64, q float64) float64 {
if len(values) == 0 {
return 0
}
cp := append([]float64(nil), values...)
sort.Float64s(cp)
if q <= 0 {
return cp[0]
}
if q >= 1 {
return cp[len(cp)-1]
}
pos := q * float64(len(cp)-1)
lower := int(math.Floor(pos))
upper := int(math.Ceil(pos))
if lower == upper {
return cp[lower]
}
weight := pos - float64(lower)
return cp[lower]*(1-weight) + cp[upper]*weight
}
+38
View File
@@ -0,0 +1,38 @@
package features
import (
"math"
"testing"
"github.com/shopspring/decimal"
)
func dec(raw string) decimal.Decimal {
v, err := decimal.NewFromString(raw)
if err != nil {
panic(err)
}
return v
}
func TestReturnsAndLogIdentity(t *testing.T) {
rOn, err := OvernightReturn(dec("102"), dec("100"))
if err != nil {
t.Fatal(err)
}
if !rOn.Equal(dec("0.02")) {
t.Fatalf("overnight return=%s", rOn)
}
rDay, err := IntradayReturn(dec("105"), dec("102"))
if err != nil {
t.Fatal(err)
}
if !rDay.Round(10).Equal(dec("0.0294117647")) {
t.Fatalf("intraday return=%s", rDay)
}
linear := CumulativeLinear([]decimal.Decimal{dec("0.01"), dec("-0.02"), dec("0.03")})
logs := []float64{math.Log(1.01), math.Log(0.98), math.Log(1.03)}
if math.Abs(linear.InexactFloat64()-CumulativeLog(logs)) > 1e-10 {
t.Fatalf("linear/log cumulative mismatch")
}
}
+30
View File
@@ -0,0 +1,30 @@
package features
import (
"math"
"testing"
)
func TestRollingStats(t *testing.T) {
values := []float64{0.01, -0.01, 0.02, 0.03}
got := Rolling(values, 4, 0.5)
if !got.Available {
t.Fatal("expected rolling result")
}
if math.Abs(got.Mean-0.0125) > 1e-12 {
t.Fatalf("mean=%f", got.Mean)
}
if math.Abs(got.WinRate-0.75) > 1e-12 {
t.Fatalf("win=%f", got.WinRate)
}
if got.StdDev <= 0 || got.TStat <= 0 {
t.Fatalf("std/tstat invalid: %+v", got)
}
}
func TestRollingSigmaZero(t *testing.T) {
got := Rolling([]float64{0.01, 0.01, 0.01}, 3, 0.08)
if got.StdDev != 0 || got.TStat != 0 {
t.Fatalf("expected zero sigma/tstat, got %+v", got)
}
}
+13
View File
@@ -0,0 +1,13 @@
package features
import "testing"
func TestSpread(t *testing.T) {
got, err := Spread(dec("99"), dec("101"), dec("0.1"))
if err != nil {
t.Fatal(err)
}
if !got.Mid.Equal(dec("100")) || !got.SpreadBps.Equal(dec("200")) || !got.HalfSpreadBps.Equal(dec("100")) || !got.TickBps.Equal(dec("10")) {
t.Fatalf("unexpected spread: %+v", got)
}
}
+109
View File
@@ -0,0 +1,109 @@
package healthcheck
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"time"
"overnight-trading-bot/internal/timeutil"
"overnight-trading-bot/internal/tinvest"
)
type Service struct {
db *sql.DB
gateway tinvest.Gateway
maxDrift time.Duration
server *http.Server
}
func New(db *sql.DB, gateway tinvest.Gateway, maxDrift time.Duration) *Service {
return &Service{db: db, gateway: gateway, maxDrift: maxDrift}
}
func (s *Service) Start(addr string) {
mux := http.NewServeMux()
mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/ready", s.handleReady)
s.server = &http.Server{Addr: addr, Handler: mux, ReadHeaderTimeout: 3 * time.Second}
go func() {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
// HTTP health errors are intentionally surfaced through /ready and logs by caller.
return
}
}()
}
func (s *Service) Shutdown(ctx context.Context) error {
if s.server == nil {
return nil
}
return s.server.Shutdown(ctx)
}
func (s *Service) Check(ctx context.Context) map[string]string {
status := map[string]string{"status": "ok"}
if s.db != nil {
if err := s.db.PingContext(ctx); err != nil {
status["status"] = "fail"
status["db"] = err.Error()
} else {
status["db"] = "ok"
}
}
if s.gateway != nil {
serverTime, err := s.gateway.GetServerTime(ctx)
if err != nil {
status["status"] = "fail"
status["api"] = err.Error()
} else {
status["api"] = "ok"
drift := timeutil.Drift(time.Now().UTC(), serverTime)
status["clock_drift"] = drift.String()
if s.maxDrift > 0 && drift > s.maxDrift {
status["status"] = "fail"
status["clock"] = fmt.Sprintf("drift %s exceeds %s", drift, s.maxDrift)
}
}
}
return status
}
func (s *Service) handleHealth(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
func (s *Service) handleReady(w http.ResponseWriter, r *http.Request) {
status := s.Check(r.Context())
code := http.StatusOK
if status["status"] != "ok" {
code = http.StatusServiceUnavailable
}
writeJSON(w, code, status)
}
func CheckEndpoint(ctx context.Context, url string) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode >= 300 {
return fmt.Errorf("healthcheck returned %s", resp.Status)
}
return nil
}
func writeJSON(w http.ResponseWriter, code int, value any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(value)
}
+61
View File
@@ -0,0 +1,61 @@
package instruments
import (
"context"
"fmt"
"strings"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/repository"
"overnight-trading-bot/internal/tinvest"
)
type Registry struct {
repo repository.Repository
gateway tinvest.Gateway
}
func NewRegistry(repo repository.Repository, gateway tinvest.Gateway) Registry {
return Registry{repo: repo, gateway: gateway}
}
func (r Registry) SyncMetadata(ctx context.Context) error {
instruments, err := r.repo.ListInstruments(ctx, true)
if err != nil {
return err
}
for _, instrument := range instruments {
if strings.HasPrefix(instrument.InstrumentUID, "PENDING:") || !instrument.MetadataValid() {
remote, err := r.gateway.GetInstrument(ctx, instrument.Ticker, instrument.ClassCode)
if err != nil {
return fmt.Errorf("sync %s: %w", instrument.Ticker, err)
}
remote.Enabled = instrument.Enabled && remote.Enabled
remote.FundType = instrument.FundType
remote.ExpectedCommissionBpsPerSide = instrument.ExpectedCommissionBpsPerSide
remote.FreeOrderLimitPerDay = instrument.FreeOrderLimitPerDay
remote.Quarantine = instrument.Quarantine
remote.QuarantineReason = instrument.QuarantineReason
remote.ExcludeReason = instrument.ExcludeReason
if err := r.repo.ReplaceInstrument(ctx, instrument.InstrumentUID, remote); err != nil {
return fmt.Errorf("replace synced instrument %s: %w", instrument.Ticker, err)
}
}
}
return nil
}
func CheckInstrument(instrument domain.Instrument, status domain.TradingStatus) error {
switch {
case !instrument.Enabled:
return fmt.Errorf("%s disabled", instrument.Ticker)
case instrument.Quarantine:
return fmt.Errorf("%s quarantined: %s", instrument.Ticker, instrument.QuarantineReason)
case !instrument.MetadataValid():
return fmt.Errorf("%s invalid metadata", instrument.Ticker)
case status != domain.TradingStatusNormal:
return fmt.Errorf("%s trading status %s", instrument.Ticker, status)
default:
return nil
}
}
+48
View File
@@ -0,0 +1,48 @@
package logging
import (
"io"
"log/slog"
"os"
"strings"
)
func New(level string, out io.Writer) *slog.Logger {
if out == nil {
out = os.Stdout
}
var slogLevel slog.Level
switch strings.ToLower(level) {
case "debug":
slogLevel = slog.LevelDebug
case "warn", "warning":
slogLevel = slog.LevelWarn
case "error":
slogLevel = slog.LevelError
default:
slogLevel = slog.LevelInfo
}
return slog.New(slog.NewJSONHandler(out, &slog.HandlerOptions{Level: slogLevel}))
}
type SDKLogger struct {
Logger *slog.Logger
}
func (l SDKLogger) Infof(template string, args ...any) {
if l.Logger != nil {
l.Logger.Info(template, "args", args)
}
}
func (l SDKLogger) Errorf(template string, args ...any) {
if l.Logger != nil {
l.Logger.Error(template, "args", args)
}
}
func (l SDKLogger) Fatalf(template string, args ...any) {
if l.Logger != nil {
l.Logger.Error(template, "args", args)
}
}
+67
View File
@@ -0,0 +1,67 @@
package marketdata
import (
"context"
"fmt"
"time"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/repository"
"overnight-trading-bot/internal/tinvest"
)
type Loader struct {
repo repository.Repository
gateway tinvest.Gateway
}
func NewLoader(repo repository.Repository, gateway tinvest.Gateway) Loader {
return Loader{repo: repo, gateway: gateway}
}
func (l Loader) BackfillDaily(ctx context.Context, instruments []domain.Instrument, from, to time.Time) error {
for _, instrument := range instruments {
if !instrument.Enabled || instrument.Quarantine {
continue
}
candles, err := l.gateway.GetCandles(ctx, instrument.InstrumentUID, "day", from, to)
if err != nil {
return fmt.Errorf("load candles %s: %w", instrument.Ticker, err)
}
if err := l.repo.UpsertDailyCandles(ctx, candles); err != nil {
return fmt.Errorf("persist candles %s: %w", instrument.Ticker, err)
}
}
return nil
}
func (l Loader) BackfillMinute(ctx context.Context, instruments []domain.Instrument, from, to time.Time) error {
for _, instrument := range instruments {
if !instrument.Enabled || instrument.Quarantine {
continue
}
candles, err := l.gateway.GetCandles(ctx, instrument.InstrumentUID, "minute", from, to)
if err != nil {
return fmt.Errorf("load minute candles %s: %w", instrument.Ticker, err)
}
if err := l.repo.UpsertMinuteCandles(ctx, candles); err != nil {
return fmt.Errorf("persist minute candles %s: %w", instrument.Ticker, err)
}
}
return nil
}
func (l Loader) LatestQuote(ctx context.Context, instrumentUID string, depth int32, maxAge time.Duration) (domain.OrderBook, error) {
book, err := l.gateway.GetOrderBook(ctx, instrumentUID, depth)
if err != nil {
return domain.OrderBook{}, err
}
age := time.Since(book.ReceivedAt)
if book.ReceivedAt.IsZero() {
age = time.Since(book.Time)
}
if maxAge > 0 && age > maxAge {
return domain.OrderBook{}, fmt.Errorf("quote age %s exceeds %s", age, maxAge)
}
return book, nil
}
+117
View File
@@ -0,0 +1,117 @@
package money
import (
"errors"
"github.com/shopspring/decimal"
pb "github.com/russianinvestments/invest-api-go-sdk/proto"
)
var (
ErrInvalidTick = errors.New("tick must be positive")
ErrInvalidBase = errors.New("base must be positive")
)
type RoundMode int
const (
RoundNearest RoundMode = iota
RoundFloor
RoundCeil
)
func QuotationToDecimal(q *pb.Quotation) decimal.Decimal {
if q == nil {
return decimal.Zero
}
return decimal.NewFromInt(q.GetUnits()).Add(decimal.New(int64(q.GetNano()), -9))
}
func DecimalToQuotation(d decimal.Decimal) *pb.Quotation {
units := d.Truncate(0)
nano := d.Sub(units).Mul(decimal.NewFromInt(1_000_000_000)).Round(0)
if nano.Equal(decimal.NewFromInt(1_000_000_000)) {
units = units.Add(decimal.NewFromInt(1))
nano = decimal.Zero
}
if nano.Equal(decimal.NewFromInt(-1_000_000_000)) {
units = units.Sub(decimal.NewFromInt(1))
nano = decimal.Zero
}
nanoPart := nano.IntPart()
if nanoPart < -999_999_999 || nanoPart > 999_999_999 {
panic("decimal quotation nano is out of protobuf range")
}
return &pb.Quotation{
Units: units.IntPart(),
Nano: int32(nanoPart), // #nosec G115 -- nanoPart is bounded above.
}
}
func MoneyValueToDecimal(v *pb.MoneyValue) decimal.Decimal {
if v == nil {
return decimal.Zero
}
return decimal.NewFromInt(v.GetUnits()).Add(decimal.New(int64(v.GetNano()), -9))
}
func Bps(part, base decimal.Decimal) (decimal.Decimal, error) {
if !base.IsPositive() {
return decimal.Zero, ErrInvalidBase
}
return part.Div(base).Mul(decimal.NewFromInt(10_000)), nil
}
func FromBps(bps decimal.Decimal) decimal.Decimal {
return bps.Div(decimal.NewFromInt(10_000))
}
func RoundToTick(price, tick decimal.Decimal, mode RoundMode) (decimal.Decimal, error) {
if !tick.IsPositive() {
return decimal.Zero, ErrInvalidTick
}
steps := price.Div(tick)
switch mode {
case RoundFloor:
steps = steps.Floor()
case RoundCeil:
steps = steps.Ceil()
default:
steps = steps.Round(0)
}
return steps.Mul(tick), nil
}
func Min(values ...decimal.Decimal) decimal.Decimal {
if len(values) == 0 {
return decimal.Zero
}
min := values[0]
for _, value := range values[1:] {
if value.LessThan(min) {
min = value
}
}
return min
}
func Max(values ...decimal.Decimal) decimal.Decimal {
if len(values) == 0 {
return decimal.Zero
}
max := values[0]
for _, value := range values[1:] {
if value.GreaterThan(max) {
max = value
}
}
return max
}
func Abs(value decimal.Decimal) decimal.Decimal {
if value.IsNegative() {
return value.Neg()
}
return value
}
+39
View File
@@ -0,0 +1,39 @@
package money
import (
"testing"
"github.com/shopspring/decimal"
)
func d(raw string) decimal.Decimal {
v, err := decimal.NewFromString(raw)
if err != nil {
panic(err)
}
return v
}
func TestRoundToTick(t *testing.T) {
tests := []struct {
price string
tick string
mode RoundMode
want string
}{
{"10.12346", "0.0001", RoundNearest, "10.1235"},
{"10.126", "0.01", RoundFloor, "10.12"},
{"10.126", "0.01", RoundCeil, "10.13"},
{"10.24", "0.5", RoundNearest, "10"},
{"10.26", "0.5", RoundNearest, "10.5"},
}
for _, tt := range tests {
got, err := RoundToTick(d(tt.price), d(tt.tick), tt.mode)
if err != nil {
t.Fatal(err)
}
if !got.Equal(d(tt.want)) {
t.Fatalf("RoundToTick(%s,%s)=%s want %s", tt.price, tt.tick, got, tt.want)
}
}
}
+219
View File
@@ -0,0 +1,219 @@
package notify
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"time"
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
"overnight-trading-bot/internal/domain"
)
type Notifier interface {
Info(ctx context.Context, msg string) error
Warn(ctx context.Context, msg string) error
Alert(ctx context.Context, msg string) error
Report(ctx context.Context, msg string) error
Close() error
}
type Noop struct{}
func (Noop) Info(context.Context, string) error { return nil }
func (Noop) Warn(context.Context, string) error { return nil }
func (Noop) Alert(context.Context, string) error { return nil }
func (Noop) Report(context.Context, string) error { return nil }
func (Noop) Close() error { return nil }
type TelegramConfig struct {
BotToken string
ChatID int64
NotifyInfo bool
NotifyWarn bool
NotifyAlert bool
NotifyReport bool
AuditSink AuditSink
}
type AuditSink interface {
InsertRiskEvent(ctx context.Context, event domain.RiskEvent) error
}
type Telegram struct {
cfg TelegramConfig
bot *tgbotapi.BotAPI
log *slog.Logger
queue chan outbound
done chan struct{}
closed chan struct{}
}
type outbound struct {
level domain.Severity
text string
}
func NewTelegram(cfg TelegramConfig, log *slog.Logger) (Notifier, error) {
if cfg.BotToken == "" || cfg.ChatID == 0 {
return Noop{}, nil
}
bot, err := tgbotapi.NewBotAPI(cfg.BotToken)
if err != nil {
return nil, err
}
t := &Telegram{
cfg: cfg,
bot: bot,
log: log,
queue: make(chan outbound, 256),
done: make(chan struct{}),
closed: make(chan struct{}),
}
go t.dispatch()
return t, nil
}
func (t *Telegram) Info(ctx context.Context, msg string) error {
if !t.cfg.NotifyInfo {
return nil
}
return t.enqueue(ctx, domain.SeverityInfo, msg, false)
}
func (t *Telegram) Warn(ctx context.Context, msg string) error {
if !t.cfg.NotifyWarn {
return nil
}
return t.enqueue(ctx, domain.SeverityWarn, msg, false)
}
func (t *Telegram) Alert(ctx context.Context, msg string) error {
if !t.cfg.NotifyAlert {
return nil
}
return t.enqueue(ctx, domain.SeverityAlert, msg, true)
}
func (t *Telegram) Report(ctx context.Context, msg string) error {
if !t.cfg.NotifyReport {
return nil
}
return t.enqueueText(ctx, domain.SeverityInfo, formatMessage("[REPORT]", msg), true)
}
func (t *Telegram) Close() error {
close(t.done)
<-t.closed
return nil
}
func (t *Telegram) enqueue(ctx context.Context, level domain.Severity, msg string, mustDeliver bool) error {
return t.enqueueText(ctx, level, formatMessage(prefix(level), msg), mustDeliver)
}
func (t *Telegram) enqueueText(ctx context.Context, level domain.Severity, text string, mustDeliver bool) error {
item := outbound{level: level, text: text}
if mustDeliver {
select {
case t.queue <- item:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
select {
case t.queue <- item:
default:
if t.log != nil {
t.log.Warn("telegram queue full; dropping non-critical notification", "level", level)
}
if t.cfg.AuditSink != nil {
_ = t.cfg.AuditSink.InsertRiskEvent(ctx, domain.RiskEvent{
TS: time.Now().UTC(),
Severity: domain.SeverityWarn,
EventType: "notification_dropped",
Message: fmt.Sprintf("telegram queue full; dropped %s notification", level),
ContextJSON: "{}",
})
}
}
return nil
}
func (t *Telegram) dispatch() {
defer close(t.closed)
for {
select {
case item := <-t.queue:
t.send(item)
case <-t.done:
for {
select {
case item := <-t.queue:
t.send(item)
default:
return
}
}
}
}
}
func (t *Telegram) send(item outbound) {
msg := tgbotapi.NewMessage(t.cfg.ChatID, item.text)
for attempt := 0; attempt < 3; attempt++ {
if _, err := t.bot.Send(msg); err != nil {
delay := telegramRetryDelay(err, attempt)
if t.log != nil {
t.log.Warn("telegram send failed", "attempt", attempt+1, "err", err, "retry_in", delay)
}
timer := time.NewTimer(delay)
select {
case <-timer.C:
case <-t.done:
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
return
}
continue
}
return
}
}
func telegramRetryDelay(err error, attempt int) time.Duration {
var apiErr tgbotapi.Error
if errors.As(err, &apiErr) && apiErr.RetryAfter > 0 {
return time.Duration(apiErr.RetryAfter) * time.Second
}
var apiErrPtr *tgbotapi.Error
if errors.As(err, &apiErrPtr) && apiErrPtr != nil && apiErrPtr.RetryAfter > 0 {
return time.Duration(apiErrPtr.RetryAfter) * time.Second
}
return time.Duration(attempt+1) * time.Second
}
func prefix(level domain.Severity) string {
switch level {
case domain.SeverityInfo:
return "[INFO]"
case domain.SeverityWarn:
return "[WARN]"
case domain.SeverityAlert:
return "[ALERT]"
default:
return fmt.Sprintf("[%s]", strings.ToUpper(string(level)))
}
}
func formatMessage(prefixValue, msg string) string {
return prefixValue + " " + msg
}
+49
View File
@@ -0,0 +1,49 @@
package notify
import (
"testing"
"time"
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
"overnight-trading-bot/internal/domain"
)
func TestPrefix(t *testing.T) {
tests := map[domain.Severity]string{
domain.SeverityInfo: "[INFO]",
domain.SeverityWarn: "[WARN]",
domain.SeverityAlert: "[ALERT]",
domain.Severity("alert"): "[ALERT]",
}
for severity, want := range tests {
if got := prefix(severity); got != want {
t.Fatalf("prefix(%s)=%s, want %s", severity, got, want)
}
}
}
func TestFormatReportPrefix(t *testing.T) {
if got := formatMessage("[REPORT]", "daily"); got != "[REPORT] daily" {
t.Fatalf("message=%s", got)
}
}
func TestTelegramRetryDelayUsesRetryAfter(t *testing.T) {
err := &tgbotapi.Error{
Code: 429,
ResponseParameters: tgbotapi.ResponseParameters{
RetryAfter: 7,
},
}
if got := telegramRetryDelay(err, 0); got != 7*time.Second {
t.Fatalf("delay=%s, want 7s", got)
}
if got := telegramRetryDelay(assertErr{}, 1); got != 2*time.Second {
t.Fatalf("fallback delay=%s, want 2s", got)
}
}
type assertErr struct{}
func (assertErr) Error() string { return "boom" }
+93
View File
@@ -0,0 +1,93 @@
package position
import (
"context"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/money"
"overnight-trading-bot/internal/repository"
)
type Manager struct {
repo repository.Repository
}
func NewManager(repo repository.Repository) Manager {
return Manager{repo: repo}
}
func (m Manager) OnEntryFill(ctx context.Context, accountIDHash string, instrument domain.Instrument, order domain.Order) (domain.Position, error) {
now := time.Now().UTC()
lot := instrument.Lot
if lot <= 0 {
lot = 1
}
pos := domain.Position{
AccountIDHash: accountIDHash,
InstrumentUID: order.InstrumentUID,
OpenTradeDate: order.TradeDate,
Lots: order.FilledLots,
Lot: lot,
AvgBuyPrice: order.AvgFillPrice,
CommissionTotal: order.Commission,
Status: domain.PositionHoldingOvernight,
OpenedAt: &now,
UpdatedAt: now,
}
if pos.Lots < order.QuantityLots {
pos.Status = domain.PositionEntryPartiallyFilled
}
if err := m.repo.UpsertPosition(ctx, pos); err != nil {
return domain.Position{}, err
}
return pos, nil
}
func (m Manager) OnExitFill(ctx context.Context, pos domain.Position, exitOrder domain.Order) (domain.Position, error) {
now := time.Now().UTC()
lot := pos.Lot
if lot <= 0 {
lot = 1
}
executedLots := min(exitOrder.FilledLots, pos.Lots)
if executedLots < 0 {
executedLots = 0
}
previousExitLots := pos.ExitFilledLots
pos.ExitFilledLots += executedLots
if executedLots > 0 {
previousValue := pos.AvgSellPrice.Mul(decimal.NewFromInt(previousExitLots))
newValue := exitOrder.AvgFillPrice.Mul(decimal.NewFromInt(executedLots))
pos.AvgSellPrice = previousValue.Add(newValue).Div(decimal.NewFromInt(pos.ExitFilledLots))
}
pos.CommissionTotal = pos.CommissionTotal.Add(exitOrder.Commission)
executedUnits := decimal.NewFromInt(executedLots).Mul(decimal.NewFromInt(lot))
pos.GrossPnL = pos.GrossPnL.Add(exitOrder.AvgFillPrice.Sub(pos.AvgBuyPrice).Mul(executedUnits))
pos.NetPnL = pos.GrossPnL.Sub(pos.CommissionTotal)
if pos.AvgBuyPrice.IsPositive() {
baseLots := pos.ExitFilledLots
if baseLots <= 0 {
baseLots = pos.Lots
}
base := pos.AvgBuyPrice.Mul(decimal.NewFromInt(baseLots)).Mul(decimal.NewFromInt(lot))
edge, _ := money.Bps(pos.NetPnL, base)
pos.RealizedEdgeBps = edge
}
pos.Status = domain.PositionExitFilled
if executedLots < pos.Lots {
pos.Lots -= executedLots
pos.Status = domain.PositionExitPartiallyFilled
pos.ClosedAt = nil
} else {
pos.Lots = 0
pos.ClosedAt = &now
}
pos.UpdatedAt = now
if err := m.repo.UpsertPosition(ctx, pos); err != nil {
return domain.Position{}, err
}
return pos, nil
}
+141
View File
@@ -0,0 +1,141 @@
package position
import (
"context"
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/testutil"
)
func TestOnEntryFillKeepsBuyCommission(t *testing.T) {
ctx := context.Background()
manager := NewManager(testutil.NewMemoryRepository())
pos, err := manager.OnEntryFill(ctx, "hash", domain.Instrument{Lot: 1}, domain.Order{
InstrumentUID: "uid",
TradeDate: time.Now().UTC(),
QuantityLots: 10,
FilledLots: 10,
AvgFillPrice: decimal.NewFromInt(100),
Commission: decimal.NewFromInt(3),
})
if err != nil {
t.Fatal(err)
}
if !pos.CommissionTotal.Equal(decimal.NewFromInt(3)) {
t.Fatalf("commission=%s, want 3", pos.CommissionTotal)
}
}
func TestOnExitFillPartialUsesExecutedLots(t *testing.T) {
ctx := context.Background()
manager := NewManager(testutil.NewMemoryRepository())
openAt := time.Now().UTC()
pos := domain.Position{
AccountIDHash: "hash",
InstrumentUID: "uid",
OpenTradeDate: openAt,
Lots: 10,
Lot: 1,
AvgBuyPrice: decimal.NewFromInt(100),
Status: domain.PositionHoldingOvernight,
CommissionTotal: decimal.NewFromInt(2),
OpenedAt: &openAt,
}
updated, err := manager.OnExitFill(ctx, pos, domain.Order{
InstrumentUID: "uid",
FilledLots: 4,
AvgFillPrice: decimal.NewFromInt(110),
Commission: decimal.NewFromInt(1),
})
if err != nil {
t.Fatal(err)
}
if updated.Status != domain.PositionExitPartiallyFilled || updated.ClosedAt != nil {
t.Fatalf("unexpected partial status/closed_at: %+v", updated)
}
if updated.Lots != 6 {
t.Fatalf("remaining lots=%d, want 6", updated.Lots)
}
if !updated.GrossPnL.Equal(decimal.NewFromInt(40)) {
t.Fatalf("gross pnl=%s, want 40", updated.GrossPnL)
}
if updated.ExitFilledLots != 4 || !updated.AvgSellPrice.Equal(decimal.NewFromInt(110)) {
t.Fatalf("exit aggregation lots=%d avg=%s", updated.ExitFilledLots, updated.AvgSellPrice)
}
second, err := manager.OnExitFill(ctx, updated, domain.Order{
InstrumentUID: "uid",
FilledLots: 3,
AvgFillPrice: decimal.NewFromInt(120),
})
if err != nil {
t.Fatal(err)
}
wantAvg := decimal.NewFromInt(800).Div(decimal.NewFromInt(7))
if second.ExitFilledLots != 7 || !second.AvgSellPrice.Equal(wantAvg) {
t.Fatalf("weighted avg sell=%s lots=%d, want %s/7", second.AvgSellPrice, second.ExitFilledLots, wantAvg)
}
}
func TestOnExitFillUsesInstrumentLotForAbsolutePnL(t *testing.T) {
ctx := context.Background()
manager := NewManager(testutil.NewMemoryRepository())
openAt := time.Now().UTC()
pos := domain.Position{
AccountIDHash: "hash",
InstrumentUID: "uid",
OpenTradeDate: openAt,
Lots: 4,
Lot: 10,
AvgBuyPrice: decimal.NewFromInt(100),
Status: domain.PositionHoldingOvernight,
CommissionTotal: decimal.NewFromInt(2),
OpenedAt: &openAt,
}
updated, err := manager.OnExitFill(ctx, pos, domain.Order{
InstrumentUID: "uid",
FilledLots: 4,
AvgFillPrice: decimal.NewFromInt(105),
Commission: decimal.NewFromInt(3),
})
if err != nil {
t.Fatal(err)
}
if !updated.GrossPnL.Equal(decimal.NewFromInt(200)) {
t.Fatalf("gross pnl=%s, want 200", updated.GrossPnL)
}
if !updated.NetPnL.Equal(decimal.NewFromInt(195)) {
t.Fatalf("net pnl=%s, want 195", updated.NetPnL)
}
}
func TestOnExitFillUsesLotInRealizedEdgeCommissionBase(t *testing.T) {
ctx := context.Background()
manager := NewManager(testutil.NewMemoryRepository())
openAt := time.Now().UTC()
pos := domain.Position{
AccountIDHash: "hash",
InstrumentUID: "uid",
OpenTradeDate: openAt,
Lots: 1,
Lot: 100,
AvgBuyPrice: decimal.NewFromInt(100),
Status: domain.PositionHoldingOvernight,
OpenedAt: &openAt,
}
updated, err := manager.OnExitFill(ctx, pos, domain.Order{
InstrumentUID: "uid",
FilledLots: 1,
AvgFillPrice: decimal.NewFromInt(100),
Commission: decimal.NewFromInt(10),
})
if err != nil {
t.Fatal(err)
}
if !updated.RealizedEdgeBps.Equal(decimal.NewFromInt(-10)) {
t.Fatalf("realized edge=%s, want -10 bps", updated.RealizedEdgeBps)
}
}
+230
View File
@@ -0,0 +1,230 @@
package reconciliation
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/money"
"overnight-trading-bot/internal/repository"
"overnight-trading-bot/internal/tinvest"
)
type Engine struct {
repo repository.Repository
gateway tinvest.Gateway
accountID string
accountIDHash string
window time.Duration
inFlightGrace time.Duration
}
func New(repo repository.Repository, gateway tinvest.Gateway, accountID, accountIDHash string) Engine {
return Engine{repo: repo, gateway: gateway, accountID: accountID, accountIDHash: accountIDHash, window: 72 * time.Hour}
}
func (e Engine) WithWindow(window time.Duration) Engine {
if window > 0 {
e.window = window
}
return e
}
func (e Engine) WithInFlightGrace(grace time.Duration) Engine {
if grace >= 0 {
e.inFlightGrace = grace
}
return e
}
func (e Engine) Run(ctx context.Context) ([]domain.ReconciliationDiff, error) {
localOrders, err := e.repo.ListActiveOrders(ctx, e.accountIDHash)
if err != nil {
return nil, err
}
brokerOrders, err := e.gateway.GetActiveOrders(ctx, e.accountID)
if err != nil {
return nil, err
}
now := time.Now().UTC()
localByBroker := make(map[string]domain.Order, len(localOrders))
brokerByID := make(map[string]domain.Order, len(brokerOrders))
for _, order := range localOrders {
if order.BrokerOrderID != "" {
localByBroker[order.BrokerOrderID] = order
}
}
var diffs []domain.ReconciliationDiff
for _, brokerOrder := range brokerOrders {
brokerByID[brokerOrder.BrokerOrderID] = brokerOrder
if _, ok := localByBroker[brokerOrder.BrokerOrderID]; !ok {
diffs = append(diffs, domain.ReconciliationDiff{
Kind: "unknown_active_order",
InstrumentUID: brokerOrder.InstrumentUID,
Message: fmt.Sprintf("broker order %s is not known locally", brokerOrder.BrokerOrderID),
Critical: true,
})
}
}
for _, localOrder := range localOrders {
if e.isInFlight(localOrder, now) {
continue
}
if localOrder.BrokerOrderID == "" {
diffs = append(diffs, domain.ReconciliationDiff{
Kind: "local_order_without_broker_id",
InstrumentUID: localOrder.InstrumentUID,
Message: fmt.Sprintf("local order %s is active without broker order id", localOrder.ClientOrderID),
Critical: true,
})
continue
}
if _, ok := brokerByID[localOrder.BrokerOrderID]; !ok {
diffs = append(diffs, domain.ReconciliationDiff{
Kind: "missing_local_order",
InstrumentUID: localOrder.InstrumentUID,
Message: fmt.Sprintf("local active order %s/%s is not active at broker", localOrder.ClientOrderID, localOrder.BrokerOrderID),
Critical: true,
})
}
}
localPositions, err := e.repo.ListOpenPositions(ctx, e.accountIDHash)
if err != nil {
return nil, err
}
portfolio, err := e.gateway.GetPortfolio(ctx, e.accountID)
if err != nil {
return nil, err
}
brokerLots := make(map[string]int64, len(portfolio.Holdings))
for _, holding := range portfolio.Holdings {
brokerLots[holding.InstrumentUID] += holding.QuantityLots
}
for _, pos := range localPositions {
if brokerLots[pos.InstrumentUID] != pos.Lots {
diffs = append(diffs, domain.ReconciliationDiff{
Kind: "position_lots_mismatch",
InstrumentUID: pos.InstrumentUID,
Message: fmt.Sprintf("local lots=%d broker lots=%d", pos.Lots, brokerLots[pos.InstrumentUID]),
Critical: true,
})
}
}
localLots := make(map[string]int64, len(localPositions))
for _, pos := range localPositions {
localLots[pos.InstrumentUID] += pos.Lots
}
for instrumentUID, lots := range brokerLots {
if lots > 0 && localLots[instrumentUID] == 0 {
diffs = append(diffs, domain.ReconciliationDiff{
Kind: "unknown_broker_position",
InstrumentUID: instrumentUID,
Message: fmt.Sprintf("broker holds %d lots but local position is absent", lots),
Critical: true,
})
}
}
from := now.Add(-e.window)
recentOrders, err := e.repo.ListOrders(ctx, e.accountIDHash, from, now)
if err != nil {
return nil, err
}
operations, err := e.gateway.GetOperations(ctx, e.accountID, from, now)
if err != nil {
return nil, err
}
diffs = append(diffs, compareOperations(recentOrders, operations)...)
raw, _ := json.Marshal(diffs)
if err := e.repo.InsertReconciliation(ctx, now, string(raw), len(diffs) > 0); err != nil {
return nil, err
}
return diffs, nil
}
func (e Engine) isInFlight(order domain.Order, now time.Time) bool {
if e.inFlightGrace <= 0 || order.CreatedAt.IsZero() {
return false
}
return order.CreatedAt.After(now.Add(-e.inFlightGrace))
}
func HasCritical(diffs []domain.ReconciliationDiff) bool {
for _, diff := range diffs {
if diff.Critical {
return true
}
}
return false
}
func compareOperations(orders []domain.Order, operations []domain.Operation) []domain.ReconciliationDiff {
var diffs []domain.ReconciliationDiff
localCommissionByInstrument := make(map[string]decimal.Decimal)
localTraded := make(map[string]bool)
for _, order := range orders {
if order.Status == domain.OrderStatusFilled || order.Status == domain.OrderStatusPartiallyFilled {
localCommissionByInstrument[order.InstrumentUID] = localCommissionByInstrument[order.InstrumentUID].Add(order.Commission)
localTraded[order.InstrumentUID] = true
}
}
brokerCommissionByInstrument := make(map[string]decimal.Decimal)
brokerTraded := make(map[string]bool)
for _, op := range operations {
if !op.Commission.IsZero() {
brokerCommissionByInstrument[op.InstrumentUID] = brokerCommissionByInstrument[op.InstrumentUID].Add(op.Commission)
}
if isTradeOperation(op.Type) {
brokerTraded[op.InstrumentUID] = true
}
}
instruments := make(map[string]struct{}, len(localCommissionByInstrument)+len(brokerCommissionByInstrument))
for instrumentUID := range localCommissionByInstrument {
instruments[instrumentUID] = struct{}{}
}
for instrumentUID := range brokerCommissionByInstrument {
instruments[instrumentUID] = struct{}{}
}
for instrumentUID := range instruments {
localCommission := localCommissionByInstrument[instrumentUID]
brokerCommission := brokerCommissionByInstrument[instrumentUID]
if diff := money.Abs(localCommission.Sub(brokerCommission)); diff.GreaterThan(decimal.NewFromFloat(0.01)) {
diffs = append(diffs, domain.ReconciliationDiff{
Kind: "commission_mismatch",
InstrumentUID: instrumentUID,
Message: fmt.Sprintf("local commission=%s broker commission=%s", localCommission.StringFixed(2), brokerCommission.StringFixed(2)),
Critical: true,
})
}
}
for instrumentUID := range brokerTraded {
if instrumentUID != "" && !localTraded[instrumentUID] {
diffs = append(diffs, domain.ReconciliationDiff{
Kind: "unknown_broker_operation",
InstrumentUID: instrumentUID,
Message: "broker has executed operation without local filled order",
Critical: true,
})
}
}
for instrumentUID := range localTraded {
if !brokerTraded[instrumentUID] {
diffs = append(diffs, domain.ReconciliationDiff{
Kind: "missing_broker_operation",
InstrumentUID: instrumentUID,
Message: "local filled order has no matching broker operation in reconciliation window",
Critical: true,
})
}
}
return diffs
}
func isTradeOperation(raw string) bool {
raw = strings.ToUpper(raw)
return strings.Contains(raw, "OPERATION_TYPE_BUY") || strings.Contains(raw, "OPERATION_TYPE_SELL")
}
+131
View File
@@ -0,0 +1,131 @@
package reconciliation
import (
"context"
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/testutil"
"overnight-trading-bot/internal/tinvest"
)
func TestReconciliationFindsCriticalDiffs(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
now := time.Now().UTC()
if err := repo.UpsertOrder(ctx, domain.Order{
ClientOrderID: "local",
BrokerOrderID: "broker-missing",
AccountIDHash: "hash",
InstrumentUID: "uid-local",
TradeDate: now,
Side: domain.SideBuy,
OrderType: domain.OrderTypeLimit,
QuantityLots: 1,
Status: domain.OrderStatusSent,
}); err != nil {
t.Fatal(err)
}
gateway.Orders["broker-unknown"] = domain.Order{
ClientOrderID: "unknown",
BrokerOrderID: "broker-unknown",
AccountIDHash: "hash",
InstrumentUID: "uid-broker",
QuantityLots: 1,
Status: domain.OrderStatusSent,
}
if err := repo.UpsertPosition(ctx, domain.Position{
AccountIDHash: "hash",
InstrumentUID: "uid-local",
OpenTradeDate: now,
Lots: 2,
Status: domain.PositionHoldingOvernight,
}); err != nil {
t.Fatal(err)
}
gateway.Portfolio = domain.Portfolio{
Equity: decimal.NewFromInt(100000),
Cash: decimal.NewFromInt(90000),
Holdings: []domain.Holding{
{InstrumentUID: "uid-local", QuantityLots: 1},
{InstrumentUID: "uid-broker-only", QuantityLots: 3},
},
}
diffs, err := New(repo, gateway, "account", "hash").Run(ctx)
if err != nil {
t.Fatal(err)
}
wantKinds := map[string]bool{
"unknown_active_order": false,
"missing_local_order": false,
"position_lots_mismatch": false,
"unknown_broker_position": false,
}
for _, diff := range diffs {
if _, ok := wantKinds[diff.Kind]; ok {
wantKinds[diff.Kind] = true
}
}
for kind, seen := range wantKinds {
if !seen {
t.Fatalf("missing diff kind %s in %+v", kind, diffs)
}
}
if !HasCritical(diffs) {
t.Fatalf("expected critical diffs")
}
}
func TestCompareOperationsCommissionPerInstrument(t *testing.T) {
orders := []domain.Order{
{InstrumentUID: "TRUR", Status: domain.OrderStatusFilled, Commission: decimal.NewFromInt(2)},
{InstrumentUID: "TGLD", Status: domain.OrderStatusFilled, Commission: decimal.NewFromInt(1)},
}
operations := []domain.Operation{
{InstrumentUID: "TRUR", Type: "OPERATION_TYPE_BUY", Commission: decimal.NewFromInt(1)},
{InstrumentUID: "TGLD", Type: "OPERATION_TYPE_BUY", Commission: decimal.NewFromInt(2)},
}
diffs := compareOperations(orders, operations)
seen := map[string]bool{}
for _, diff := range diffs {
if diff.Kind == "commission_mismatch" {
seen[diff.InstrumentUID] = true
}
}
if !seen["TRUR"] || !seen["TGLD"] {
t.Fatalf("expected per-instrument commission diffs, got %+v", diffs)
}
}
func TestReconciliationSkipsFreshInFlightLocalOrders(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
now := time.Now().UTC()
if err := repo.UpsertOrder(ctx, domain.Order{
ClientOrderID: "fresh",
AccountIDHash: "hash",
InstrumentUID: "uid",
TradeDate: now,
Side: domain.SideBuy,
OrderType: domain.OrderTypeLimit,
QuantityLots: 1,
Status: domain.OrderStatusSent,
CreatedAt: now,
}); err != nil {
t.Fatal(err)
}
diffs, err := New(repo, gateway, "account", "hash").WithInFlightGrace(10 * time.Second).Run(ctx)
if err != nil {
t.Fatal(err)
}
for _, diff := range diffs {
if diff.Kind == "local_order_without_broker_id" || diff.Kind == "missing_local_order" {
t.Fatalf("fresh in-flight order produced diff: %+v", diffs)
}
}
}
+46
View File
@@ -0,0 +1,46 @@
package report
import (
"fmt"
"strings"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
type DailyInput struct {
Date time.Time
Mode domain.Mode
Signals []domain.Signal
Positions []domain.Position
AverageSpreadBps decimal.Decimal
AverageSlipBps decimal.Decimal
RiskStatus string
}
func ComposeDaily(input DailyInput) string {
var b strings.Builder
fmt.Fprintf(&b, "Дата: %s\n", input.Date.Format("2006-01-02"))
fmt.Fprintf(&b, "Режим: %s\n", input.Mode)
fmt.Fprintf(&b, "Сигналы: %d\n", len(input.Signals))
for _, signal := range input.Signals {
fmt.Fprintf(&b, "- %s %s edge=%s reason=%s\n", signal.InstrumentUID, signal.Decision, signal.NetEdgeBps.StringFixed(2), signal.RejectReason)
}
gross := decimal.Zero
net := decimal.Zero
commission := decimal.Zero
for _, pos := range input.Positions {
gross = gross.Add(pos.GrossPnL)
net = net.Add(pos.NetPnL)
commission = commission.Add(pos.CommissionTotal)
}
fmt.Fprintf(&b, "Gross PnL: %s\n", gross.StringFixed(2))
fmt.Fprintf(&b, "Net PnL: %s\n", net.StringFixed(2))
fmt.Fprintf(&b, "Комиссии: %s\n", commission.StringFixed(2))
fmt.Fprintf(&b, "Средний spread: %s bps\n", input.AverageSpreadBps.StringFixed(2))
fmt.Fprintf(&b, "Среднее проскальзывание: %s bps\n", input.AverageSlipBps.StringFixed(2))
fmt.Fprintf(&b, "Risk: %s", input.RiskStatus)
return b.String()
}
@@ -0,0 +1,13 @@
DROP TABLE IF EXISTS reconciliations;
DROP TABLE IF EXISTS daily_reports;
DROP TABLE IF EXISTS system_state;
DROP TABLE IF EXISTS free_order_counters;
DROP TABLE IF EXISTS risk_events;
DROP TABLE IF EXISTS positions;
DROP TABLE IF EXISTS orders;
DROP TABLE IF EXISTS signals;
DROP TABLE IF EXISTS features;
DROP TABLE IF EXISTS candles_minute;
DROP TABLE IF EXISTS candles_daily;
DROP TABLE IF EXISTS instruments;
DROP TABLE IF EXISTS schema_meta;
@@ -0,0 +1,181 @@
CREATE TABLE IF NOT EXISTS schema_meta (
meta_key VARCHAR(64) PRIMARY KEY,
meta_value VARCHAR(255) NOT NULL,
updated_at DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3) ON UPDATE CURRENT_TIMESTAMP(3)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS instruments (
instrument_uid VARCHAR(128) PRIMARY KEY,
figi VARCHAR(64),
ticker VARCHAR(32) NOT NULL,
class_code VARCHAR(32) NOT NULL DEFAULT 'TQTF',
name VARCHAR(255) NOT NULL DEFAULT '',
lot BIGINT NOT NULL DEFAULT 1,
min_price_increment DECIMAL(20,8) NOT NULL DEFAULT 0,
currency VARCHAR(8) NOT NULL DEFAULT 'RUB',
enabled TINYINT(1) NOT NULL DEFAULT 1,
fund_type VARCHAR(64) NOT NULL DEFAULT '',
expected_commission_bps_per_side DECIMAL(12,4) NOT NULL DEFAULT 0,
free_order_limit_per_day INT NOT NULL DEFAULT 0 COMMENT '0 means no configured free-order cap',
quarantine TINYINT(1) NOT NULL DEFAULT 0,
quarantine_reason TEXT,
exclude_reason TEXT,
updated_at DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3) ON UPDATE CURRENT_TIMESTAMP(3),
UNIQUE KEY ux_instruments_ticker_class (ticker, class_code)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS candles_daily (
instrument_uid VARCHAR(128) NOT NULL,
trade_date DATE NOT NULL,
open DECIMAL(20,8) NOT NULL,
high DECIMAL(20,8) NOT NULL,
low DECIMAL(20,8) NOT NULL,
close DECIMAL(20,8) NOT NULL,
volume_lots DECIMAL(20,8) NOT NULL DEFAULT 0,
source VARCHAR(32) NOT NULL,
loaded_at DATETIME(3) NOT NULL,
PRIMARY KEY (instrument_uid, trade_date),
CONSTRAINT fk_candles_daily_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS candles_minute (
instrument_uid VARCHAR(128) NOT NULL,
ts DATETIME(3) NOT NULL,
open DECIMAL(20,8) NOT NULL,
high DECIMAL(20,8) NOT NULL,
low DECIMAL(20,8) NOT NULL,
close DECIMAL(20,8) NOT NULL,
volume_lots DECIMAL(20,8) NOT NULL DEFAULT 0,
source VARCHAR(32) NOT NULL,
loaded_at DATETIME(3) NOT NULL,
PRIMARY KEY (instrument_uid, ts),
CONSTRAINT fk_candles_minute_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS features (
instrument_uid VARCHAR(128) NOT NULL,
trade_date DATE NOT NULL,
r_on DECIMAL(20,10) NOT NULL DEFAULT 0,
r_day DECIMAL(20,10) NOT NULL DEFAULT 0,
mu_on_60 DECIMAL(20,10) NOT NULL DEFAULT 0,
mu_on_252 DECIMAL(20,10) NOT NULL DEFAULT 0,
sigma_on_60 DECIMAL(20,10) NOT NULL DEFAULT 0,
tstat_on_60 DECIMAL(20,10) NOT NULL DEFAULT 0,
win_on_60 DECIMAL(20,10) NOT NULL DEFAULT 0,
ewma_on DECIMAL(20,10) NOT NULL DEFAULT 0,
spread_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
half_spread_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
tick_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
adv_20 DECIMAL(20,8) NOT NULL DEFAULT 0,
expected_cost_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
net_edge_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
entry_interval_volume DECIMAL(20,8) NOT NULL DEFAULT 0,
exit_interval_volume DECIMAL(20,8) NOT NULL DEFAULT 0,
calculated_at DATETIME(3) NOT NULL,
PRIMARY KEY (instrument_uid, trade_date),
CONSTRAINT fk_features_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS signals (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
trade_date DATE NOT NULL,
instrument_uid VARCHAR(128) NOT NULL,
decision ENUM('ENTER','SKIP','REJECT') NOT NULL,
score DECIMAL(20,10) NOT NULL DEFAULT 0,
net_edge_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
target_notional DECIMAL(20,8) NOT NULL DEFAULT 0,
target_lots BIGINT NOT NULL DEFAULT 0,
reject_reason VARCHAR(128),
context_json JSON,
created_at DATETIME(3) NOT NULL,
UNIQUE KEY ux_signals_date_instr (trade_date, instrument_uid),
CONSTRAINT fk_signals_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS orders (
client_order_id VARCHAR(128) PRIMARY KEY,
broker_order_id VARCHAR(128),
account_id_hash VARCHAR(128) NOT NULL,
instrument_uid VARCHAR(128) NOT NULL,
trade_date DATE NOT NULL,
side ENUM('BUY','SELL') NOT NULL,
order_type ENUM('LIMIT') NOT NULL,
limit_price DECIMAL(20,8) NOT NULL DEFAULT 0,
quantity_lots BIGINT NOT NULL,
filled_lots BIGINT NOT NULL DEFAULT 0,
avg_fill_price DECIMAL(20,8) NOT NULL DEFAULT 0,
status ENUM('NEW','SENT','PARTIALLY_FILLED','FILLED','CANCELLED','REJECTED','EXPIRED','FAILED') NOT NULL,
commission DECIMAL(20,8) NOT NULL DEFAULT 0,
attempt_no INT NOT NULL DEFAULT 1,
raw_state_json JSON,
created_at DATETIME(3) NOT NULL,
updated_at DATETIME(3) NOT NULL,
UNIQUE KEY ux_orders_broker_order_id (broker_order_id),
KEY ix_orders_active (account_id_hash, status),
CONSTRAINT fk_orders_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS positions (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
account_id_hash VARCHAR(128) NOT NULL,
instrument_uid VARCHAR(128) NOT NULL,
open_trade_date DATE NOT NULL,
lots BIGINT NOT NULL,
avg_buy_price DECIMAL(20,8) NOT NULL DEFAULT 0,
avg_sell_price DECIMAL(20,8) NOT NULL DEFAULT 0,
status ENUM('NO_POSITION','ENTRY_SIGNALLED','ENTRY_ORDER_SENT','ENTRY_PARTIALLY_FILLED','ENTRY_FILLED','HOLDING_OVERNIGHT','EXIT_ORDER_SENT','EXIT_PARTIALLY_FILLED','EXIT_FILLED','EXIT_FAILED','QUARANTINE') NOT NULL,
gross_pnl DECIMAL(20,8) NOT NULL DEFAULT 0,
net_pnl DECIMAL(20,8) NOT NULL DEFAULT 0,
commission_total DECIMAL(20,8) NOT NULL DEFAULT 0,
realized_edge_bps DECIMAL(12,4) NOT NULL DEFAULT 0,
opened_at DATETIME(3),
closed_at DATETIME(3),
updated_at DATETIME(3) NOT NULL,
KEY ix_positions_open (account_id_hash, status),
CONSTRAINT fk_positions_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS risk_events (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
ts DATETIME(3) NOT NULL,
severity ENUM('INFO','WARN','ALERT','CRITICAL') NOT NULL,
event_type VARCHAR(128) NOT NULL,
instrument_uid VARCHAR(128),
message TEXT NOT NULL,
raw_context_json JSON,
KEY ix_risk_events_ts (ts)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS free_order_counters (
trade_date DATE NOT NULL,
instrument_uid VARCHAR(128) NOT NULL,
orders_sent INT NOT NULL DEFAULT 0,
updated_at DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3) ON UPDATE CURRENT_TIMESTAMP(3),
PRIMARY KEY (trade_date, instrument_uid),
CONSTRAINT fk_free_orders_instrument FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS system_state (
id TINYINT NOT NULL PRIMARY KEY,
state ENUM('INIT','SYNC_INSTRUMENTS','SYNC_MARKET_DATA','GENERATE_SIGNALS','WAIT_ENTRY_WINDOW','PLACE_ENTRY_ORDERS','MONITOR_ENTRY_ORDERS','HOLD_OVERNIGHT','WAIT_EXIT_WINDOW','PLACE_EXIT_ORDERS','MONITOR_EXIT_ORDERS','RECONCILE','REPORT','SLEEP','HALTED') NOT NULL,
mode ENUM('backtest','paper','sandbox','live_readonly','live_trade') NOT NULL,
halted TINYINT(1) NOT NULL DEFAULT 0,
halt_reason TEXT,
last_heartbeat DATETIME(3) NOT NULL,
context_json JSON
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS reconciliations (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
ts DATETIME(3) NOT NULL,
has_diff TINYINT(1) NOT NULL,
diff_json JSON,
KEY ix_reconciliations_ts (ts)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
INSERT INTO schema_meta(meta_key, meta_value) VALUES ('schema_version', '0001')
ON DUPLICATE KEY UPDATE meta_value=VALUES(meta_value);
INSERT INTO system_state(id, state, mode, halted, last_heartbeat, context_json)
VALUES (1, 'INIT', 'paper', 0, UTC_TIMESTAMP(3), JSON_OBJECT())
ON DUPLICATE KEY UPDATE id=id;
@@ -0,0 +1,7 @@
DELETE FROM instruments WHERE instrument_uid IN (
'PENDING:TRUR','PENDING:TGLD','PENDING:TBRU','PENDING:TDIV','PENDING:TMON',
'PENDING:TOFZ','PENDING:TLCB','PENDING:TITR','PENDING:TRND','PENDING:TMOS'
);
UPDATE schema_meta SET meta_value='0001' WHERE meta_key='schema_version';
@@ -0,0 +1,24 @@
INSERT INTO instruments (
instrument_uid, ticker, class_code, name, lot, min_price_increment, currency,
enabled, fund_type, expected_commission_bps_per_side, free_order_limit_per_day,
quarantine, exclude_reason, updated_at
) VALUES
('PENDING:TRUR', 'TRUR', 'TQTF', 'TRUR', 1, 0.0001, 'RUB', 1, 'mixed', 0, 15, 0, NULL, UTC_TIMESTAMP(3)),
('PENDING:TGLD', 'TGLD', 'TQTF', 'TGLD', 1, 0.0001, 'RUB', 1, 'commodity', 0, 15, 0, NULL, UTC_TIMESTAMP(3)),
('PENDING:TBRU', 'TBRU', 'TQTF', 'TBRU', 1, 0.0001, 'RUB', 1, 'bonds', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
('PENDING:TDIV', 'TDIV', 'TQTF', 'TDIV', 1, 0.0001, 'RUB', 1, 'equity_income', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
('PENDING:TMON', 'TMON', 'TQTF', 'TMON', 1, 0.0001, 'RUB', 1, 'money_market', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
('PENDING:TOFZ', 'TOFZ', 'TQTF', 'TOFZ', 1, 0.0001, 'RUB', 1, 'bonds', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
('PENDING:TLCB', 'TLCB', 'TQTF', 'TLCB', 1, 0.0001, 'RUB', 1, 'corporate_bonds', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
('PENDING:TITR', 'TITR', 'TQTF', 'TITR', 1, 0.0001, 'RUB', 1, 'equity', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
('PENDING:TRND', 'TRND', 'TQTF', 'TRND', 1, 0.0001, 'RUB', 1, 'equity', 0, 0, 0, NULL, UTC_TIMESTAMP(3)),
('PENDING:TMOS', 'TMOS', 'TQTF', 'TMOS', 1, 0.0001, 'RUB', 0, 'equity', 0, 0, 0, 'Excluded by default due to possible non-zero sell-side fee', UTC_TIMESTAMP(3))
ON DUPLICATE KEY UPDATE
enabled=VALUES(enabled),
fund_type=VALUES(fund_type),
expected_commission_bps_per_side=VALUES(expected_commission_bps_per_side),
free_order_limit_per_day=VALUES(free_order_limit_per_day),
exclude_reason=VALUES(exclude_reason),
updated_at=UTC_TIMESTAMP(3);
UPDATE schema_meta SET meta_value='0002' WHERE meta_key='schema_version';
@@ -0,0 +1,29 @@
ALTER TABLE free_order_counters DROP FOREIGN KEY fk_free_orders_instrument;
ALTER TABLE free_order_counters ADD CONSTRAINT fk_free_orders_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
ALTER TABLE positions DROP FOREIGN KEY fk_positions_instrument;
ALTER TABLE positions ADD CONSTRAINT fk_positions_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
ALTER TABLE orders DROP FOREIGN KEY fk_orders_instrument;
ALTER TABLE orders ADD CONSTRAINT fk_orders_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
ALTER TABLE signals DROP FOREIGN KEY fk_signals_instrument;
ALTER TABLE signals ADD CONSTRAINT fk_signals_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
ALTER TABLE features DROP FOREIGN KEY fk_features_instrument;
ALTER TABLE features ADD CONSTRAINT fk_features_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
ALTER TABLE candles_minute DROP FOREIGN KEY fk_candles_minute_instrument;
ALTER TABLE candles_minute ADD CONSTRAINT fk_candles_minute_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
ALTER TABLE candles_daily DROP FOREIGN KEY fk_candles_daily_instrument;
ALTER TABLE candles_daily ADD CONSTRAINT fk_candles_daily_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid);
UPDATE schema_meta SET meta_value='0002' WHERE meta_key='schema_version';
@@ -0,0 +1,29 @@
ALTER TABLE candles_daily DROP FOREIGN KEY fk_candles_daily_instrument;
ALTER TABLE candles_daily ADD CONSTRAINT fk_candles_daily_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
ALTER TABLE candles_minute DROP FOREIGN KEY fk_candles_minute_instrument;
ALTER TABLE candles_minute ADD CONSTRAINT fk_candles_minute_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
ALTER TABLE features DROP FOREIGN KEY fk_features_instrument;
ALTER TABLE features ADD CONSTRAINT fk_features_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
ALTER TABLE signals DROP FOREIGN KEY fk_signals_instrument;
ALTER TABLE signals ADD CONSTRAINT fk_signals_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
ALTER TABLE orders DROP FOREIGN KEY fk_orders_instrument;
ALTER TABLE orders ADD CONSTRAINT fk_orders_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
ALTER TABLE positions DROP FOREIGN KEY fk_positions_instrument;
ALTER TABLE positions ADD CONSTRAINT fk_positions_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
ALTER TABLE free_order_counters DROP FOREIGN KEY fk_free_orders_instrument;
ALTER TABLE free_order_counters ADD CONSTRAINT fk_free_orders_instrument
FOREIGN KEY (instrument_uid) REFERENCES instruments(instrument_uid) ON UPDATE CASCADE;
UPDATE schema_meta SET meta_value='0003' WHERE meta_key='schema_version';
@@ -0,0 +1,5 @@
DROP TABLE IF EXISTS daily_reports;
ALTER TABLE positions DROP INDEX ux_positions_trade;
ALTER TABLE positions DROP COLUMN exit_filled_lots;
UPDATE schema_meta SET meta_value='0003' WHERE meta_key='schema_version';
@@ -0,0 +1,11 @@
ALTER TABLE positions ADD COLUMN exit_filled_lots BIGINT NOT NULL DEFAULT 0 AFTER lots;
ALTER TABLE positions ADD UNIQUE KEY ux_positions_trade (account_id_hash, instrument_uid, open_trade_date);
CREATE TABLE IF NOT EXISTS daily_reports (
report_date DATE NOT NULL,
account_id_hash VARCHAR(128) NOT NULL,
sent_at DATETIME(3) NOT NULL,
PRIMARY KEY (report_date, account_id_hash)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
UPDATE schema_meta SET meta_value='0004' WHERE meta_key='schema_version';
@@ -0,0 +1,7 @@
ALTER TABLE risk_events
MODIFY severity ENUM('INFO','WARN','ALERT','CRITICAL','REPORT') NOT NULL;
ALTER TABLE instruments
MODIFY free_order_limit_per_day INT NOT NULL DEFAULT 0;
UPDATE schema_meta SET meta_value='0004' WHERE meta_key='schema_version';
@@ -0,0 +1,20 @@
UPDATE instruments
SET free_order_limit_per_day=0
WHERE ticker NOT IN ('TRUR', 'TGLD') AND free_order_limit_per_day=15;
ALTER TABLE instruments
MODIFY free_order_limit_per_day INT NOT NULL DEFAULT 0 COMMENT '0 means no configured free-order cap';
UPDATE risk_events
SET
severity='INFO',
event_type=CASE
WHEN event_type LIKE 'report_%' THEN event_type
ELSE CONCAT('report_', event_type)
END
WHERE severity='REPORT';
ALTER TABLE risk_events
MODIFY severity ENUM('INFO','WARN','ALERT','CRITICAL') NOT NULL;
UPDATE schema_meta SET meta_value='0005' WHERE meta_key='schema_version';
@@ -0,0 +1,3 @@
ALTER TABLE positions DROP COLUMN lot_size;
UPDATE schema_meta SET meta_value='0005' WHERE meta_key='schema_version';
@@ -0,0 +1,8 @@
ALTER TABLE positions ADD COLUMN lot_size BIGINT NOT NULL DEFAULT 1 AFTER lots;
UPDATE positions p
JOIN instruments i ON i.instrument_uid = p.instrument_uid
SET p.lot_size = i.lot
WHERE p.lot_size = 1 AND i.lot > 1;
UPDATE schema_meta SET meta_value='0006' WHERE meta_key='schema_version';
@@ -0,0 +1,8 @@
package migrations
import "embed"
// FS contains SQL migrations used by both the daemon and cmd/migrate.
//
//go:embed *.sql
var FS embed.FS
+61
View File
@@ -0,0 +1,61 @@
package mysql
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/golang-migrate/migrate/v4"
migratemysql "github.com/golang-migrate/migrate/v4/database/mysql"
"github.com/golang-migrate/migrate/v4/source/iofs"
"overnight-trading-bot/internal/repository/migrations"
)
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
if err := ctx.Err(); err != nil {
return err
}
driver, err := migratemysql.WithInstance(db, &migratemysql.Config{})
if err != nil {
return fmt.Errorf("create mysql migration driver: %w", err)
}
source, err := iofs.New(migrations.FS, ".")
if err != nil {
return fmt.Errorf("create iofs migration source: %w", err)
}
m, err := migrate.NewWithInstance("iofs", source, "mysql", driver)
if err != nil {
return fmt.Errorf("create migrate instance: %w", err)
}
defer func() {
_, _ = m.Close()
}()
if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
return fmt.Errorf("apply migrations: %w", err)
}
return nil
}
func RollbackAll(db *sql.DB) error {
driver, err := migratemysql.WithInstance(db, &migratemysql.Config{})
if err != nil {
return fmt.Errorf("create mysql migration driver: %w", err)
}
source, err := iofs.New(migrations.FS, ".")
if err != nil {
return fmt.Errorf("create iofs migration source: %w", err)
}
m, err := migrate.NewWithInstance("iofs", source, "mysql", driver)
if err != nil {
return fmt.Errorf("create migrate instance: %w", err)
}
defer func() {
_, _ = m.Close()
}()
if err := m.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
return fmt.Errorf("rollback migrations: %w", err)
}
return nil
}
+790
View File
@@ -0,0 +1,790 @@
package mysql
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/repository"
)
var _ repository.Repository = (*Repository)(nil)
type Repository struct {
db *sqlx.DB
tx *sqlx.Tx
}
func NewRepository(db *sqlx.DB) *Repository {
return &Repository{db: db}
}
func (r *Repository) RunInTx(ctx context.Context, fn func(ctx context.Context, repo repository.Repository) error) error {
if r.tx != nil {
return fn(ctx, r)
}
tx, err := r.db.BeginTxx(ctx, nil)
if err != nil {
return err
}
txRepo := &Repository{db: r.db, tx: tx}
if err := fn(ctx, txRepo); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("%w; rollback: %v", err, rbErr)
}
return err
}
return tx.Commit()
}
func (r *Repository) execer() sqlx.ExtContext {
if r.tx != nil {
return r.tx
}
return r.db
}
func (r *Repository) selectContext(ctx context.Context, dest any, query string, args ...any) error {
if r.tx != nil {
return r.tx.SelectContext(ctx, dest, query, args...)
}
return r.db.SelectContext(ctx, dest, query, args...)
}
func (r *Repository) getContext(ctx context.Context, dest any, query string, args ...any) error {
if r.tx != nil {
return r.tx.GetContext(ctx, dest, query, args...)
}
return r.db.GetContext(ctx, dest, query, args...)
}
func (r *Repository) UpsertInstrument(ctx context.Context, instrument domain.Instrument) error {
if instrument.UpdatedAt.IsZero() {
instrument.UpdatedAt = time.Now().UTC()
}
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
INSERT INTO instruments (
instrument_uid, figi, ticker, class_code, name, lot, min_price_increment, currency,
enabled, fund_type, expected_commission_bps_per_side, free_order_limit_per_day,
quarantine, quarantine_reason, exclude_reason, updated_at
) VALUES (
:instrument_uid, :figi, :ticker, :class_code, :name, :lot, :min_price_increment, :currency,
:enabled, :fund_type, :expected_commission_bps_per_side, :free_order_limit_per_day,
:quarantine, :quarantine_reason, :exclude_reason, :updated_at
) ON DUPLICATE KEY UPDATE
instrument_uid=VALUES(instrument_uid),
figi=VALUES(figi),
name=VALUES(name),
lot=VALUES(lot),
min_price_increment=VALUES(min_price_increment),
currency=VALUES(currency),
enabled=VALUES(enabled),
fund_type=VALUES(fund_type),
expected_commission_bps_per_side=VALUES(expected_commission_bps_per_side),
free_order_limit_per_day=VALUES(free_order_limit_per_day),
quarantine=VALUES(quarantine),
quarantine_reason=VALUES(quarantine_reason),
exclude_reason=VALUES(exclude_reason),
updated_at=VALUES(updated_at)`, instrumentRowFromDomain(instrument))
return err
}
func (r *Repository) ReplaceInstrument(ctx context.Context, oldInstrumentUID string, instrument domain.Instrument) error {
if oldInstrumentUID == "" || oldInstrumentUID == instrument.InstrumentUID {
return r.UpsertInstrument(ctx, instrument)
}
return r.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
txRepo, ok := repo.(*Repository)
if !ok {
return errors.New("unexpected repository implementation")
}
return txRepo.replaceInstrument(ctx, oldInstrumentUID, instrument)
})
}
func (r *Repository) replaceInstrument(ctx context.Context, oldInstrumentUID string, instrument domain.Instrument) error {
if instrument.UpdatedAt.IsZero() {
instrument.UpdatedAt = time.Now().UTC()
}
exists, err := r.instrumentExists(ctx, instrument.InstrumentUID)
if err != nil {
return err
}
if exists {
if err := r.mergeInstrumentUID(ctx, oldInstrumentUID, instrument.InstrumentUID); err != nil {
return err
}
return r.UpsertInstrument(ctx, instrument)
}
result, err := sqlx.NamedExecContext(ctx, r.execer(), `
UPDATE instruments SET
instrument_uid=:instrument_uid,
figi=:figi,
ticker=:ticker,
class_code=:class_code,
name=:name,
lot=:lot,
min_price_increment=:min_price_increment,
currency=:currency,
enabled=:enabled,
fund_type=:fund_type,
expected_commission_bps_per_side=:expected_commission_bps_per_side,
free_order_limit_per_day=:free_order_limit_per_day,
quarantine=:quarantine,
quarantine_reason=:quarantine_reason,
exclude_reason=:exclude_reason,
updated_at=:updated_at
WHERE instrument_uid=:old_instrument_uid`, replaceInstrumentRowFromDomain(oldInstrumentUID, instrument))
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return r.UpsertInstrument(ctx, instrument)
}
return nil
}
func (r *Repository) instrumentExists(ctx context.Context, instrumentUID string) (bool, error) {
var count int
if err := r.getContext(ctx, &count, `SELECT COUNT(*) FROM instruments WHERE instrument_uid=?`, instrumentUID); err != nil {
return false, err
}
return count > 0, nil
}
func (r *Repository) mergeInstrumentUID(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
if oldInstrumentUID == newInstrumentUID {
return nil
}
if err := r.mergeDailyCandles(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
return err
}
if err := r.mergeMinuteCandles(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
return err
}
if err := r.mergeFeatures(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
return err
}
if err := r.mergeSignals(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
return err
}
if err := r.mergeFreeOrders(ctx, oldInstrumentUID, newInstrumentUID); err != nil {
return err
}
for _, table := range []string{"orders", "positions", "risk_events"} {
if _, err := r.execer().ExecContext(ctx, fmt.Sprintf(`UPDATE %s SET instrument_uid=? WHERE instrument_uid=?`, table), newInstrumentUID, oldInstrumentUID); err != nil {
return err
}
}
_, err := r.execer().ExecContext(ctx, `DELETE FROM instruments WHERE instrument_uid=?`, oldInstrumentUID)
return err
}
func (r *Repository) mergeDailyCandles(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
_, err := r.execer().ExecContext(ctx, `
INSERT INTO candles_daily (instrument_uid, trade_date, open, high, low, close, volume_lots, source, loaded_at)
SELECT ?, trade_date, open, high, low, close, volume_lots, source, loaded_at
FROM candles_daily WHERE instrument_uid=?
ON DUPLICATE KEY UPDATE
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, newInstrumentUID, oldInstrumentUID)
if err != nil {
return err
}
_, err = r.execer().ExecContext(ctx, `DELETE FROM candles_daily WHERE instrument_uid=?`, oldInstrumentUID)
return err
}
func (r *Repository) mergeMinuteCandles(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
_, err := r.execer().ExecContext(ctx, `
INSERT INTO candles_minute (instrument_uid, ts, open, high, low, close, volume_lots, source, loaded_at)
SELECT ?, ts, open, high, low, close, volume_lots, source, loaded_at
FROM candles_minute WHERE instrument_uid=?
ON DUPLICATE KEY UPDATE
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, newInstrumentUID, oldInstrumentUID)
if err != nil {
return err
}
_, err = r.execer().ExecContext(ctx, `DELETE FROM candles_minute WHERE instrument_uid=?`, oldInstrumentUID)
return err
}
func (r *Repository) mergeFeatures(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
_, err := r.execer().ExecContext(ctx, `
INSERT INTO features (
instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
exit_interval_volume, calculated_at
)
SELECT
?, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
exit_interval_volume, calculated_at
FROM features WHERE instrument_uid=?
ON DUPLICATE KEY UPDATE
r_on=VALUES(r_on), r_day=VALUES(r_day), mu_on_60=VALUES(mu_on_60),
mu_on_252=VALUES(mu_on_252), sigma_on_60=VALUES(sigma_on_60),
tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60),
ewma_on=VALUES(ewma_on), spread_bps=VALUES(spread_bps),
half_spread_bps=VALUES(half_spread_bps), tick_bps=VALUES(tick_bps),
adv_20=VALUES(adv_20), expected_cost_bps=VALUES(expected_cost_bps),
net_edge_bps=VALUES(net_edge_bps), entry_interval_volume=VALUES(entry_interval_volume),
exit_interval_volume=VALUES(exit_interval_volume), calculated_at=VALUES(calculated_at)`, newInstrumentUID, oldInstrumentUID)
if err != nil {
return err
}
_, err = r.execer().ExecContext(ctx, `DELETE FROM features WHERE instrument_uid=?`, oldInstrumentUID)
return err
}
func (r *Repository) mergeSignals(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
_, err := r.execer().ExecContext(ctx, `
INSERT INTO signals (
trade_date, instrument_uid, decision, score, net_edge_bps, target_notional,
target_lots, reject_reason, context_json, created_at
)
SELECT trade_date, ?, decision, score, net_edge_bps, target_notional,
target_lots, reject_reason, context_json, created_at
FROM signals WHERE instrument_uid=?
ON DUPLICATE KEY UPDATE
decision=VALUES(decision), score=VALUES(score), net_edge_bps=VALUES(net_edge_bps),
target_notional=VALUES(target_notional), target_lots=VALUES(target_lots),
reject_reason=VALUES(reject_reason), context_json=VALUES(context_json),
created_at=VALUES(created_at)`, newInstrumentUID, oldInstrumentUID)
if err != nil {
return err
}
_, err = r.execer().ExecContext(ctx, `DELETE FROM signals WHERE instrument_uid=?`, oldInstrumentUID)
return err
}
func (r *Repository) mergeFreeOrders(ctx context.Context, oldInstrumentUID, newInstrumentUID string) error {
_, err := r.execer().ExecContext(ctx, `
INSERT INTO free_order_counters (trade_date, instrument_uid, orders_sent)
SELECT trade_date, ?, orders_sent FROM free_order_counters WHERE instrument_uid=?
ON DUPLICATE KEY UPDATE orders_sent=GREATEST(orders_sent, VALUES(orders_sent))`, newInstrumentUID, oldInstrumentUID)
if err != nil {
return err
}
_, err = r.execer().ExecContext(ctx, `DELETE FROM free_order_counters WHERE instrument_uid=?`, oldInstrumentUID)
return err
}
func (r *Repository) ListInstruments(ctx context.Context, includeDisabled bool) ([]domain.Instrument, error) {
query := `SELECT * FROM instruments`
if !includeDisabled {
query += ` WHERE enabled=1`
}
query += ` ORDER BY ticker`
var rows []instrumentRow
if err := r.selectContext(ctx, &rows, query); err != nil {
return nil, err
}
out := make([]domain.Instrument, 0, len(rows))
for _, row := range rows {
out = append(out, row.domain())
}
return out, nil
}
func (r *Repository) QuarantineInstrument(ctx context.Context, instrumentUID, reason string) error {
_, err := r.execer().ExecContext(ctx, `
UPDATE instruments SET quarantine=1, quarantine_reason=?, updated_at=UTC_TIMESTAMP(3)
WHERE instrument_uid=?`, reason, instrumentUID)
return err
}
func (r *Repository) UpsertDailyCandles(ctx context.Context, candles []domain.Candle) error {
for _, candle := range candles {
if candle.LoadedAt.IsZero() {
candle.LoadedAt = time.Now().UTC()
}
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
INSERT INTO candles_daily (
instrument_uid, trade_date, open, high, low, close, volume_lots, source, loaded_at
) VALUES (
:instrument_uid, :trade_date, :open, :high, :low, :close, :volume_lots, :source, :loaded_at
) ON DUPLICATE KEY UPDATE
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, candleRowFromDomain(candle))
if err != nil {
return err
}
}
return nil
}
func (r *Repository) ListDailyCandles(ctx context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error) {
var rows []candleRow
if err := r.selectContext(ctx, &rows, `
SELECT * FROM candles_daily
WHERE instrument_uid=? AND trade_date BETWEEN ? AND ?
ORDER BY trade_date`, instrumentUID, dateOnly(from), dateOnly(to)); err != nil {
return nil, err
}
out := make([]domain.Candle, 0, len(rows))
for _, row := range rows {
out = append(out, row.domain())
}
return out, nil
}
func (r *Repository) UpsertMinuteCandles(ctx context.Context, candles []domain.Candle) error {
for _, candle := range candles {
if candle.LoadedAt.IsZero() {
candle.LoadedAt = time.Now().UTC()
}
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
INSERT INTO candles_minute (
instrument_uid, ts, open, high, low, close, volume_lots, source, loaded_at
) VALUES (
:instrument_uid, :trade_date, :open, :high, :low, :close, :volume_lots, :source, :loaded_at
) ON DUPLICATE KEY UPDATE
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
volume_lots=VALUES(volume_lots), source=VALUES(source), loaded_at=VALUES(loaded_at)`, candleRowFromDomain(candle))
if err != nil {
return err
}
}
return nil
}
func (r *Repository) ListMinuteCandles(ctx context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error) {
var rows []candleRow
if err := r.selectContext(ctx, &rows, `
SELECT instrument_uid, ts AS trade_date, open, high, low, close, volume_lots, source, loaded_at
FROM candles_minute
WHERE instrument_uid=? AND ts BETWEEN ? AND ?
ORDER BY ts`, instrumentUID, from.UTC(), to.UTC()); err != nil {
return nil, err
}
out := make([]domain.Candle, 0, len(rows))
for _, row := range rows {
out = append(out, row.domain())
}
return out, nil
}
func (r *Repository) UpsertFeature(ctx context.Context, feature domain.FeatureSet) error {
if feature.CalculatedAt.IsZero() {
feature.CalculatedAt = time.Now().UTC()
}
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
INSERT INTO features (
instrument_uid, trade_date, r_on, r_day, mu_on_60, mu_on_252, sigma_on_60,
tstat_on_60, win_on_60, ewma_on, spread_bps, half_spread_bps, tick_bps,
adv_20, expected_cost_bps, net_edge_bps, entry_interval_volume,
exit_interval_volume, calculated_at
) VALUES (
:instrument_uid, :trade_date, :r_on, :r_day, :mu_on_60, :mu_on_252, :sigma_on_60,
:tstat_on_60, :win_on_60, :ewma_on, :spread_bps, :half_spread_bps, :tick_bps,
:adv_20, :expected_cost_bps, :net_edge_bps, :entry_interval_volume,
:exit_interval_volume, :calculated_at
) ON DUPLICATE KEY UPDATE
r_on=VALUES(r_on), r_day=VALUES(r_day), mu_on_60=VALUES(mu_on_60),
mu_on_252=VALUES(mu_on_252), sigma_on_60=VALUES(sigma_on_60),
tstat_on_60=VALUES(tstat_on_60), win_on_60=VALUES(win_on_60),
ewma_on=VALUES(ewma_on), spread_bps=VALUES(spread_bps),
half_spread_bps=VALUES(half_spread_bps), tick_bps=VALUES(tick_bps),
adv_20=VALUES(adv_20), expected_cost_bps=VALUES(expected_cost_bps),
net_edge_bps=VALUES(net_edge_bps), entry_interval_volume=VALUES(entry_interval_volume),
exit_interval_volume=VALUES(exit_interval_volume), calculated_at=VALUES(calculated_at)`, featureRowFromDomain(feature))
return err
}
func (r *Repository) GetFeature(ctx context.Context, instrumentUID string, tradeDate time.Time) (domain.FeatureSet, error) {
var row featureRow
if err := r.getContext(ctx, &row, `SELECT * FROM features WHERE instrument_uid=? AND trade_date=?`, instrumentUID, dateOnly(tradeDate)); err != nil {
return domain.FeatureSet{}, err
}
return row.domain(), nil
}
func (r *Repository) UpsertSignal(ctx context.Context, signal domain.Signal) error {
if signal.CreatedAt.IsZero() {
signal.CreatedAt = time.Now().UTC()
}
if signal.ContextJSON == "" {
signal.ContextJSON = "{}"
}
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
INSERT INTO signals (
trade_date, instrument_uid, decision, score, net_edge_bps, target_notional,
target_lots, reject_reason, context_json, created_at
) VALUES (
:trade_date, :instrument_uid, :decision, :score, :net_edge_bps, :target_notional,
:target_lots, :reject_reason, :context_json, :created_at
) ON DUPLICATE KEY UPDATE
decision=VALUES(decision), score=VALUES(score), net_edge_bps=VALUES(net_edge_bps),
target_notional=VALUES(target_notional), target_lots=VALUES(target_lots),
reject_reason=VALUES(reject_reason), context_json=VALUES(context_json),
created_at=VALUES(created_at)`, signalRowFromDomain(signal))
return err
}
func (r *Repository) ListSignals(ctx context.Context, tradeDate time.Time) ([]domain.Signal, error) {
var rows []signalRow
if err := r.selectContext(ctx, &rows, `SELECT * FROM signals WHERE trade_date=? ORDER BY id`, dateOnly(tradeDate)); err != nil {
return nil, err
}
out := make([]domain.Signal, 0, len(rows))
for _, row := range rows {
out = append(out, row.domain())
}
return out, nil
}
func (r *Repository) UpsertOrder(ctx context.Context, order domain.Order) error {
now := time.Now().UTC()
if order.CreatedAt.IsZero() {
order.CreatedAt = now
}
order.UpdatedAt = now
if order.RawStateJSON == "" {
order.RawStateJSON = "{}"
}
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
INSERT INTO orders (
client_order_id, broker_order_id, account_id_hash, instrument_uid, trade_date,
side, order_type, limit_price, quantity_lots, filled_lots, avg_fill_price,
status, commission, attempt_no, raw_state_json, created_at, updated_at
) VALUES (
:client_order_id, :broker_order_id, :account_id_hash, :instrument_uid, :trade_date,
:side, :order_type, :limit_price, :quantity_lots, :filled_lots, :avg_fill_price,
:status, :commission, :attempt_no, :raw_state_json, :created_at, :updated_at
) ON DUPLICATE KEY UPDATE
broker_order_id=VALUES(broker_order_id), filled_lots=VALUES(filled_lots),
avg_fill_price=VALUES(avg_fill_price), status=VALUES(status),
commission=VALUES(commission), raw_state_json=VALUES(raw_state_json),
updated_at=VALUES(updated_at)`, orderRowFromDomain(order))
return err
}
func (r *Repository) UpdateOrderStatus(ctx context.Context, clientOrderID string, status domain.OrderStatus, filledLots int64, rawJSON string) error {
if rawJSON == "" {
rawJSON = "{}"
}
_, err := r.execer().ExecContext(ctx, `
UPDATE orders SET status=?, filled_lots=?, raw_state_json=?, updated_at=UTC_TIMESTAMP(3)
WHERE client_order_id=?`, status, filledLots, rawJSON, clientOrderID)
return err
}
func (r *Repository) ListActiveOrders(ctx context.Context, accountIDHash string) ([]domain.Order, error) {
var rows []orderRow
if err := r.selectContext(ctx, &rows, `
SELECT * FROM orders
WHERE account_id_hash=? AND status IN ('NEW','SENT','PARTIALLY_FILLED')
ORDER BY created_at`, accountIDHash); err != nil {
return nil, err
}
out := make([]domain.Order, 0, len(rows))
for _, row := range rows {
out = append(out, row.domain())
}
return out, nil
}
func (r *Repository) ListOrders(ctx context.Context, accountIDHash string, from, to time.Time) ([]domain.Order, error) {
var rows []orderRow
if err := r.selectContext(ctx, &rows, `
SELECT * FROM orders
WHERE account_id_hash=? AND trade_date BETWEEN ? AND ?
ORDER BY created_at`, accountIDHash, dateOnly(from), dateOnly(to)); err != nil {
return nil, err
}
out := make([]domain.Order, 0, len(rows))
for _, row := range rows {
out = append(out, row.domain())
}
return out, nil
}
func (r *Repository) UpsertPosition(ctx context.Context, position domain.Position) error {
if position.UpdatedAt.IsZero() {
position.UpdatedAt = time.Now().UTC()
}
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
INSERT INTO positions (
id, account_id_hash, instrument_uid, open_trade_date, lots, lot_size, exit_filled_lots,
avg_buy_price, avg_sell_price, status, gross_pnl, net_pnl, commission_total,
realized_edge_bps, opened_at, closed_at, updated_at
) VALUES (
NULLIF(:id, 0), :account_id_hash, :instrument_uid, :open_trade_date, :lots, :lot_size, :exit_filled_lots,
:avg_buy_price, :avg_sell_price, :status, :gross_pnl, :net_pnl, :commission_total,
:realized_edge_bps, :opened_at, :closed_at, :updated_at
) ON DUPLICATE KEY UPDATE
lots=VALUES(lots), lot_size=VALUES(lot_size), exit_filled_lots=VALUES(exit_filled_lots), avg_buy_price=VALUES(avg_buy_price), avg_sell_price=VALUES(avg_sell_price),
status=VALUES(status), gross_pnl=VALUES(gross_pnl), net_pnl=VALUES(net_pnl),
commission_total=VALUES(commission_total), realized_edge_bps=VALUES(realized_edge_bps),
opened_at=VALUES(opened_at), closed_at=VALUES(closed_at), updated_at=VALUES(updated_at)`, positionRowFromDomain(position))
return err
}
func (r *Repository) ListOpenPositions(ctx context.Context, accountIDHash string) ([]domain.Position, error) {
var rows []positionRow
if err := r.selectContext(ctx, &rows, `
SELECT * FROM positions
WHERE account_id_hash=? AND status NOT IN ('NO_POSITION','EXIT_FILLED','QUARANTINE')
ORDER BY updated_at`, accountIDHash); err != nil {
return nil, err
}
out := make([]domain.Position, 0, len(rows))
for _, row := range rows {
out = append(out, row.domain())
}
return out, nil
}
func (r *Repository) ListPositions(ctx context.Context, accountIDHash string, from, to time.Time) ([]domain.Position, error) {
var rows []positionRow
if err := r.selectContext(ctx, &rows, `
SELECT * FROM positions
WHERE account_id_hash=? AND open_trade_date BETWEEN ? AND ?
ORDER BY updated_at`, accountIDHash, dateOnly(from), dateOnly(to)); err != nil {
return nil, err
}
out := make([]domain.Position, 0, len(rows))
for _, row := range rows {
out = append(out, row.domain())
}
return out, nil
}
func (r *Repository) InsertRiskEvent(ctx context.Context, event domain.RiskEvent) error {
if event.TS.IsZero() {
event.TS = time.Now().UTC()
}
if event.ContextJSON == "" {
event.ContextJSON = "{}"
}
_, err := sqlx.NamedExecContext(ctx, r.execer(), `
INSERT INTO risk_events (ts, severity, event_type, instrument_uid, message, raw_context_json)
VALUES (:ts, :severity, :event_type, :instrument_uid, :message, :raw_context_json)`, riskEventRowFromDomain(event))
return err
}
func (r *Repository) GetFreeOrdersSent(ctx context.Context, tradeDate time.Time, instrumentUID string) (int, error) {
var sent int
err := r.getContext(ctx, &sent, `
SELECT orders_sent FROM free_order_counters WHERE trade_date=? AND instrument_uid=?`, dateOnly(tradeDate), instrumentUID)
if errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return sent, err
}
func (r *Repository) IncrementFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int) error {
_, err := r.execer().ExecContext(ctx, `
INSERT INTO free_order_counters (trade_date, instrument_uid, orders_sent)
VALUES (?, ?, ?)
ON DUPLICATE KEY UPDATE orders_sent=orders_sent+VALUES(orders_sent)`, dateOnly(tradeDate), instrumentUID, delta)
return err
}
func (r *Repository) GetSystemState(ctx context.Context) (domain.SystemState, bool, string, error) {
var row struct {
State string `db:"state"`
Halted bool `db:"halted"`
HaltReason sql.NullString `db:"halt_reason"`
}
if err := r.getContext(ctx, &row, `SELECT state, halted, halt_reason FROM system_state WHERE id=1`); err != nil {
return "", false, "", err
}
return domain.SystemState(row.State), row.Halted, row.HaltReason.String, nil
}
func (r *Repository) SaveSystemState(ctx context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, contextJSON string) error {
if contextJSON == "" {
contextJSON = "{}"
}
_, err := r.execer().ExecContext(ctx, `
INSERT INTO system_state (id, state, mode, halted, halt_reason, last_heartbeat, context_json)
VALUES (1, ?, ?, ?, ?, UTC_TIMESTAMP(3), ?)
ON DUPLICATE KEY UPDATE
state=VALUES(state), mode=VALUES(mode), halted=VALUES(halted),
halt_reason=VALUES(halt_reason), last_heartbeat=VALUES(last_heartbeat),
context_json=VALUES(context_json)`, state, mode, halted, nullableString(reason), contextJSON)
return err
}
func (r *Repository) Unhalt(ctx context.Context, reason string) error {
return r.RunInTx(ctx, func(ctx context.Context, repo repository.Repository) error {
state, halted, haltReason, err := repo.GetSystemState(ctx)
if err != nil {
return err
}
if !halted && state != domain.StateHalted {
return fmt.Errorf("system is not halted")
}
if err := repo.InsertRiskEvent(ctx, domain.RiskEvent{
TS: time.Now().UTC(),
Severity: domain.SeverityInfo,
EventType: "manual_unhalt",
Message: fmt.Sprintf("%s (previous halt: %s)", reason, haltReason),
}); err != nil {
return err
}
mode := domain.ModePaper
if txRepo, ok := repo.(*Repository); ok {
currentMode, err := txRepo.getSystemMode(ctx)
if err != nil {
return err
}
mode = currentMode
}
return repo.SaveSystemState(ctx, domain.StateInit, mode, false, "", `{"manual_unhalt":true}`)
})
}
func (r *Repository) getSystemMode(ctx context.Context) (domain.Mode, error) {
var raw string
if err := r.getContext(ctx, &raw, `SELECT mode FROM system_state WHERE id=1`); err != nil {
return "", err
}
mode, err := domain.ParseMode(raw)
if err != nil {
return "", err
}
return mode, nil
}
func (r *Repository) WasDailyReportSent(ctx context.Context, reportDate time.Time, accountIDHash string) (bool, error) {
var count int
if err := r.getContext(ctx, &count, `
SELECT COUNT(*) FROM daily_reports WHERE report_date=? AND account_id_hash=?`, dateOnly(reportDate), accountIDHash); err != nil {
return false, err
}
return count > 0, nil
}
func (r *Repository) MarkDailyReportSent(ctx context.Context, reportDate time.Time, accountIDHash string) error {
_, err := r.execer().ExecContext(ctx, `
INSERT INTO daily_reports (report_date, account_id_hash, sent_at)
VALUES (?, ?, UTC_TIMESTAMP(3))
ON DUPLICATE KEY UPDATE sent_at=sent_at`, dateOnly(reportDate), accountIDHash)
return err
}
func (r *Repository) InsertReconciliation(ctx context.Context, ts time.Time, diffJSON string, hasDiff bool) error {
if ts.IsZero() {
ts = time.Now().UTC()
}
if diffJSON == "" {
diffJSON = "[]"
}
_, err := r.execer().ExecContext(ctx, `
INSERT INTO reconciliations (ts, has_diff, diff_json)
VALUES (?, ?, ?)`, ts, hasDiff, diffJSON)
return err
}
func dateOnly(t time.Time) time.Time {
y, m, d := t.UTC().Date()
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
}
func nullableString(s string) any {
if s == "" {
return nil
}
return s
}
type instrumentRow struct {
InstrumentUID string `db:"instrument_uid"`
Figi sql.NullString `db:"figi"`
Ticker string `db:"ticker"`
ClassCode string `db:"class_code"`
Name string `db:"name"`
Lot int64 `db:"lot"`
MinPriceIncrement decimal.Decimal `db:"min_price_increment"`
Currency string `db:"currency"`
Enabled bool `db:"enabled"`
FundType string `db:"fund_type"`
ExpectedCommissionBpsPerSide decimal.Decimal `db:"expected_commission_bps_per_side"`
FreeOrderLimitPerDay int `db:"free_order_limit_per_day"`
Quarantine bool `db:"quarantine"`
QuarantineReason sql.NullString `db:"quarantine_reason"`
ExcludeReason sql.NullString `db:"exclude_reason"`
UpdatedAt time.Time `db:"updated_at"`
}
func instrumentRowFromDomain(instrument domain.Instrument) instrumentRow {
return instrumentRow{
InstrumentUID: instrument.InstrumentUID,
Figi: sql.NullString{String: instrument.Figi, Valid: instrument.Figi != ""},
Ticker: instrument.Ticker,
ClassCode: instrument.ClassCode,
Name: instrument.Name,
Lot: instrument.Lot,
MinPriceIncrement: instrument.MinPriceIncrement,
Currency: instrument.Currency,
Enabled: instrument.Enabled,
FundType: instrument.FundType,
ExpectedCommissionBpsPerSide: instrument.ExpectedCommissionBpsPerSide,
FreeOrderLimitPerDay: instrument.FreeOrderLimitPerDay,
Quarantine: instrument.Quarantine,
QuarantineReason: sql.NullString{String: instrument.QuarantineReason, Valid: instrument.QuarantineReason != ""},
ExcludeReason: sql.NullString{String: instrument.ExcludeReason, Valid: instrument.ExcludeReason != ""},
UpdatedAt: instrument.UpdatedAt,
}
}
func replaceInstrumentRowFromDomain(oldInstrumentUID string, instrument domain.Instrument) map[string]any {
row := instrumentRowFromDomain(instrument)
return map[string]any{
"instrument_uid": row.InstrumentUID,
"figi": row.Figi,
"ticker": row.Ticker,
"class_code": row.ClassCode,
"name": row.Name,
"lot": row.Lot,
"min_price_increment": row.MinPriceIncrement,
"currency": row.Currency,
"enabled": row.Enabled,
"fund_type": row.FundType,
"expected_commission_bps_per_side": row.ExpectedCommissionBpsPerSide,
"free_order_limit_per_day": row.FreeOrderLimitPerDay,
"quarantine": row.Quarantine,
"quarantine_reason": row.QuarantineReason,
"exclude_reason": row.ExcludeReason,
"updated_at": row.UpdatedAt,
"old_instrument_uid": oldInstrumentUID,
}
}
func (r instrumentRow) domain() domain.Instrument {
return domain.Instrument{
InstrumentUID: r.InstrumentUID,
Figi: r.Figi.String,
Ticker: r.Ticker,
ClassCode: r.ClassCode,
Name: r.Name,
Lot: r.Lot,
MinPriceIncrement: r.MinPriceIncrement,
Currency: r.Currency,
Enabled: r.Enabled,
FundType: r.FundType,
ExpectedCommissionBpsPerSide: r.ExpectedCommissionBpsPerSide,
FreeOrderLimitPerDay: r.FreeOrderLimitPerDay,
Quarantine: r.Quarantine,
QuarantineReason: r.QuarantineReason.String,
ExcludeReason: r.ExcludeReason.String,
UpdatedAt: r.UpdatedAt,
}
}
@@ -0,0 +1,114 @@
//go:build integration
package mysql
import (
"context"
"testing"
"time"
"github.com/jmoiron/sqlx"
"github.com/shopspring/decimal"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mariadb"
"overnight-trading-bot/internal/domain"
)
func TestRepositoryMariaDBMigrationsAndRoundTrip(t *testing.T) {
ctx := context.Background()
container, err := mariadb.Run(ctx,
"mariadb:11.4",
mariadb.WithDatabase("overnight_bot"),
mariadb.WithUsername("bot"),
mariadb.WithPassword("bot"),
)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
if err := testcontainers.TerminateContainer(container); err != nil {
t.Logf("terminate mariadb: %v", err)
}
})
dsn, err := container.ConnectionString(ctx, "parseTime=true", "loc=UTC", "multiStatements=true")
if err != nil {
t.Fatal(err)
}
db, err := sqlx.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = db.Close()
})
if err := db.PingContext(ctx); err != nil {
t.Fatal(err)
}
if err := ApplyMigrations(ctx, db.DB); err != nil {
t.Fatal(err)
}
repo := NewRepository(db)
instrument := domain.Instrument{
InstrumentUID: "uid-trur",
Ticker: "TRUR",
ClassCode: "TQTF",
Name: "TRUR",
Lot: 1,
MinPriceIncrement: decimal.NewFromFloat(0.0001),
Currency: "RUB",
Enabled: true,
}
if err := repo.ReplaceInstrument(ctx, "PENDING:TRUR", instrument); err != nil {
t.Fatal(err)
}
tradeDate := time.Date(2026, 6, 7, 0, 0, 0, 0, time.UTC)
position := domain.Position{
AccountIDHash: "hash",
InstrumentUID: "uid-trur",
OpenTradeDate: tradeDate,
Lots: 10,
AvgBuyPrice: decimal.NewFromInt(100),
Status: domain.PositionHoldingOvernight,
}
if err := repo.UpsertPosition(ctx, position); err != nil {
t.Fatal(err)
}
position.Lots = 8
position.ExitFilledLots = 2
if err := repo.UpsertPosition(ctx, position); err != nil {
t.Fatal(err)
}
var count int
if err := db.GetContext(ctx, &count, `
SELECT COUNT(*) FROM positions WHERE account_id_hash='hash' AND instrument_uid='uid-trur' AND open_trade_date=?`, tradeDate); err != nil {
t.Fatal(err)
}
if count != 1 {
t.Fatalf("positions count=%d, want 1", count)
}
if err := repo.MarkDailyReportSent(ctx, tradeDate, "hash"); err != nil {
t.Fatal(err)
}
sent, err := repo.WasDailyReportSent(ctx, tradeDate, "hash")
if err != nil {
t.Fatal(err)
}
if !sent {
t.Fatalf("daily report marker was not persisted")
}
if err := repo.UpsertOrder(ctx, domain.Order{
ClientOrderID: "bad",
AccountIDHash: "hash",
InstrumentUID: "missing",
TradeDate: tradeDate,
Side: domain.SideBuy,
OrderType: domain.OrderTypeLimit,
LimitPrice: decimal.NewFromInt(100),
QuantityLots: 1,
Status: domain.OrderStatusSent,
RawStateJSON: "{}",
}); err == nil {
t.Fatalf("expected FK failure for missing instrument")
}
}
+338
View File
@@ -0,0 +1,338 @@
package mysql
import (
"database/sql"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
type candleRow struct {
InstrumentUID string `db:"instrument_uid"`
TradeDate time.Time `db:"trade_date"`
Open decimal.Decimal `db:"open"`
High decimal.Decimal `db:"high"`
Low decimal.Decimal `db:"low"`
Close decimal.Decimal `db:"close"`
VolumeLots decimal.Decimal `db:"volume_lots"`
Source string `db:"source"`
LoadedAt time.Time `db:"loaded_at"`
}
func candleRowFromDomain(candle domain.Candle) candleRow {
return candleRow{
InstrumentUID: candle.InstrumentUID,
TradeDate: dateOnly(candle.TradeDate),
Open: candle.Open,
High: candle.High,
Low: candle.Low,
Close: candle.Close,
VolumeLots: candle.VolumeLots,
Source: candle.Source,
LoadedAt: candle.LoadedAt,
}
}
func (r candleRow) domain() domain.Candle {
return domain.Candle{
InstrumentUID: r.InstrumentUID,
TradeDate: r.TradeDate,
Open: r.Open,
High: r.High,
Low: r.Low,
Close: r.Close,
VolumeLots: r.VolumeLots,
Source: r.Source,
LoadedAt: r.LoadedAt,
}
}
type featureRow struct {
InstrumentUID string `db:"instrument_uid"`
TradeDate time.Time `db:"trade_date"`
ROn decimal.Decimal `db:"r_on"`
RDay decimal.Decimal `db:"r_day"`
MuOn60 decimal.Decimal `db:"mu_on_60"`
MuOn252 decimal.Decimal `db:"mu_on_252"`
SigmaOn60 decimal.Decimal `db:"sigma_on_60"`
TStatOn60 decimal.Decimal `db:"tstat_on_60"`
WinOn60 decimal.Decimal `db:"win_on_60"`
EWMAOn decimal.Decimal `db:"ewma_on"`
SpreadBps decimal.Decimal `db:"spread_bps"`
HalfSpreadBps decimal.Decimal `db:"half_spread_bps"`
TickBps decimal.Decimal `db:"tick_bps"`
ADV20 decimal.Decimal `db:"adv_20"`
ExpectedCostBps decimal.Decimal `db:"expected_cost_bps"`
NetEdgeBps decimal.Decimal `db:"net_edge_bps"`
EntryIntervalVolume decimal.Decimal `db:"entry_interval_volume"`
ExitIntervalVolume decimal.Decimal `db:"exit_interval_volume"`
CalculatedAt time.Time `db:"calculated_at"`
}
func featureRowFromDomain(feature domain.FeatureSet) featureRow {
return featureRow{
InstrumentUID: feature.InstrumentUID,
TradeDate: dateOnly(feature.TradeDate),
ROn: feature.ROn,
RDay: feature.RDay,
MuOn60: feature.MuOn60,
MuOn252: feature.MuOn252,
SigmaOn60: feature.SigmaOn60,
TStatOn60: feature.TStatOn60,
WinOn60: feature.WinOn60,
EWMAOn: feature.EWMAOn,
SpreadBps: feature.SpreadBps,
HalfSpreadBps: feature.HalfSpreadBps,
TickBps: feature.TickBps,
ADV20: feature.ADV20,
ExpectedCostBps: feature.ExpectedCostBps,
NetEdgeBps: feature.NetEdgeBps,
EntryIntervalVolume: feature.EntryIntervalVolume,
ExitIntervalVolume: feature.ExitIntervalVolume,
CalculatedAt: feature.CalculatedAt,
}
}
func (r featureRow) domain() domain.FeatureSet {
return domain.FeatureSet{
InstrumentUID: r.InstrumentUID,
TradeDate: r.TradeDate,
ROn: r.ROn,
RDay: r.RDay,
MuOn60: r.MuOn60,
MuOn252: r.MuOn252,
SigmaOn60: r.SigmaOn60,
TStatOn60: r.TStatOn60,
WinOn60: r.WinOn60,
EWMAOn: r.EWMAOn,
SpreadBps: r.SpreadBps,
HalfSpreadBps: r.HalfSpreadBps,
TickBps: r.TickBps,
ADV20: r.ADV20,
ExpectedCostBps: r.ExpectedCostBps,
NetEdgeBps: r.NetEdgeBps,
EntryIntervalVolume: r.EntryIntervalVolume,
ExitIntervalVolume: r.ExitIntervalVolume,
CalculatedAt: r.CalculatedAt,
}
}
type signalRow struct {
ID int64 `db:"id"`
TradeDate time.Time `db:"trade_date"`
InstrumentUID string `db:"instrument_uid"`
Decision string `db:"decision"`
Score decimal.Decimal `db:"score"`
NetEdgeBps decimal.Decimal `db:"net_edge_bps"`
TargetNotional decimal.Decimal `db:"target_notional"`
TargetLots int64 `db:"target_lots"`
RejectReason sql.NullString `db:"reject_reason"`
ContextJSON sql.NullString `db:"context_json"`
CreatedAt time.Time `db:"created_at"`
}
func signalRowFromDomain(signal domain.Signal) signalRow {
return signalRow{
ID: signal.ID,
TradeDate: dateOnly(signal.TradeDate),
InstrumentUID: signal.InstrumentUID,
Decision: string(signal.Decision),
Score: signal.Score,
NetEdgeBps: signal.NetEdgeBps,
TargetNotional: signal.TargetNotional,
TargetLots: signal.TargetLots,
RejectReason: sql.NullString{String: signal.RejectReason, Valid: signal.RejectReason != ""},
ContextJSON: sql.NullString{String: signal.ContextJSON, Valid: signal.ContextJSON != ""},
CreatedAt: signal.CreatedAt,
}
}
func (r signalRow) domain() domain.Signal {
return domain.Signal{
ID: r.ID,
TradeDate: r.TradeDate,
InstrumentUID: r.InstrumentUID,
Decision: domain.SignalDecision(r.Decision),
Score: r.Score,
NetEdgeBps: r.NetEdgeBps,
TargetNotional: r.TargetNotional,
TargetLots: r.TargetLots,
RejectReason: r.RejectReason.String,
ContextJSON: r.ContextJSON.String,
CreatedAt: r.CreatedAt,
}
}
type orderRow struct {
ClientOrderID string `db:"client_order_id"`
BrokerOrderID sql.NullString `db:"broker_order_id"`
AccountIDHash string `db:"account_id_hash"`
InstrumentUID string `db:"instrument_uid"`
TradeDate time.Time `db:"trade_date"`
Side string `db:"side"`
OrderType string `db:"order_type"`
LimitPrice decimal.Decimal `db:"limit_price"`
QuantityLots int64 `db:"quantity_lots"`
FilledLots int64 `db:"filled_lots"`
AvgFillPrice decimal.Decimal `db:"avg_fill_price"`
Status string `db:"status"`
Commission decimal.Decimal `db:"commission"`
AttemptNo int `db:"attempt_no"`
RawStateJSON sql.NullString `db:"raw_state_json"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}
func orderRowFromDomain(order domain.Order) orderRow {
return orderRow{
ClientOrderID: order.ClientOrderID,
BrokerOrderID: sql.NullString{
String: order.BrokerOrderID,
Valid: order.BrokerOrderID != "",
},
AccountIDHash: order.AccountIDHash,
InstrumentUID: order.InstrumentUID,
TradeDate: dateOnly(order.TradeDate),
Side: string(order.Side),
OrderType: string(order.OrderType),
LimitPrice: order.LimitPrice,
QuantityLots: order.QuantityLots,
FilledLots: order.FilledLots,
AvgFillPrice: order.AvgFillPrice,
Status: string(order.Status),
Commission: order.Commission,
AttemptNo: order.AttemptNo,
RawStateJSON: sql.NullString{
String: order.RawStateJSON,
Valid: order.RawStateJSON != "",
},
CreatedAt: order.CreatedAt,
UpdatedAt: order.UpdatedAt,
}
}
func (r orderRow) domain() domain.Order {
return domain.Order{
ClientOrderID: r.ClientOrderID,
BrokerOrderID: r.BrokerOrderID.String,
AccountIDHash: r.AccountIDHash,
InstrumentUID: r.InstrumentUID,
TradeDate: r.TradeDate,
Side: domain.Side(r.Side),
OrderType: domain.OrderType(r.OrderType),
LimitPrice: r.LimitPrice,
QuantityLots: r.QuantityLots,
FilledLots: r.FilledLots,
AvgFillPrice: r.AvgFillPrice,
Status: domain.OrderStatus(r.Status),
Commission: r.Commission,
AttemptNo: r.AttemptNo,
RawStateJSON: r.RawStateJSON.String,
CreatedAt: r.CreatedAt,
UpdatedAt: r.UpdatedAt,
}
}
type positionRow struct {
ID int64 `db:"id"`
AccountIDHash string `db:"account_id_hash"`
InstrumentUID string `db:"instrument_uid"`
OpenTradeDate time.Time `db:"open_trade_date"`
Lots int64 `db:"lots"`
Lot int64 `db:"lot_size"`
ExitFilledLots int64 `db:"exit_filled_lots"`
AvgBuyPrice decimal.Decimal `db:"avg_buy_price"`
AvgSellPrice decimal.Decimal `db:"avg_sell_price"`
Status string `db:"status"`
GrossPnL decimal.Decimal `db:"gross_pnl"`
NetPnL decimal.Decimal `db:"net_pnl"`
CommissionTotal decimal.Decimal `db:"commission_total"`
RealizedEdgeBps decimal.Decimal `db:"realized_edge_bps"`
OpenedAt sql.NullTime `db:"opened_at"`
ClosedAt sql.NullTime `db:"closed_at"`
UpdatedAt time.Time `db:"updated_at"`
}
func positionRowFromDomain(position domain.Position) positionRow {
lot := position.Lot
if lot <= 0 {
lot = 1
}
return positionRow{
ID: position.ID,
AccountIDHash: position.AccountIDHash,
InstrumentUID: position.InstrumentUID,
OpenTradeDate: dateOnly(position.OpenTradeDate),
Lots: position.Lots,
Lot: lot,
ExitFilledLots: position.ExitFilledLots,
AvgBuyPrice: position.AvgBuyPrice,
AvgSellPrice: position.AvgSellPrice,
Status: string(position.Status),
GrossPnL: position.GrossPnL,
NetPnL: position.NetPnL,
CommissionTotal: position.CommissionTotal,
RealizedEdgeBps: position.RealizedEdgeBps,
OpenedAt: nullableTime(position.OpenedAt),
ClosedAt: nullableTime(position.ClosedAt),
UpdatedAt: position.UpdatedAt,
}
}
func (r positionRow) domain() domain.Position {
return domain.Position{
ID: r.ID,
AccountIDHash: r.AccountIDHash,
InstrumentUID: r.InstrumentUID,
OpenTradeDate: r.OpenTradeDate,
Lots: r.Lots,
Lot: r.Lot,
ExitFilledLots: r.ExitFilledLots,
AvgBuyPrice: r.AvgBuyPrice,
AvgSellPrice: r.AvgSellPrice,
Status: domain.PositionStatus(r.Status),
GrossPnL: r.GrossPnL,
NetPnL: r.NetPnL,
CommissionTotal: r.CommissionTotal,
RealizedEdgeBps: r.RealizedEdgeBps,
OpenedAt: timePtr(r.OpenedAt),
ClosedAt: timePtr(r.ClosedAt),
UpdatedAt: r.UpdatedAt,
}
}
type riskEventRow struct {
TS time.Time `db:"ts"`
Severity string `db:"severity"`
EventType string `db:"event_type"`
InstrumentUID sql.NullString `db:"instrument_uid"`
Message string `db:"message"`
ContextJSON string `db:"raw_context_json"`
}
func riskEventRowFromDomain(event domain.RiskEvent) riskEventRow {
return riskEventRow{
TS: event.TS,
Severity: string(event.Severity),
EventType: event.EventType,
InstrumentUID: sql.NullString{String: event.InstrumentUID, Valid: event.InstrumentUID != ""},
Message: event.Message,
ContextJSON: event.ContextJSON,
}
}
func nullableTime(t *time.Time) sql.NullTime {
if t == nil {
return sql.NullTime{}
}
return sql.NullTime{Time: *t, Valid: true}
}
func timePtr(t sql.NullTime) *time.Time {
if !t.Valid {
return nil
}
return &t.Time
}
+49
View File
@@ -0,0 +1,49 @@
package repository
import (
"context"
"time"
"overnight-trading-bot/internal/domain"
)
type Repository interface {
RunInTx(ctx context.Context, fn func(ctx context.Context, repo Repository) error) error
UpsertInstrument(ctx context.Context, instrument domain.Instrument) error
ReplaceInstrument(ctx context.Context, oldInstrumentUID string, instrument domain.Instrument) error
ListInstruments(ctx context.Context, includeDisabled bool) ([]domain.Instrument, error)
QuarantineInstrument(ctx context.Context, instrumentUID, reason string) error
UpsertDailyCandles(ctx context.Context, candles []domain.Candle) error
ListDailyCandles(ctx context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error)
UpsertMinuteCandles(ctx context.Context, candles []domain.Candle) error
ListMinuteCandles(ctx context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error)
UpsertFeature(ctx context.Context, feature domain.FeatureSet) error
GetFeature(ctx context.Context, instrumentUID string, tradeDate time.Time) (domain.FeatureSet, error)
UpsertSignal(ctx context.Context, signal domain.Signal) error
ListSignals(ctx context.Context, tradeDate time.Time) ([]domain.Signal, error)
UpsertOrder(ctx context.Context, order domain.Order) error
UpdateOrderStatus(ctx context.Context, clientOrderID string, status domain.OrderStatus, filledLots int64, rawJSON string) error
ListActiveOrders(ctx context.Context, accountIDHash string) ([]domain.Order, error)
ListOrders(ctx context.Context, accountIDHash string, from, to time.Time) ([]domain.Order, error)
UpsertPosition(ctx context.Context, position domain.Position) error
ListOpenPositions(ctx context.Context, accountIDHash string) ([]domain.Position, error)
ListPositions(ctx context.Context, accountIDHash string, from, to time.Time) ([]domain.Position, error)
InsertRiskEvent(ctx context.Context, event domain.RiskEvent) error
GetFreeOrdersSent(ctx context.Context, tradeDate time.Time, instrumentUID string) (int, error)
IncrementFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int) error
GetSystemState(ctx context.Context) (domain.SystemState, bool, string, error)
SaveSystemState(ctx context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, contextJSON string) error
Unhalt(ctx context.Context, reason string) error
WasDailyReportSent(ctx context.Context, reportDate time.Time, accountIDHash string) (bool, error)
MarkDailyReportSent(ctx context.Context, reportDate time.Time, accountIDHash string) error
InsertReconciliation(ctx context.Context, ts time.Time, diffJSON string, hasDiff bool) error
}
+27
View File
@@ -0,0 +1,27 @@
package risk
import (
"context"
"errors"
"testing"
"time"
"overnight-trading-bot/internal/domain"
)
func TestFreeOrderBudgetSubmittedPolicy(t *testing.T) {
ctx := context.Background()
store := NewMemoryFreeOrderStore()
budget := NewFreeOrderBudget(store)
instr := domain.Instrument{InstrumentUID: "uid", FreeOrderLimitPerDay: 2}
date := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
if _, err := budget.Check(ctx, date, instr, 2); err != nil {
t.Fatal(err)
}
if err := budget.Submitted(ctx, date, instr.InstrumentUID); err != nil {
t.Fatal(err)
}
if _, err := budget.Check(ctx, date, instr, 2); !errors.Is(err, ErrFreeOrderBudget) {
t.Fatalf("expected ErrFreeOrderBudget, got %v", err)
}
}
+127
View File
@@ -0,0 +1,127 @@
package risk
import (
"context"
"fmt"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
type EventSink interface {
InsertRiskEvent(ctx context.Context, event domain.RiskEvent) error
SaveSystemState(ctx context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, contextJSON string) error
}
type Manager struct {
sink EventSink
cfg ManagerConfig
}
type ManagerConfig struct {
MaxDailyLossPct decimal.Decimal
MaxWeeklyLossPct decimal.Decimal
MaxMonthlyDrawdownPct decimal.Decimal
MaxAvgSlippageBps10Trades decimal.Decimal
MaxOpenPositions int
MinTimeToClose time.Duration
MaxQuoteAge time.Duration
}
type PreTradeInput struct {
Portfolio domain.Portfolio
OpenPositions int
DailyPnL decimal.Decimal
WeeklyPnL decimal.Decimal
MonthlyDrawdownPct decimal.Decimal
AvgSlippageBps10 decimal.Decimal
TradingStatus domain.TradingStatus
QuoteReceivedAt time.Time
Now time.Time
MarketClose time.Time
DatabaseUnavailable bool
UnknownBrokerOrder bool
UnknownBrokerHolding bool
}
type PreTradeResult struct {
Allowed bool
Reason string
}
func NewManager(sink EventSink, cfg ManagerConfig) Manager {
return Manager{sink: sink, cfg: cfg}
}
func (m Manager) Halt(ctx context.Context, mode domain.Mode, eventType, reason string, instrumentUID string) error {
if m.sink == nil {
return nil
}
event := domain.RiskEvent{
TS: time.Now().UTC(),
Severity: domain.SeverityCritical,
EventType: eventType,
InstrumentUID: instrumentUID,
Message: reason,
}
if err := m.sink.InsertRiskEvent(ctx, event); err != nil {
return fmt.Errorf("insert halt risk event: %w", err)
}
if err := m.sink.SaveSystemState(ctx, domain.StateHalted, mode, true, reason, "{}"); err != nil {
return fmt.Errorf("persist halt state: %w", err)
}
return nil
}
func (m Manager) PreTradeCheck(input PreTradeInput) PreTradeResult {
now := input.Now
if now.IsZero() {
now = time.Now().UTC()
}
switch {
case input.DatabaseUnavailable:
return reject("database_unavailable")
case input.UnknownBrokerOrder:
return reject("unknown_broker_order")
case input.UnknownBrokerHolding:
return reject("unknown_broker_position")
case input.TradingStatus == domain.TradingStatusUnknown:
return reject("trading_status_unknown_before_order")
case input.TradingStatus != domain.TradingStatusNormal:
return reject("trading_status_not_normal")
case m.cfg.MaxOpenPositions > 0 && input.OpenPositions >= m.cfg.MaxOpenPositions:
return reject("max_open_positions")
case DailyLossBreached(input.DailyPnL, input.Portfolio.Equity, m.cfg.MaxDailyLossPct):
return reject("max_daily_loss")
case DailyLossBreached(input.WeeklyPnL, input.Portfolio.Equity, m.cfg.MaxWeeklyLossPct):
return reject("max_weekly_loss")
case m.cfg.MaxMonthlyDrawdownPct.IsPositive() && input.MonthlyDrawdownPct.GreaterThanOrEqual(m.cfg.MaxMonthlyDrawdownPct):
return reject("max_monthly_drawdown")
case m.cfg.MaxAvgSlippageBps10Trades.IsPositive() && input.AvgSlippageBps10.GreaterThan(m.cfg.MaxAvgSlippageBps10Trades):
return reject("max_avg_slippage_bps_10_trades")
case m.cfg.MaxQuoteAge > 0 && !input.QuoteReceivedAt.IsZero() && now.Sub(input.QuoteReceivedAt) > m.cfg.MaxQuoteAge:
return reject("quote_age_too_high")
case m.cfg.MinTimeToClose > 0 && !input.MarketClose.IsZero() && input.MarketClose.Sub(now) < m.cfg.MinTimeToClose:
return reject("min_time_to_close_sec")
default:
return PreTradeResult{Allowed: true}
}
}
func DailyLossBreached(pnl, equity, maxLossPct decimal.Decimal) bool {
if !equity.IsPositive() || !maxLossPct.IsPositive() {
return false
}
limit := equity.Mul(maxLossPct).Neg()
return pnl.LessThanOrEqual(limit)
}
func CommissionBreached(actualCommission decimal.Decimal, requireZero bool) bool {
return requireZero && actualCommission.IsPositive()
}
func reject(reason string) PreTradeResult {
return PreTradeResult{Allowed: false, Reason: reason}
}
+166
View File
@@ -0,0 +1,166 @@
package risk
import (
"context"
"errors"
"sync"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/money"
)
var (
ErrNoSizingCapacity = errors.New("no sizing capacity")
ErrFreeOrderBudget = errors.New("free order budget is insufficient")
)
type SizingConfig struct {
MaxPositionPct decimal.Decimal
MaxTotalExposurePct decimal.Decimal
MaxParticipationRate decimal.Decimal
CashUsageBuffer decimal.Decimal
RiskBudgetPerInstrumentPct decimal.Decimal
MinOrderNotionalRUB decimal.Decimal
}
type SizingInput struct {
Portfolio domain.Portfolio
SelectedInstruments int
LimitPrice decimal.Decimal
Lot int64
EntryIntervalVolume decimal.Decimal
ExitIntervalVolume decimal.Decimal
Q05OvernightAbs decimal.Decimal
}
type SizingResult struct {
TargetNotional decimal.Decimal
Lots int64
Reason string
Limits map[string]decimal.Decimal
}
type Sizer struct {
cfg SizingConfig
sizeFactor decimal.Decimal
}
func NewSizer(cfg SizingConfig) Sizer {
return Sizer{cfg: cfg, sizeFactor: decimal.NewFromInt(1)}
}
func (s Sizer) WithSizeFactor(factor decimal.Decimal) Sizer {
if !factor.IsPositive() {
factor = decimal.NewFromInt(1)
}
s.sizeFactor = factor
return s
}
func (s Sizer) Size(input SizingInput) SizingResult {
limits := make(map[string]decimal.Decimal, 6)
if input.SelectedInstruments <= 0 {
input.SelectedInstruments = 1
}
capLimit := input.Portfolio.Equity.Mul(s.cfg.MaxPositionPct)
exposureLimit := input.Portfolio.Equity.Mul(s.cfg.MaxTotalExposurePct).
Div(decimal.NewFromInt(int64(input.SelectedInstruments)))
liquidityLimit := money.Min(input.EntryIntervalVolume, input.ExitIntervalVolume).
Mul(s.cfg.MaxParticipationRate)
cashLimit := input.Portfolio.Cash.Mul(s.cfg.CashUsageBuffer)
riskLimit := capLimit
if input.Q05OvernightAbs.IsPositive() {
riskBudget := input.Portfolio.Equity.Mul(s.cfg.RiskBudgetPerInstrumentPct)
riskLimit = riskBudget.Div(input.Q05OvernightAbs)
}
limits["cap"] = capLimit
limits["exposure"] = exposureLimit
limits["liquidity"] = liquidityLimit
limits["risk"] = riskLimit
limits["cash"] = cashLimit
sizeFactor := s.effectiveSizeFactor()
limits["size_factor"] = sizeFactor
target := money.Min(capLimit, exposureLimit, liquidityLimit, riskLimit, cashLimit).Mul(sizeFactor)
if !target.IsPositive() || !input.LimitPrice.IsPositive() || input.Lot <= 0 {
return SizingResult{Reason: "non_positive_limit", Limits: limits}
}
lotNotional := input.LimitPrice.Mul(decimal.NewFromInt(input.Lot))
lots := target.Div(lotNotional).Floor().IntPart()
notional := lotNotional.Mul(decimal.NewFromInt(lots))
if lots < 1 {
return SizingResult{TargetNotional: notional, Lots: lots, Reason: "lots_below_one", Limits: limits}
}
if notional.LessThan(s.cfg.MinOrderNotionalRUB) {
return SizingResult{TargetNotional: notional, Lots: 0, Reason: "min_order_notional", Limits: limits}
}
return SizingResult{TargetNotional: notional, Lots: lots, Limits: limits}
}
func (s Sizer) effectiveSizeFactor() decimal.Decimal {
if !s.sizeFactor.IsPositive() {
return decimal.NewFromInt(1)
}
return s.sizeFactor
}
type FreeOrderStore interface {
GetFreeOrdersSent(ctx context.Context, tradeDate time.Time, instrumentUID string) (int, error)
IncrementFreeOrders(ctx context.Context, tradeDate time.Time, instrumentUID string, delta int) error
}
type FreeOrderBudget struct {
store FreeOrderStore
}
func NewFreeOrderBudget(store FreeOrderStore) FreeOrderBudget {
return FreeOrderBudget{store: store}
}
func (b FreeOrderBudget) Check(ctx context.Context, tradeDate time.Time, instr domain.Instrument, ordersNeeded int) (int, error) {
if instr.FreeOrderLimitPerDay <= 0 {
return 0, nil
}
sent, err := b.store.GetFreeOrdersSent(ctx, tradeDate, instr.InstrumentUID)
if err != nil {
return 0, err
}
remaining := instr.FreeOrderLimitPerDay - sent
if remaining < ordersNeeded {
return remaining, ErrFreeOrderBudget
}
return remaining, nil
}
func (b FreeOrderBudget) Submitted(ctx context.Context, tradeDate time.Time, instrumentUID string) error {
return b.store.IncrementFreeOrders(ctx, tradeDate, instrumentUID, 1)
}
type MemoryFreeOrderStore struct {
mu sync.Mutex
counts map[string]int
}
func NewMemoryFreeOrderStore() *MemoryFreeOrderStore {
return &MemoryFreeOrderStore{counts: make(map[string]int)}
}
func (s *MemoryFreeOrderStore) GetFreeOrdersSent(_ context.Context, tradeDate time.Time, instrumentUID string) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.counts[freeOrderKey(tradeDate, instrumentUID)], nil
}
func (s *MemoryFreeOrderStore) IncrementFreeOrders(_ context.Context, tradeDate time.Time, instrumentUID string, delta int) error {
s.mu.Lock()
defer s.mu.Unlock()
s.counts[freeOrderKey(tradeDate, instrumentUID)] += delta
return nil
}
func freeOrderKey(tradeDate time.Time, instrumentUID string) string {
return tradeDate.Format("2006-01-02") + "|" + instrumentUID
}
+172
View File
@@ -0,0 +1,172 @@
package risk
import (
"testing"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
func rd(raw string) decimal.Decimal {
v, err := decimal.NewFromString(raw)
if err != nil {
panic(err)
}
return v
}
func TestSizerTakesMinimumOfLimits(t *testing.T) {
sizer := NewSizer(SizingConfig{
MaxPositionPct: rd("0.10"),
MaxTotalExposurePct: rd("0.50"),
MaxParticipationRate: rd("0.01"),
CashUsageBuffer: rd("0.95"),
RiskBudgetPerInstrumentPct: rd("0.005"),
MinOrderNotionalRUB: rd("1000"),
})
got := sizer.Size(SizingInput{
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("90000")},
SelectedInstruments: 5,
LimitPrice: rd("100"),
Lot: 1,
EntryIntervalVolume: rd("1000000"),
ExitIntervalVolume: rd("1000000"),
Q05OvernightAbs: rd("0.05"),
})
if got.Lots != 100 || !got.TargetNotional.Equal(rd("10000")) {
t.Fatalf("unexpected sizing: %+v", got)
}
}
func TestSizerMinOrderGate(t *testing.T) {
sizer := NewSizer(SizingConfig{
MaxPositionPct: rd("0.10"),
MaxTotalExposurePct: rd("0.50"),
MaxParticipationRate: rd("0.01"),
CashUsageBuffer: rd("0.95"),
RiskBudgetPerInstrumentPct: rd("0.005"),
MinOrderNotionalRUB: rd("1000"),
})
got := sizer.Size(SizingInput{
Portfolio: domain.Portfolio{Equity: rd("10000"), Cash: rd("10000")},
SelectedInstruments: 1,
LimitPrice: rd("999"),
Lot: 1,
EntryIntervalVolume: rd("1000000"),
ExitIntervalVolume: rd("1000000"),
Q05OvernightAbs: rd("0.05"),
})
if got.Lots != 0 || got.Reason != "min_order_notional" {
t.Fatalf("unexpected min order gate: %+v", got)
}
}
func TestSizerBindingLimits(t *testing.T) {
sizer := NewSizer(SizingConfig{
MaxPositionPct: rd("0.10"),
MaxTotalExposurePct: rd("0.50"),
MaxParticipationRate: rd("0.01"),
CashUsageBuffer: rd("0.95"),
RiskBudgetPerInstrumentPct: rd("0.005"),
MinOrderNotionalRUB: rd("1"),
})
tests := []struct {
name string
input SizingInput
want decimal.Decimal
}{
{
name: "cap",
input: SizingInput{
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("100000")},
SelectedInstruments: 1,
LimitPrice: rd("100"),
Lot: 1,
EntryIntervalVolume: rd("5000000"),
ExitIntervalVolume: rd("5000000"),
},
want: rd("10000"),
},
{
name: "exposure",
input: SizingInput{
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("100000")},
SelectedInstruments: 10,
LimitPrice: rd("100"),
Lot: 1,
EntryIntervalVolume: rd("5000000"),
ExitIntervalVolume: rd("5000000"),
},
want: rd("5000"),
},
{
name: "liquidity",
input: SizingInput{
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("100000")},
SelectedInstruments: 1,
LimitPrice: rd("100"),
Lot: 1,
EntryIntervalVolume: rd("300000"),
ExitIntervalVolume: rd("500000"),
},
want: rd("3000"),
},
{
name: "risk",
input: SizingInput{
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("100000")},
SelectedInstruments: 1,
LimitPrice: rd("100"),
Lot: 1,
EntryIntervalVolume: rd("5000000"),
ExitIntervalVolume: rd("5000000"),
Q05OvernightAbs: rd("0.10"),
},
want: rd("5000"),
},
{
name: "cash",
input: SizingInput{
Portfolio: domain.Portfolio{Equity: rd("100000"), Cash: rd("2000")},
SelectedInstruments: 1,
LimitPrice: rd("100"),
Lot: 1,
EntryIntervalVolume: rd("5000000"),
ExitIntervalVolume: rd("5000000"),
},
want: rd("1900"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := sizer.Size(tt.input)
if !got.TargetNotional.Equal(tt.want) {
t.Fatalf("target=%s, want %s limits=%v", got.TargetNotional, tt.want, got.Limits)
}
})
}
}
func TestSizerAppliesSizeReductionFactor(t *testing.T) {
sizer := NewSizer(SizingConfig{
MaxPositionPct: rd("1"),
MaxTotalExposurePct: rd("1"),
MaxParticipationRate: rd("1"),
CashUsageBuffer: rd("1"),
RiskBudgetPerInstrumentPct: rd("1"),
MinOrderNotionalRUB: rd("1"),
}).WithSizeFactor(rd("0.5"))
got := sizer.Size(SizingInput{
Portfolio: domain.Portfolio{Equity: rd("10000"), Cash: rd("10000")},
SelectedInstruments: 1,
LimitPrice: rd("100"),
Lot: 1,
EntryIntervalVolume: rd("10000"),
ExitIntervalVolume: rd("10000"),
Q05OvernightAbs: rd("1"),
})
if got.Lots != 50 || !got.TargetNotional.Equal(rd("5000")) {
t.Fatalf("unexpected reduced sizing: %+v", got)
}
}
+905
View File
@@ -0,0 +1,905 @@
package scheduler
import (
"context"
"database/sql"
"errors"
"fmt"
"log/slog"
"sort"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/execution"
"overnight-trading-bot/internal/features"
"overnight-trading-bot/internal/instruments"
"overnight-trading-bot/internal/marketdata"
"overnight-trading-bot/internal/money"
"overnight-trading-bot/internal/notify"
"overnight-trading-bot/internal/position"
"overnight-trading-bot/internal/reconciliation"
"overnight-trading-bot/internal/report"
"overnight-trading-bot/internal/repository"
"overnight-trading-bot/internal/risk"
"overnight-trading-bot/internal/signal"
"overnight-trading-bot/internal/statemachine"
"overnight-trading-bot/internal/timeutil"
"overnight-trading-bot/internal/tinvest"
)
const (
sizeReductionWindowTrades = 20
sizeReductionFactor = 0.5
)
type Config struct {
Mode domain.Mode
Location *time.Location
RollingLong int
TickInterval time.Duration
EntrySignalTime timeutil.TimeOfDay
EntryWindowStart timeutil.TimeOfDay
EntryWindowEnd timeutil.TimeOfDay
NoNewEntryAfter timeutil.TimeOfDay
ExitWatchStart timeutil.TimeOfDay
ExitWindowStart timeutil.TimeOfDay
ExitWindowEnd timeutil.TimeOfDay
HardExitDeadline timeutil.TimeOfDay
QuoteDepth int32
MaxQuoteAge time.Duration
OrderPollInterval time.Duration
PassiveImproveTicks int
MaxEntryOrderAttempts int
MaxExitOrderAttempts int
MinTimeToClose time.Duration
MaxClockDrift time.Duration
APIOutageHalt time.Duration
}
type Services struct {
Repo repository.Repository
Gateway tinvest.Gateway
Registry instruments.Registry
MarketData marketdata.Loader
Features features.Pipeline
Signals signal.Engine
Sizer risk.Sizer
FreeOrders risk.FreeOrderBudget
Risk risk.Manager
Execution *execution.Engine
Positions position.Manager
Reconcile reconciliation.Engine
Notifier notify.Notifier
AccountID string
AccountIDHash string
Log *slog.Logger
}
type Scheduler struct {
clock timeutil.Clock
sm statemachine.System
cfg Config
svc Services
infraFailedSince time.Time
}
func New(clock timeutil.Clock, sm statemachine.System, cfg Config, svc Services) Scheduler {
if cfg.TickInterval <= 0 {
cfg.TickInterval = 30 * time.Second
}
if cfg.Location == nil {
cfg.Location = time.UTC
}
return Scheduler{clock: clock, sm: sm, cfg: cfg, svc: svc}
}
func (s *Scheduler) Run(ctx context.Context) error {
for {
if err := s.Step(ctx); err != nil {
if errors.Is(err, statemachine.ErrSystemHalted) {
s.logWarn("scheduler paused in HALT", "err", err)
} else if err := s.halt(ctx, "scheduler_error", err.Error(), ""); err != nil {
return err
}
}
if !s.clock.Sleep(ctx.Done(), s.cfg.TickInterval) {
return ctx.Err()
}
}
}
func (s *Scheduler) Step(ctx context.Context) error {
if err := s.checkInfrastructure(ctx); err != nil {
return err
}
now := s.clock.Now().In(s.cfg.Location)
phase := s.phase(now)
switch phase {
case domain.StateWaitExitWindow:
return s.waitExit(ctx, now)
case domain.StatePlaceExitOrders:
return s.placeExitOrders(ctx, now)
case domain.StateMonitorExitOrders:
return s.monitorExitOrders(ctx, now)
case domain.StateReconcile:
return s.failOpenPositionsAtHardDeadline(ctx)
case domain.StateGenerateSignals:
return s.prepareSignals(ctx, now)
case domain.StatePlaceEntryOrders:
return s.placeEntryOrders(ctx, now)
case domain.StateMonitorEntryOrders:
return s.monitorEntryOrders(ctx, now)
case domain.StateHoldOvernight:
return s.holdOvernight(ctx)
default:
return s.sm.Heartbeat(ctx, domain.StateSleep)
}
}
func (s Scheduler) phase(now time.Time) domain.SystemState {
tod := sinceMidnight(now)
switch {
case tod >= s.cfg.ExitWatchStart.Duration && tod < s.cfg.ExitWindowStart.Duration:
return domain.StateWaitExitWindow
case tod >= s.cfg.ExitWindowStart.Duration && tod < s.cfg.ExitWindowEnd.Duration:
return domain.StatePlaceExitOrders
case tod >= s.cfg.ExitWindowEnd.Duration && tod < s.cfg.HardExitDeadline.Duration:
return domain.StateMonitorExitOrders
case tod >= s.cfg.HardExitDeadline.Duration && tod < s.cfg.EntrySignalTime.Duration:
return domain.StateReconcile
case tod >= s.cfg.EntrySignalTime.Duration && tod < s.cfg.EntryWindowStart.Duration:
return domain.StateGenerateSignals
case tod >= s.cfg.EntryWindowStart.Duration && tod < s.cfg.NoNewEntryAfter.Duration:
return domain.StatePlaceEntryOrders
case tod >= s.cfg.NoNewEntryAfter.Duration:
return domain.StateHoldOvernight
default:
return domain.StateSleep
}
}
func (s *Scheduler) prepareSignals(ctx context.Context, now time.Time) error {
if err := s.transitionSequence(ctx,
domain.StateInit,
domain.StateSyncInstruments,
domain.StateSyncMarketData,
domain.StateGenerateSignals,
); err != nil {
return err
}
if err := s.svc.Registry.SyncMetadata(ctx); err != nil {
return err
}
tradeDate := tradingDate(now)
instrumentsList, err := s.svc.Repo.ListInstruments(ctx, false)
if err != nil {
return err
}
if err := s.svc.MarketData.BackfillDaily(ctx, instrumentsList, tradeDate.AddDate(0, 0, -s.cfg.RollingLong-10), tradeDate); err != nil {
return err
}
minuteFrom := s.cfg.EntryWindowStart.On(tradeDate, s.cfg.Location)
minuteTo := s.cfg.ExitWindowEnd.On(tradeDate.AddDate(0, 0, 1), s.cfg.Location)
if err := s.svc.MarketData.BackfillMinute(ctx, instrumentsList, minuteFrom, minuteTo); err != nil {
s.logWarn("minute backfill failed; liquidity will fall back to ADV", "err", err)
}
if err := s.applySizeReductionRule(ctx, tradeDate, false); err != nil {
return err
}
portfolio, err := s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID)
if err != nil {
return err
}
openPositions, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
if err != nil {
return err
}
for _, instrument := range instrumentsList {
if err := s.generateInstrumentSignal(ctx, now, tradeDate, portfolio, len(openPositions), instrument); err != nil {
return err
}
}
return s.transitionTo(ctx, domain.StateWaitEntryWindow)
}
func (s Scheduler) generateInstrumentSignal(ctx context.Context, now, tradeDate time.Time, portfolio domain.Portfolio, openPositionCount int, instrument domain.Instrument) error {
book, err := s.svc.MarketData.LatestQuote(ctx, instrument.InstrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
if err != nil {
return s.saveRejectedSignal(ctx, tradeDate, instrument, "quote_unavailable", err)
}
spread, err := spreadFromBook(book, instrument.MinPriceIncrement)
if err != nil {
return s.saveRejectedSignal(ctx, tradeDate, instrument, "spread_unavailable", err)
}
tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, instrument.InstrumentUID)
if err != nil {
tradingStatus = domain.TradingStatusUnknown
}
feature, err := s.svc.Features.Recompute(ctx, instrument, tradeDate, spread)
if err != nil {
return s.saveRejectedSignal(ctx, tradeDate, instrument, "features_unavailable", err)
}
remaining, err := s.svc.FreeOrders.Check(ctx, tradeDate, instrument, 1)
freeOrderOK := err == nil
sig := s.svc.Signals.Evaluate(signal.Candidate{
Instrument: instrument,
Features: feature,
TradingStatus: tradingStatus,
FreeOrderOK: freeOrderOK,
OpenPositions: openPositionCount,
TradeDate: tradeDate,
ExtraContext: map[string]any{
"free_orders_remaining": remaining,
"quote_time": book.Time.Format(time.RFC3339),
},
})
if sig.Decision == domain.DecisionEnter {
sized, sizingErr := s.sizeSignal(ctx, portfolio, instrument, feature, book, 1)
switch {
case sizingErr != nil:
sig.Decision = domain.DecisionReject
sig.RejectReason = sizingErr.Error()
case sized.Lots <= 0:
sig.Decision = domain.DecisionReject
sig.RejectReason = sized.Reason
default:
sig.TargetLots = sized.Lots
sig.TargetNotional = sized.TargetNotional
}
}
if err := s.svc.Repo.UpsertSignal(ctx, sig); err != nil {
return err
}
return s.notifySignal(ctx, now, sig)
}
func (s Scheduler) saveRejectedSignal(ctx context.Context, tradeDate time.Time, instrument domain.Instrument, reason string, cause error) error {
sig := domain.Signal{
TradeDate: tradeDate,
InstrumentUID: instrument.InstrumentUID,
Decision: domain.DecisionReject,
RejectReason: reason,
ContextJSON: fmt.Sprintf(`{"error":%q}`, cause.Error()),
CreatedAt: s.nowUTC(),
}
return s.svc.Repo.UpsertSignal(ctx, sig)
}
func (s Scheduler) sizeSignal(_ context.Context, portfolio domain.Portfolio, instrument domain.Instrument, feature domain.FeatureSet, book domain.OrderBook, selected int) (risk.SizingResult, error) {
bid, ask, err := bestBidAsk(book)
if err != nil {
return risk.SizingResult{}, err
}
price, err := execution.LimitBuyPrice(bid, ask, instrument.MinPriceIncrement, s.cfg.PassiveImproveTicks)
if err != nil {
return risk.SizingResult{}, err
}
return s.svc.Sizer.Size(risk.SizingInput{
Portfolio: portfolio,
SelectedInstruments: selected,
LimitPrice: price,
Lot: instrument.Lot,
EntryIntervalVolume: feature.EntryIntervalVolume,
ExitIntervalVolume: feature.ExitIntervalVolume,
Q05OvernightAbs: money.Abs(feature.SigmaOn60).Mul(decimal.NewFromFloat(1.65)),
}), nil
}
func (s Scheduler) placeEntryOrders(ctx context.Context, now time.Time) error {
if err := s.transitionTo(ctx, domain.StatePlaceEntryOrders); err != nil {
return err
}
tradeDate := tradingDate(now)
signals, err := s.svc.Repo.ListSignals(ctx, tradeDate)
if err != nil {
return err
}
existing, err := s.svc.Repo.ListOrders(ctx, s.svc.AccountIDHash, tradeDate, tradeDate)
if err != nil {
return err
}
openPositions, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
if err != nil {
return err
}
instrumentByUID, err := s.instrumentMap(ctx)
if err != nil {
return err
}
for _, sig := range signals {
if sig.Decision != domain.DecisionEnter || sig.TargetLots <= 0 || hasOrder(existing, sig.InstrumentUID, domain.SideBuy) {
continue
}
instrument, ok := instrumentByUID[sig.InstrumentUID]
if !ok {
return fmt.Errorf("instrument %s is not in registry", sig.InstrumentUID)
}
book, err := s.svc.MarketData.LatestQuote(ctx, sig.InstrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
if err != nil {
return err
}
tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, sig.InstrumentUID)
if err != nil {
tradingStatus = domain.TradingStatusUnknown
}
portfolio, err := s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID)
if err != nil {
return err
}
pre := s.svc.Risk.PreTradeCheck(risk.PreTradeInput{
Portfolio: portfolio,
OpenPositions: len(openPositions),
TradingStatus: tradingStatus,
QuoteReceivedAt: book.ReceivedAt,
Now: now.UTC(),
MarketClose: s.cfg.EntryWindowEnd.On(now, s.cfg.Location).UTC(),
})
if !pre.Allowed {
if err := s.svc.Repo.InsertRiskEvent(ctx, domain.RiskEvent{
Severity: domain.SeverityWarn,
EventType: "pre_trade_reject",
InstrumentUID: sig.InstrumentUID,
Message: pre.Reason,
ContextJSON: "{}",
}); err != nil {
return err
}
continue
}
placed, err := s.svc.Execution.PlaceEntry(ctx, s.svc.AccountIDHash, instrument, tradeDate, sig.TargetLots, book, s.cfg.PassiveImproveTicks, 1)
if err != nil && !errors.Is(err, execution.ErrBrokerOrdersDisabled) {
return err
}
_ = s.svc.Notifier.Info(ctx, fmt.Sprintf("entry order %s %s lots=%d status=%s", instrument.Ticker, placed.Side, placed.QuantityLots, placed.Status))
existing = append(existing, placed)
}
return s.transitionTo(ctx, domain.StateMonitorEntryOrders)
}
func (s Scheduler) monitorEntryOrders(ctx context.Context, now time.Time) error {
if err := s.transitionTo(ctx, domain.StateMonitorEntryOrders); err != nil {
return err
}
orders, err := s.svc.Repo.ListActiveOrders(ctx, s.svc.AccountIDHash)
if err != nil {
return err
}
instrumentByUID, err := s.instrumentMap(ctx)
if err != nil {
return err
}
deadline := s.cfg.NoNewEntryAfter.On(now, s.cfg.Location).UTC()
for _, order := range orders {
if order.Side != domain.SideBuy || order.BrokerOrderID == "" {
continue
}
instrument, ok := instrumentByUID[order.InstrumentUID]
if !ok {
return fmt.Errorf("instrument %s is not in registry", order.InstrumentUID)
}
monitored, err := s.svc.Execution.MonitorUntil(ctx, order, execution.MonitorConfig{
Deadline: deadline,
PollInterval: s.cfg.OrderPollInterval,
MaxAttempts: s.cfg.MaxEntryOrderAttempts,
RepostAfter: repostAfter(now, deadline, s.cfg.MaxEntryOrderAttempts, s.cfg.OrderPollInterval),
Instrument: instrument,
ImproveTicks: s.cfg.PassiveImproveTicks,
Quote: func(ctx context.Context, instrumentUID string) (domain.OrderBook, error) {
return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
},
})
if err != nil {
return err
}
if monitored.FilledLots > order.FilledLots || monitored.Commission.GreaterThan(order.Commission) {
pos, err := s.svc.Positions.OnEntryFill(ctx, s.svc.AccountIDHash, instrument, monitored)
if err != nil {
return err
}
_ = s.svc.Notifier.Info(ctx, fmt.Sprintf("entry fill %s lots=%d status=%s", monitored.InstrumentUID, monitored.FilledLots, pos.Status))
}
}
if sinceMidnight(s.nowUTC().In(s.cfg.Location)) >= s.cfg.NoNewEntryAfter.Duration {
if err := s.cancelActiveOrders(ctx, domain.SideBuy, domain.OrderStatusCancelled, "entry_window_closed"); err != nil {
return err
}
return s.transitionTo(ctx, domain.StateHoldOvernight)
}
return nil
}
func (s Scheduler) waitExit(ctx context.Context, _ time.Time) error {
return s.transitionTo(ctx, domain.StateWaitExitWindow)
}
func (s Scheduler) holdOvernight(ctx context.Context) error {
if err := s.cancelActiveOrders(ctx, domain.SideBuy, domain.OrderStatusCancelled, "entry_window_closed"); err != nil {
return err
}
return s.transitionTo(ctx, domain.StateHoldOvernight)
}
func (s Scheduler) placeExitOrders(ctx context.Context, now time.Time) error {
if err := s.transitionTo(ctx, domain.StatePlaceExitOrders); err != nil {
return err
}
positionsList, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
if err != nil {
return err
}
existing, err := s.svc.Repo.ListOrders(ctx, s.svc.AccountIDHash, tradingDate(now).AddDate(0, 0, -1), tradingDate(now))
if err != nil {
return err
}
instrumentByUID, err := s.instrumentMap(ctx)
if err != nil {
return err
}
for _, pos := range positionsList {
if pos.Lots <= 0 || hasOrder(existing, pos.InstrumentUID, domain.SideSell) {
continue
}
instrument, ok := instrumentByUID[pos.InstrumentUID]
if !ok {
return fmt.Errorf("instrument %s is not in registry", pos.InstrumentUID)
}
book, err := s.svc.MarketData.LatestQuote(ctx, pos.InstrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
if err != nil {
return err
}
tradingStatus, err := s.svc.Gateway.GetTradingStatus(ctx, pos.InstrumentUID)
if err != nil {
tradingStatus = domain.TradingStatusUnknown
}
portfolio, err := s.svc.Gateway.GetPortfolio(ctx, s.svc.AccountID)
if err != nil {
return err
}
pre := s.svc.Risk.PreTradeCheck(risk.PreTradeInput{
Portfolio: portfolio,
OpenPositions: len(positionsList),
TradingStatus: tradingStatus,
QuoteReceivedAt: book.ReceivedAt,
Now: now.UTC(),
MarketClose: s.cfg.HardExitDeadline.On(now, s.cfg.Location).UTC(),
})
if !pre.Allowed {
return fmt.Errorf("exit pre-trade rejected: %s", pre.Reason)
}
placed, err := s.svc.Execution.PlaceExit(ctx, s.svc.AccountIDHash, instrument, pos.OpenTradeDate, pos.Lots, book, s.cfg.PassiveImproveTicks, 1)
if err != nil && !errors.Is(err, execution.ErrBrokerOrdersDisabled) {
return err
}
pos.Status = domain.PositionExitOrderSent
if err := s.svc.Repo.UpsertPosition(ctx, pos); err != nil {
return err
}
_ = s.svc.Notifier.Info(ctx, fmt.Sprintf("exit order %s lots=%d status=%s", instrument.Ticker, placed.QuantityLots, placed.Status))
existing = append(existing, placed)
}
return s.transitionTo(ctx, domain.StateMonitorExitOrders)
}
func (s Scheduler) monitorExitOrders(ctx context.Context, now time.Time) error {
if err := s.transitionTo(ctx, domain.StateMonitorExitOrders); err != nil {
return err
}
orders, err := s.svc.Repo.ListActiveOrders(ctx, s.svc.AccountIDHash)
if err != nil {
return err
}
openPositions, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
if err != nil {
return err
}
positionByInstrument := make(map[string]domain.Position, len(openPositions))
for _, pos := range openPositions {
positionByInstrument[pos.InstrumentUID] = pos
}
instrumentByUID, err := s.instrumentMap(ctx)
if err != nil {
return err
}
deadline := s.cfg.HardExitDeadline.On(now, s.cfg.Location).UTC()
for _, order := range orders {
if order.Side != domain.SideSell || order.BrokerOrderID == "" {
continue
}
instrument, ok := instrumentByUID[order.InstrumentUID]
if !ok {
return fmt.Errorf("instrument %s is not in registry", order.InstrumentUID)
}
monitored, err := s.svc.Execution.MonitorUntil(ctx, order, execution.MonitorConfig{
Deadline: deadline,
PollInterval: s.cfg.OrderPollInterval,
MaxAttempts: s.cfg.MaxExitOrderAttempts,
RepostAfter: repostAfter(now, deadline, s.cfg.MaxExitOrderAttempts, s.cfg.OrderPollInterval),
Instrument: instrument,
ImproveTicks: s.cfg.PassiveImproveTicks,
Quote: func(ctx context.Context, instrumentUID string) (domain.OrderBook, error) {
return s.svc.MarketData.LatestQuote(ctx, instrumentUID, s.cfg.QuoteDepth, s.cfg.MaxQuoteAge)
},
})
if err != nil {
return err
}
if monitored.FilledLots > order.FilledLots || monitored.Commission.GreaterThan(order.Commission) {
fill := exitFillDelta(order, monitored)
if fill.FilledLots <= 0 && fill.Commission.IsZero() {
continue
}
pos, ok := positionByInstrument[monitored.InstrumentUID]
if !ok {
return fmt.Errorf("exit fill for unknown local position %s", monitored.InstrumentUID)
}
updated, err := s.svc.Positions.OnExitFill(ctx, pos, fill)
if err != nil {
return err
}
positionByInstrument[monitored.InstrumentUID] = updated
_ = s.svc.Notifier.Info(ctx, fmt.Sprintf("exit fill %s lots=%d status=%s pnl=%s", monitored.InstrumentUID, monitored.FilledLots, updated.Status, updated.NetPnL.StringFixed(2)))
}
}
if sinceMidnight(s.nowUTC().In(s.cfg.Location)) >= s.cfg.HardExitDeadline.Duration {
return s.failOpenPositionsAtHardDeadline(ctx)
}
return nil
}
func (s *Scheduler) reconcileAndReport(ctx context.Context, now time.Time) error {
if err := s.transitionTo(ctx, domain.StateReconcile); err != nil {
return err
}
diffs, err := s.svc.Reconcile.Run(ctx)
if err != nil {
return err
}
if reconciliation.HasCritical(diffs) {
return s.halt(ctx, "reconciliation_critical", "critical reconciliation diff", "")
}
tradeDate := tradingDate(now)
sent, err := s.svc.Repo.WasDailyReportSent(ctx, tradeDate, s.svc.AccountIDHash)
if err != nil {
return err
}
if sent {
s.logWarn("daily report already sent; skipping duplicate", "date", tradeDate.Format("2006-01-02"))
return s.transitionTo(ctx, domain.StateSleep)
}
signals, err := s.svc.Repo.ListSignals(ctx, tradeDate)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
positionsList, err := s.svc.Repo.ListPositions(ctx, s.svc.AccountIDHash, tradeDate.AddDate(0, 0, -1), tradeDate)
if err != nil {
return err
}
if err := s.applySizeReductionRule(ctx, tradeDate, true); err != nil {
return err
}
if err := s.transitionTo(ctx, domain.StateReport); err != nil {
return err
}
msg := report.ComposeDaily(report.DailyInput{
Date: tradeDate,
Mode: s.cfg.Mode,
Signals: signals,
Positions: positionsList,
RiskStatus: "ok",
})
if err := s.svc.Notifier.Report(ctx, msg); err != nil {
return err
}
if err := s.svc.Repo.MarkDailyReportSent(ctx, tradeDate, s.svc.AccountIDHash); err != nil {
return err
}
return s.transitionTo(ctx, domain.StateSleep)
}
func (s *Scheduler) applySizeReductionRule(ctx context.Context, tradeDate time.Time, emitEvent bool) error {
averageError, count, ok, err := s.averageExpectedErrorBps(ctx, tradeDate, sizeReductionWindowTrades)
if err != nil {
return err
}
if !ok || count < sizeReductionWindowTrades || averageError.GreaterThanOrEqual(decimal.NewFromInt(-10)) {
s.svc.Sizer = s.svc.Sizer.WithSizeFactor(decimal.NewFromInt(1))
return nil
}
factor := decimal.NewFromFloat(sizeReductionFactor)
s.svc.Sizer = s.svc.Sizer.WithSizeFactor(factor)
if !emitEvent {
return nil
}
return s.svc.Repo.InsertRiskEvent(ctx, domain.RiskEvent{
Severity: domain.SeverityWarn,
EventType: "size_reduction_rule_triggered",
Message: fmt.Sprintf("average expected_error_bps over %d trades is %s; sizing factor set to %s", count, averageError.StringFixed(2), factor.String()),
ContextJSON: fmt.Sprintf(`{"average_expected_error_bps":%q,"trades":%d,"size_factor":%q}`, averageError.String(), count, factor.String()),
})
}
func (s Scheduler) averageExpectedErrorBps(ctx context.Context, tradeDate time.Time, limit int) (decimal.Decimal, int, bool, error) {
if limit <= 0 {
return decimal.Zero, 0, false, nil
}
positionsList, err := s.svc.Repo.ListPositions(ctx, s.svc.AccountIDHash, tradeDate.AddDate(0, 0, -120), tradeDate)
if err != nil {
return decimal.Zero, 0, false, err
}
sort.Slice(positionsList, func(i, j int) bool {
return positionsList[i].UpdatedAt.After(positionsList[j].UpdatedAt)
})
signalsByDate := make(map[string][]domain.Signal)
var errorsBps []decimal.Decimal
for _, pos := range positionsList {
if pos.Status != domain.PositionExitFilled {
continue
}
key := tradingDate(pos.OpenTradeDate).Format("2006-01-02")
signals, ok := signalsByDate[key]
if !ok {
signals, err = s.svc.Repo.ListSignals(ctx, tradingDate(pos.OpenTradeDate))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return decimal.Zero, 0, false, err
}
signalsByDate[key] = signals
}
for _, sig := range signals {
if sig.InstrumentUID != pos.InstrumentUID || sig.Decision != domain.DecisionEnter {
continue
}
errorsBps = append(errorsBps, pos.RealizedEdgeBps.Sub(sig.NetEdgeBps))
break
}
if len(errorsBps) == limit {
break
}
}
if len(errorsBps) == 0 {
return decimal.Zero, 0, false, nil
}
sum := decimal.Zero
for _, value := range errorsBps {
sum = sum.Add(value)
}
return sum.Div(decimal.NewFromInt(int64(len(errorsBps)))), len(errorsBps), true, nil
}
func (s *Scheduler) checkInfrastructure(ctx context.Context) error {
if s.cfg.MaxClockDrift <= 0 || s.svc.Gateway == nil {
return nil
}
serverTime, err := s.svc.Gateway.GetServerTime(ctx)
if err != nil {
if s.cfg.Mode == domain.ModePaper {
return nil
}
return s.recordInfrastructureFailure(fmt.Errorf("server_time_unavailable: %w", err))
}
drift := timeutil.Drift(s.nowUTC(), serverTime)
if drift > s.cfg.MaxClockDrift {
return s.recordInfrastructureFailure(fmt.Errorf("server_clock_drift_too_high: %s > %s", drift, s.cfg.MaxClockDrift))
}
s.infraFailedSince = time.Time{}
return nil
}
func (s *Scheduler) recordInfrastructureFailure(err error) error {
now := s.nowUTC()
if s.infraFailedSince.IsZero() {
s.infraFailedSince = now
s.logWarn("infrastructure check failed; waiting for outage threshold", "err", err, "threshold", s.cfg.APIOutageHalt)
return nil
}
if s.cfg.APIOutageHalt <= 0 || now.Sub(s.infraFailedSince) >= s.cfg.APIOutageHalt {
return err
}
s.logWarn("infrastructure check still failing", "err", err, "elapsed", now.Sub(s.infraFailedSince), "threshold", s.cfg.APIOutageHalt)
return nil
}
func (s Scheduler) cancelActiveOrders(ctx context.Context, side domain.Side, fallbackStatus domain.OrderStatus, reason string) error {
orders, err := s.svc.Repo.ListActiveOrders(ctx, s.svc.AccountIDHash)
if err != nil {
return err
}
cancelled := 0
for _, order := range orders {
if order.Side != side {
continue
}
if order.BrokerOrderID != "" && s.cfg.Mode.AllowsBrokerOrders() {
if err := s.svc.Execution.Cancel(ctx, order); err != nil {
return fmt.Errorf("cancel %s order %s: %w", side, order.ClientOrderID, err)
}
cancelled++
continue
}
if err := s.svc.Repo.UpdateOrderStatus(ctx, order.ClientOrderID, fallbackStatus, order.FilledLots, order.RawStateJSON); err != nil {
return fmt.Errorf("mark %s order %s %s: %w", side, order.ClientOrderID, fallbackStatus, err)
}
cancelled++
}
if cancelled == 0 {
return nil
}
if err := s.svc.Repo.InsertRiskEvent(ctx, domain.RiskEvent{
Severity: domain.SeverityWarn,
EventType: reason,
Message: fmt.Sprintf("cancelled %d active %s orders at window boundary", cancelled, side),
ContextJSON: "{}",
}); err != nil {
return err
}
return nil
}
func (s Scheduler) failOpenPositionsAtHardDeadline(ctx context.Context) error {
if err := s.cancelActiveOrders(ctx, domain.SideSell, domain.OrderStatusExpired, "hard_exit_deadline_cancel"); err != nil {
return err
}
positionsList, err := s.svc.Repo.ListOpenPositions(ctx, s.svc.AccountIDHash)
if err != nil {
return err
}
var failed []domain.Position
now := s.nowUTC()
for _, pos := range positionsList {
switch pos.Status {
case domain.PositionHoldingOvernight, domain.PositionExitPartiallyFilled, domain.PositionExitOrderSent:
pos.Status = domain.PositionExitFailed
pos.UpdatedAt = now
if err := s.svc.Repo.UpsertPosition(ctx, pos); err != nil {
return err
}
failed = append(failed, pos)
_ = s.svc.Notifier.Alert(ctx, fmt.Sprintf("exit_failed: %s lots=%d", pos.InstrumentUID, pos.Lots))
default:
}
}
if len(failed) == 0 {
return s.reconcileAndReport(ctx, s.nowUTC().In(s.cfg.Location))
}
return s.svc.Risk.Halt(ctx, s.cfg.Mode, "hard_exit_deadline_missed", fmt.Sprintf("%d positions remain open after hard deadline", len(failed)), "")
}
func (s Scheduler) nowUTC() time.Time {
if s.clock != nil {
return s.clock.Now().UTC()
}
return time.Now().UTC()
}
func repostAfter(now, deadline time.Time, attempts int, poll time.Duration) time.Duration {
if attempts <= 1 {
return 0
}
if poll <= 0 {
poll = 500 * time.Millisecond
}
remaining := deadline.Sub(now)
if remaining <= 0 {
return poll
}
after := remaining / time.Duration(attempts)
if after < poll {
return poll
}
return after
}
func (s Scheduler) transitionSequence(ctx context.Context, states ...domain.SystemState) error {
for _, state := range states {
if err := s.transitionTo(ctx, state); err != nil {
return err
}
}
return nil
}
func (s Scheduler) transitionTo(ctx context.Context, to domain.SystemState) error {
from, halted, reason, err := s.svc.Repo.GetSystemState(ctx)
if err != nil {
return err
}
if halted || from == domain.StateHalted {
return fmt.Errorf("%w: %s", statemachine.ErrSystemHalted, reason)
}
if from == to {
return s.sm.Heartbeat(ctx, to)
}
if err := s.sm.Transition(ctx, from, to); err != nil {
if errors.Is(err, statemachine.ErrIllegalTransition) {
return s.sm.Heartbeat(ctx, to)
}
return err
}
return nil
}
func (s Scheduler) halt(ctx context.Context, eventType, reason, instrumentUID string) error {
_ = s.svc.Notifier.Alert(ctx, fmt.Sprintf("%s: %s", eventType, reason))
return s.svc.Risk.Halt(ctx, s.cfg.Mode, eventType, reason, instrumentUID)
}
func (s Scheduler) notifySignal(ctx context.Context, _ time.Time, sig domain.Signal) error {
return s.svc.Notifier.Info(ctx, fmt.Sprintf("signal %s decision=%s edge=%s reason=%s lots=%d", sig.InstrumentUID, sig.Decision, sig.NetEdgeBps.StringFixed(2), sig.RejectReason, sig.TargetLots))
}
func (s Scheduler) instrumentMap(ctx context.Context) (map[string]domain.Instrument, error) {
instrumentsList, err := s.svc.Repo.ListInstruments(ctx, false)
if err != nil {
return nil, err
}
out := make(map[string]domain.Instrument, len(instrumentsList))
for _, instrument := range instrumentsList {
out[instrument.InstrumentUID] = instrument
}
return out, nil
}
func (s Scheduler) logWarn(msg string, args ...any) {
if s.svc.Log != nil {
s.svc.Log.Warn(msg, args...)
}
}
func exitFillDelta(previous, current domain.Order) domain.Order {
fill := current
fill.FilledLots = current.FilledLots - previous.FilledLots
if fill.FilledLots < 0 {
fill.FilledLots = 0
}
fill.Commission = current.Commission.Sub(previous.Commission)
if fill.Commission.IsNegative() {
fill.Commission = decimal.Zero
}
if fill.FilledLots > 0 {
currentValue := current.AvgFillPrice.Mul(decimal.NewFromInt(current.FilledLots))
previousValue := previous.AvgFillPrice.Mul(decimal.NewFromInt(previous.FilledLots))
fill.AvgFillPrice = currentValue.Sub(previousValue).Div(decimal.NewFromInt(fill.FilledLots))
}
return fill
}
func spreadFromBook(book domain.OrderBook, tick decimal.Decimal) (features.SpreadResult, error) {
bid, ask, err := bestBidAsk(book)
if err != nil {
return features.SpreadResult{}, err
}
return features.Spread(bid, ask, tick)
}
func bestBidAsk(book domain.OrderBook) (decimal.Decimal, decimal.Decimal, error) {
bid, ok := book.BestBid()
if !ok {
return decimal.Zero, decimal.Zero, execution.ErrEmptyOrderBook
}
ask, ok := book.BestAsk()
if !ok {
return decimal.Zero, decimal.Zero, execution.ErrEmptyOrderBook
}
return bid, ask, nil
}
func hasOrder(orders []domain.Order, instrumentUID string, side domain.Side) bool {
for _, order := range orders {
if order.InstrumentUID == instrumentUID && order.Side == side && order.Status != domain.OrderStatusFailed && order.Status != domain.OrderStatusRejected {
return true
}
}
return false
}
func sinceMidnight(t time.Time) time.Duration {
h, m, s := t.Clock()
return time.Duration(h)*time.Hour + time.Duration(m)*time.Minute + time.Duration(s)*time.Second
}
func tradingDate(t time.Time) time.Time {
y, m, d := t.Date()
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
}
+287
View File
@@ -0,0 +1,287 @@
package scheduler
import (
"context"
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/execution"
"overnight-trading-bot/internal/reconciliation"
"overnight-trading-bot/internal/risk"
"overnight-trading-bot/internal/statemachine"
"overnight-trading-bot/internal/testutil"
"overnight-trading-bot/internal/timeutil"
"overnight-trading-bot/internal/tinvest"
)
func TestPhaseUsesMoscowWindows(t *testing.T) {
loc := time.FixedZone("MSK", 3*60*60)
s := Scheduler{cfg: Config{
Location: loc,
EntrySignalTime: mustTOD("18:10:00"),
EntryWindowStart: mustTOD("18:20:00"),
NoNewEntryAfter: mustTOD("18:38:30"),
ExitWatchStart: mustTOD("09:50:00"),
ExitWindowStart: mustTOD("10:05:00"),
ExitWindowEnd: mustTOD("10:25:00"),
HardExitDeadline: mustTOD("10:45:00"),
}}
tests := []struct {
at string
want domain.SystemState
}{
{"2026-06-06T09:55:00+03:00", domain.StateWaitExitWindow},
{"2026-06-06T10:10:00+03:00", domain.StatePlaceExitOrders},
{"2026-06-06T10:30:00+03:00", domain.StateMonitorExitOrders},
{"2026-06-06T11:00:00+03:00", domain.StateReconcile},
{"2026-06-06T18:15:00+03:00", domain.StateGenerateSignals},
{"2026-06-06T18:25:00+03:00", domain.StatePlaceEntryOrders},
{"2026-06-06T19:00:00+03:00", domain.StateHoldOvernight},
}
for _, tt := range tests {
t.Run(tt.at, func(t *testing.T) {
at, err := time.Parse(time.RFC3339, tt.at)
if err != nil {
t.Fatal(err)
}
if got := s.phase(at.In(loc)); got != tt.want {
t.Fatalf("phase=%s, want %s", got, tt.want)
}
})
}
}
func TestInfrastructureOutageRequiresThreshold(t *testing.T) {
gateway := tinvest.NewFakeGateway()
gateway.ServerTime = time.Now().UTC().Add(-10 * time.Second)
s := &Scheduler{
cfg: Config{
Mode: domain.ModeSandbox,
MaxClockDrift: 2 * time.Second,
APIOutageHalt: 180 * time.Second,
},
svc: Services{Gateway: gateway},
}
if err := s.checkInfrastructure(context.Background()); err != nil {
t.Fatalf("first infrastructure failure should be tolerated: %v", err)
}
s.infraFailedSince = time.Now().UTC().Add(-181 * time.Second)
if err := s.checkInfrastructure(context.Background()); err == nil {
t.Fatalf("expected outage after threshold")
}
}
func TestReconcileAndReportIsIdempotentPerDate(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
gateway := tinvest.NewFakeGateway()
notifier := &countNotifier{}
recon := reconciliation.New(repo, gateway, "account", "hash")
s := Scheduler{
cfg: Config{Mode: domain.ModePaper, Location: time.UTC},
sm: statemachine.New(repo, domain.ModePaper),
svc: Services{
Repo: repo,
Gateway: gateway,
Reconcile: recon,
Notifier: notifier,
Risk: risk.NewManager(repo, risk.ManagerConfig{}),
AccountID: "account",
AccountIDHash: "hash",
},
}
now := time.Date(2026, 6, 7, 12, 0, 0, 0, time.UTC)
if err := s.reconcileAndReport(ctx, now); err != nil {
t.Fatal(err)
}
if err := s.reconcileAndReport(ctx, now); err != nil {
t.Fatal(err)
}
if notifier.reports != 1 {
t.Fatalf("reports sent=%d, want 1", notifier.reports)
}
}
func TestExitFillDeltaUsesOnlyNewlyExecutedLots(t *testing.T) {
previous := domain.Order{
FilledLots: 2,
AvgFillPrice: decimal.NewFromInt(100),
Commission: decimal.NewFromFloat(0.50),
}
current := domain.Order{
FilledLots: 4,
AvgFillPrice: decimal.NewFromInt(110),
Commission: decimal.NewFromFloat(1.25),
}
fill := exitFillDelta(previous, current)
if fill.FilledLots != 2 {
t.Fatalf("delta filled lots=%d, want 2", fill.FilledLots)
}
if !fill.AvgFillPrice.Equal(decimal.NewFromInt(120)) {
t.Fatalf("delta avg fill price=%s, want 120", fill.AvgFillPrice)
}
if !fill.Commission.Equal(decimal.NewFromFloat(0.75)) {
t.Fatalf("delta commission=%s, want 0.75", fill.Commission)
}
}
func TestHardDeadlineMarksOpenPositionFailedAndHalts(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
openDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
if err := repo.UpsertPosition(ctx, domain.Position{
AccountIDHash: "hash",
InstrumentUID: "uid",
OpenTradeDate: openDate,
Lots: 1,
Lot: 1,
Status: domain.PositionHoldingOvernight,
}); err != nil {
t.Fatal(err)
}
notifier := &countNotifier{}
s := Scheduler{
cfg: Config{Mode: domain.ModePaper, Location: time.UTC},
svc: Services{
Repo: repo,
Risk: risk.NewManager(repo, risk.ManagerConfig{}),
Notifier: notifier,
AccountIDHash: "hash",
},
}
if err := s.failOpenPositionsAtHardDeadline(ctx); err != nil {
t.Fatal(err)
}
if !repo.Halted || repo.State != domain.StateHalted {
t.Fatalf("system not halted: state=%s halted=%v", repo.State, repo.Halted)
}
positions, err := repo.ListOpenPositions(ctx, "hash")
if err != nil {
t.Fatal(err)
}
if len(positions) != 1 || positions[0].Status != domain.PositionExitFailed {
t.Fatalf("positions=%+v, want EXIT_FAILED", positions)
}
if notifier.alerts != 1 {
t.Fatalf("alerts=%d, want 1", notifier.alerts)
}
}
func TestHoldOvernightCancelsActiveBuyOrders(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
tradeDate := time.Date(2026, 6, 6, 0, 0, 0, 0, time.UTC)
if err := repo.UpsertOrder(ctx, domain.Order{
ClientOrderID: "buy",
AccountIDHash: "hash",
InstrumentUID: "uid",
TradeDate: tradeDate,
Side: domain.SideBuy,
OrderType: domain.OrderTypeLimit,
QuantityLots: 1,
Status: domain.OrderStatusNew,
}); err != nil {
t.Fatal(err)
}
s := Scheduler{
cfg: Config{Mode: domain.ModePaper, Location: time.UTC},
sm: statemachine.New(repo, domain.ModePaper),
svc: Services{
Repo: repo,
Execution: &execution.Engine{},
AccountIDHash: "hash",
},
}
if err := s.holdOvernight(ctx); err != nil {
t.Fatal(err)
}
orders, err := repo.ListOrders(ctx, "hash", tradeDate, tradeDate)
if err != nil {
t.Fatal(err)
}
if len(orders) != 1 || orders[0].Status != domain.OrderStatusCancelled {
t.Fatalf("orders=%+v, want CANCELLED", orders)
}
}
func TestSizeReductionRuleCutsSizerAfterBadExpectedErrors(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
tradeDate := time.Date(2026, 6, 30, 0, 0, 0, 0, time.UTC)
for i := 0; i < sizeReductionWindowTrades; i++ {
date := tradeDate.AddDate(0, 0, -i)
if err := repo.UpsertSignal(ctx, domain.Signal{
TradeDate: date,
InstrumentUID: "uid",
Decision: domain.DecisionEnter,
NetEdgeBps: decimal.NewFromInt(20),
}); err != nil {
t.Fatal(err)
}
if err := repo.UpsertPosition(ctx, domain.Position{
AccountIDHash: "hash",
InstrumentUID: "uid",
OpenTradeDate: date,
Lot: 1,
Status: domain.PositionExitFilled,
RealizedEdgeBps: decimal.Zero,
UpdatedAt: date.Add(time.Hour),
}); err != nil {
t.Fatal(err)
}
}
s := Scheduler{
svc: Services{
Repo: repo,
AccountIDHash: "hash",
Sizer: risk.NewSizer(risk.SizingConfig{
MaxPositionPct: decimal.NewFromInt(1),
MaxTotalExposurePct: decimal.NewFromInt(1),
MaxParticipationRate: decimal.NewFromInt(1),
CashUsageBuffer: decimal.NewFromInt(1),
RiskBudgetPerInstrumentPct: decimal.NewFromInt(1),
MinOrderNotionalRUB: decimal.NewFromInt(1),
}),
},
}
if err := s.applySizeReductionRule(ctx, tradeDate, true); err != nil {
t.Fatal(err)
}
sized := s.svc.Sizer.Size(risk.SizingInput{
Portfolio: domain.Portfolio{Equity: decimal.NewFromInt(10_000), Cash: decimal.NewFromInt(10_000)},
SelectedInstruments: 1,
LimitPrice: decimal.NewFromInt(100),
Lot: 1,
EntryIntervalVolume: decimal.NewFromInt(10_000),
ExitIntervalVolume: decimal.NewFromInt(10_000),
Q05OvernightAbs: decimal.NewFromInt(1),
})
if sized.Lots != 50 {
t.Fatalf("lots=%d, want reduced 50", sized.Lots)
}
if len(repo.RiskEvents) != 1 || repo.RiskEvents[0].EventType != "size_reduction_rule_triggered" {
t.Fatalf("risk events=%+v", repo.RiskEvents)
}
}
func mustTOD(raw string) timeutil.TimeOfDay {
tod, err := timeutil.ParseTimeOfDay(raw)
if err != nil {
panic(err)
}
return tod
}
type countNotifier struct {
reports int
alerts int
}
func (n *countNotifier) Info(context.Context, string) error { return nil }
func (n *countNotifier) Warn(context.Context, string) error { return nil }
func (n *countNotifier) Alert(context.Context, string) error { n.alerts++; return nil }
func (n *countNotifier) Report(context.Context, string) error { n.reports++; return nil }
func (n *countNotifier) Close() error { return nil }
+152
View File
@@ -0,0 +1,152 @@
package signal
import (
"encoding/json"
"strings"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
const (
ReasonDisabled = "instrument_disabled"
ReasonQuarantine = "instrument_quarantine"
ReasonMetadataInvalid = "metadata_invalid"
ReasonTradingStatus = "trading_status_not_normal"
ReasonCommission = "commission_nonzero"
ReasonMuShort = "mu_on_60_non_positive"
ReasonMuLong = "mu_on_252_non_positive"
ReasonSigmaZero = "sigma_on_60_zero"
ReasonTStat = "tstat_on_60_below_threshold"
ReasonWinRate = "win_on_60_below_threshold"
ReasonNetEdge = "net_edge_bps_below_threshold"
ReasonSpread = "spread_bps_above_limit"
ReasonTick = "tick_bps_above_limit"
ReasonADV = "adv_20_below_limit"
ReasonFreeOrders = "free_order_budget_insufficient"
ReasonMaxPositions = "max_positions_reached"
)
type Config struct {
MinTStat60 decimal.Decimal
MinWinRate60 decimal.Decimal
MinNetEdgeBps decimal.Decimal
MinADVRUB decimal.Decimal
MaxSpreadBpsDefault decimal.Decimal
MaxSpreadBpsMoneyMarket decimal.Decimal
MaxSpreadBpsBondFunds decimal.Decimal
MaxSpreadBpsEquityFunds decimal.Decimal
MaxTickBps decimal.Decimal
RequireZeroCommission bool
MaxPositions int
}
type Candidate struct {
Instrument domain.Instrument
Features domain.FeatureSet
TradingStatus domain.TradingStatus
FreeOrderOK bool
OpenPositions int
TradeDate time.Time
ExtraContext map[string]any
}
type Engine struct {
cfg Config
}
func New(cfg Config) Engine {
return Engine{cfg: cfg}
}
func (e Engine) Evaluate(c Candidate) domain.Signal {
reason := e.firstRejectReason(c)
decision := domain.DecisionEnter
if reason != "" {
decision = domain.DecisionReject
}
if isSkipReason(reason) {
decision = domain.DecisionSkip
}
context := map[string]any{
"ticker": c.Instrument.Ticker,
"fund_type": c.Instrument.FundType,
"trading_status": c.TradingStatus,
"spread_limit": e.spreadLimit(c.Instrument).String(),
}
for k, v := range c.ExtraContext {
context[k] = v
}
raw, _ := json.Marshal(context)
return domain.Signal{
TradeDate: c.TradeDate,
InstrumentUID: c.Instrument.InstrumentUID,
Decision: decision,
Score: c.Features.NetEdgeBps,
NetEdgeBps: c.Features.NetEdgeBps,
RejectReason: reason,
ContextJSON: string(raw),
CreatedAt: time.Now().UTC(),
}
}
func isSkipReason(reason string) bool {
return reason == ReasonFreeOrders || reason == ReasonMaxPositions
}
func (e Engine) firstRejectReason(c Candidate) string {
instr := c.Instrument
features := c.Features
switch {
case !instr.Enabled:
return ReasonDisabled
case instr.Quarantine:
return ReasonQuarantine
case !instr.MetadataValid():
return ReasonMetadataInvalid
case c.TradingStatus != domain.TradingStatusNormal:
return ReasonTradingStatus
case e.cfg.RequireZeroCommission && instr.ExpectedCommissionBpsPerSide.IsPositive():
return ReasonCommission
case !features.MuOn60.IsPositive():
return ReasonMuShort
case !features.MuOn252.IsPositive():
return ReasonMuLong
case !features.SigmaOn60.IsPositive():
return ReasonSigmaZero
case features.TStatOn60.LessThan(e.cfg.MinTStat60):
return ReasonTStat
case features.WinOn60.LessThan(e.cfg.MinWinRate60):
return ReasonWinRate
case features.NetEdgeBps.LessThan(e.cfg.MinNetEdgeBps):
return ReasonNetEdge
case features.SpreadBps.GreaterThan(e.spreadLimit(instr)):
return ReasonSpread
case features.TickBps.GreaterThan(e.cfg.MaxTickBps):
return ReasonTick
case features.ADV20.LessThan(e.cfg.MinADVRUB):
return ReasonADV
case !c.FreeOrderOK:
return ReasonFreeOrders
case e.cfg.MaxPositions > 0 && c.OpenPositions >= e.cfg.MaxPositions:
return ReasonMaxPositions
default:
return ""
}
}
func (e Engine) spreadLimit(instr domain.Instrument) decimal.Decimal {
fundType := strings.ToLower(instr.FundType)
switch {
case strings.Contains(fundType, "money"):
return e.cfg.MaxSpreadBpsMoneyMarket
case strings.Contains(fundType, "bond"):
return e.cfg.MaxSpreadBpsBondFunds
case strings.Contains(fundType, "equity"):
return e.cfg.MaxSpreadBpsEquityFunds
default:
return e.cfg.MaxSpreadBpsDefault
}
}
+87
View File
@@ -0,0 +1,87 @@
package signal
import (
"testing"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
func sd(raw string) decimal.Decimal {
v, err := decimal.NewFromString(raw)
if err != nil {
panic(err)
}
return v
}
func baseCandidate() Candidate {
return Candidate{
TradeDate: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC),
Instrument: domain.Instrument{
InstrumentUID: "uid",
Ticker: "TRUR",
ClassCode: "TQTF",
Lot: 1,
MinPriceIncrement: sd("0.01"),
Currency: "RUB",
Enabled: true,
},
Features: domain.FeatureSet{
MuOn60: sd("0.002"),
MuOn252: sd("0.001"),
SigmaOn60: sd("0.01"),
TStatOn60: sd("2"),
WinOn60: sd("0.60"),
NetEdgeBps: sd("20"),
SpreadBps: sd("5"),
TickBps: sd("1"),
ADV20: sd("10000000"),
},
TradingStatus: domain.TradingStatusNormal,
FreeOrderOK: true,
}
}
func TestEngineEnter(t *testing.T) {
engine := New(Config{
MinTStat60: sd("1.25"),
MinWinRate60: sd("0.55"),
MinNetEdgeBps: sd("10"),
MinADVRUB: sd("5000000"),
MaxSpreadBpsDefault: sd("20"),
MaxSpreadBpsMoneyMarket: sd("5"),
MaxSpreadBpsBondFunds: sd("10"),
MaxSpreadBpsEquityFunds: sd("25"),
MaxTickBps: sd("10"),
RequireZeroCommission: true,
MaxPositions: 5,
})
sig := engine.Evaluate(baseCandidate())
if sig.Decision != domain.DecisionEnter || sig.RejectReason != "" {
t.Fatalf("unexpected signal: %+v", sig)
}
}
func TestEngineFirstRejectReason(t *testing.T) {
engine := New(Config{MinTStat60: sd("1.25"), MinWinRate60: sd("0.55"), MinNetEdgeBps: sd("10"), MinADVRUB: sd("5000000"), MaxSpreadBpsDefault: sd("20"), MaxTickBps: sd("10"), RequireZeroCommission: true})
c := baseCandidate()
c.Features.MuOn60 = decimal.Zero
c.Features.NetEdgeBps = decimal.Zero
sig := engine.Evaluate(c)
if sig.RejectReason != ReasonMuShort {
t.Fatalf("reason=%s", sig.RejectReason)
}
}
func TestEngineUsesSkipForCapacityReasons(t *testing.T) {
engine := New(Config{MinTStat60: sd("1.25"), MinWinRate60: sd("0.55"), MinNetEdgeBps: sd("10"), MinADVRUB: sd("5000000"), MaxSpreadBpsDefault: sd("20"), MaxTickBps: sd("10"), RequireZeroCommission: true, MaxPositions: 1})
c := baseCandidate()
c.OpenPositions = 1
sig := engine.Evaluate(c)
if sig.Decision != domain.DecisionSkip || sig.RejectReason != ReasonMaxPositions {
t.Fatalf("unexpected skip signal: %+v", sig)
}
}
+120
View File
@@ -0,0 +1,120 @@
package statemachine
import (
"context"
"errors"
"fmt"
"time"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/reconciliation"
"overnight-trading-bot/internal/repository"
)
var (
ErrIllegalTransition = errors.New("illegal system transition")
ErrSystemHalted = errors.New("system is halted")
)
type System struct {
repo repository.Repository
mode domain.Mode
}
func New(repo repository.Repository, mode domain.Mode) System {
return System{repo: repo, mode: mode}
}
func (s System) Recover(ctx context.Context, reconcile reconciliation.Engine) (domain.SystemState, error) {
state, halted, reason, err := s.repo.GetSystemState(ctx)
if err != nil {
return "", err
}
if halted || state == domain.StateHalted {
return domain.StateHalted, fmt.Errorf("system halted: %s", reason)
}
switch state {
case domain.StatePlaceEntryOrders, domain.StateMonitorEntryOrders,
domain.StatePlaceExitOrders, domain.StateMonitorExitOrders,
domain.StateHoldOvernight:
diffs, err := reconcile.Run(ctx)
if err != nil {
return "", err
}
if reconciliation.HasCritical(diffs) {
if err := s.Halt(ctx, "critical reconciliation diff during recovery"); err != nil {
return "", err
}
return domain.StateHalted, errors.New("critical reconciliation diff during recovery")
}
return state, nil
case domain.StateInit, domain.StateSyncInstruments, domain.StateSyncMarketData, domain.StateGenerateSignals:
return domain.StateInit, s.persist(ctx, domain.StateInit, false, "")
default:
return state, nil
}
}
func (s System) Transition(ctx context.Context, from, to domain.SystemState) error {
current, halted, reason, err := s.repo.GetSystemState(ctx)
if err != nil {
return err
}
if (halted || current == domain.StateHalted) && to != domain.StateHalted {
return fmt.Errorf("%w: %s", ErrSystemHalted, reason)
}
if !legalTransition(from, to) {
return fmt.Errorf("%w: %s -> %s", ErrIllegalTransition, from, to)
}
return s.persist(ctx, to, false, "")
}
func (s System) Halt(ctx context.Context, reason string) error {
return s.persist(ctx, domain.StateHalted, true, reason)
}
func (s System) Heartbeat(ctx context.Context, state domain.SystemState) error {
current, halted, reason, err := s.repo.GetSystemState(ctx)
if err != nil {
return err
}
if halted || current == domain.StateHalted {
return s.repo.SaveSystemState(ctx, domain.StateHalted, s.mode, true, reason, fmt.Sprintf(`{"heartbeat":"%s"}`, time.Now().UTC().Format(time.RFC3339Nano)))
}
return s.repo.SaveSystemState(ctx, state, s.mode, false, "", fmt.Sprintf(`{"heartbeat":"%s"}`, time.Now().UTC().Format(time.RFC3339Nano)))
}
func (s System) persist(ctx context.Context, state domain.SystemState, halted bool, reason string) error {
return s.repo.SaveSystemState(ctx, state, s.mode, halted, reason, "{}")
}
func legalTransition(from, to domain.SystemState) bool {
if from == to {
return true
}
if to == domain.StateHalted {
return true
}
allowed := map[domain.SystemState][]domain.SystemState{
domain.StateInit: {domain.StateSyncInstruments, domain.StateWaitExitWindow},
domain.StateSyncInstruments: {domain.StateSyncMarketData},
domain.StateSyncMarketData: {domain.StateGenerateSignals},
domain.StateGenerateSignals: {domain.StateWaitEntryWindow},
domain.StateWaitEntryWindow: {domain.StatePlaceEntryOrders, domain.StateSleep},
domain.StatePlaceEntryOrders: {domain.StateMonitorEntryOrders, domain.StateReconcile},
domain.StateMonitorEntryOrders: {domain.StateHoldOvernight, domain.StateReconcile},
domain.StateHoldOvernight: {domain.StateWaitExitWindow},
domain.StateWaitExitWindow: {domain.StatePlaceExitOrders},
domain.StatePlaceExitOrders: {domain.StateMonitorExitOrders, domain.StateReconcile},
domain.StateMonitorExitOrders: {domain.StateReconcile},
domain.StateReconcile: {domain.StateReport, domain.StateHalted},
domain.StateReport: {domain.StateSleep},
domain.StateSleep: {domain.StateInit, domain.StateWaitExitWindow, domain.StateGenerateSignals},
}
for _, candidate := range allowed[from] {
if candidate == to {
return true
}
}
return false
}
+93
View File
@@ -0,0 +1,93 @@
package statemachine
import (
"context"
"errors"
"testing"
"time"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/reconciliation"
"overnight-trading-bot/internal/testutil"
"overnight-trading-bot/internal/tinvest"
)
func TestHeartbeatDoesNotClearHalt(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
system := New(repo, domain.ModeLiveTrade)
if err := system.Halt(ctx, "manual kill switch"); err != nil {
t.Fatal(err)
}
if err := system.Heartbeat(ctx, domain.StateSleep); err != nil {
t.Fatal(err)
}
state, halted, reason, err := repo.GetSystemState(ctx)
if err != nil {
t.Fatal(err)
}
if state != domain.StateHalted || !halted || reason != "manual kill switch" {
t.Fatalf("halt was not sticky: state=%s halted=%v reason=%q", state, halted, reason)
}
}
func TestTransitionBlockedWhileHalted(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
system := New(repo, domain.ModePaper)
if err := system.Halt(ctx, "risk"); err != nil {
t.Fatal(err)
}
err := system.Transition(ctx, domain.StateHalted, domain.StateInit)
if !errors.Is(err, ErrSystemHalted) {
t.Fatalf("expected ErrSystemHalted, got %v", err)
}
}
func TestUnhaltPreservesMode(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
if err := repo.SaveSystemState(ctx, domain.StateHalted, domain.ModeLiveTrade, true, "risk", "{}"); err != nil {
t.Fatal(err)
}
if err := repo.Unhalt(ctx, "checked"); err != nil {
t.Fatal(err)
}
_, halted, _, err := repo.GetSystemState(ctx)
if err != nil {
t.Fatal(err)
}
if halted || repo.Mode != domain.ModeLiveTrade {
t.Fatalf("unhalt did not preserve mode: halted=%v mode=%s", halted, repo.Mode)
}
}
func TestRecoverFromMonitorEntryHaltsOnCriticalReconciliationDiff(t *testing.T) {
ctx := context.Background()
repo := testutil.NewMemoryRepository()
if err := repo.SaveSystemState(ctx, domain.StateMonitorEntryOrders, domain.ModePaper, false, "", "{}"); err != nil {
t.Fatal(err)
}
if err := repo.UpsertOrder(ctx, domain.Order{
ClientOrderID: "local",
BrokerOrderID: "broker-missing",
AccountIDHash: "hash",
InstrumentUID: "uid",
TradeDate: time.Now().UTC(),
Side: domain.SideBuy,
OrderType: domain.OrderTypeLimit,
QuantityLots: 1,
Status: domain.OrderStatusSent,
CreatedAt: time.Now().UTC().Add(-time.Minute),
}); err != nil {
t.Fatal(err)
}
system := New(repo, domain.ModePaper)
state, err := system.Recover(ctx, reconciliation.New(repo, tinvest.NewFakeGateway(), "account", "hash"))
if err == nil {
t.Fatal("expected critical reconciliation error")
}
if state != domain.StateHalted || !repo.Halted {
t.Fatalf("state=%s halted=%v, want HALTED", state, repo.Halted)
}
}
+344
View File
@@ -0,0 +1,344 @@
package testutil
import (
"context"
"fmt"
"sort"
"sync"
"time"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/repository"
)
var _ repository.Repository = (*MemoryRepository)(nil)
type MemoryRepository struct {
mu sync.Mutex
Instruments map[string]domain.Instrument
Daily map[string][]domain.Candle
Minute map[string][]domain.Candle
Features map[string]domain.FeatureSet
Signals map[string]domain.Signal
Orders map[string]domain.Order
Positions map[int64]domain.Position
RiskEvents []domain.RiskEvent
FreeOrders map[string]int
Reports map[string]bool
State domain.SystemState
Mode domain.Mode
Halted bool
HaltReason string
nextPositionID int64
}
func NewMemoryRepository() *MemoryRepository {
return &MemoryRepository{
Instruments: make(map[string]domain.Instrument),
Daily: make(map[string][]domain.Candle),
Minute: make(map[string][]domain.Candle),
Features: make(map[string]domain.FeatureSet),
Signals: make(map[string]domain.Signal),
Orders: make(map[string]domain.Order),
Positions: make(map[int64]domain.Position),
FreeOrders: make(map[string]int),
Reports: make(map[string]bool),
State: domain.StateInit,
Mode: domain.ModePaper,
nextPositionID: 1,
}
}
func (r *MemoryRepository) RunInTx(ctx context.Context, fn func(context.Context, repository.Repository) error) error {
return fn(ctx, r)
}
func (r *MemoryRepository) UpsertInstrument(_ context.Context, instrument domain.Instrument) error {
r.mu.Lock()
defer r.mu.Unlock()
r.Instruments[instrument.InstrumentUID] = instrument
return nil
}
func (r *MemoryRepository) ReplaceInstrument(_ context.Context, oldInstrumentUID string, instrument domain.Instrument) error {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.Instruments, oldInstrumentUID)
r.Instruments[instrument.InstrumentUID] = instrument
return nil
}
func (r *MemoryRepository) ListInstruments(_ context.Context, includeDisabled bool) ([]domain.Instrument, error) {
r.mu.Lock()
defer r.mu.Unlock()
out := make([]domain.Instrument, 0, len(r.Instruments))
for _, instrument := range r.Instruments {
if includeDisabled || instrument.Enabled {
out = append(out, instrument)
}
}
sort.Slice(out, func(i, j int) bool { return out[i].Ticker < out[j].Ticker })
return out, nil
}
func (r *MemoryRepository) QuarantineInstrument(_ context.Context, instrumentUID, reason string) error {
r.mu.Lock()
defer r.mu.Unlock()
instrument := r.Instruments[instrumentUID]
instrument.Quarantine = true
instrument.QuarantineReason = reason
r.Instruments[instrumentUID] = instrument
return nil
}
func (r *MemoryRepository) UpsertDailyCandles(_ context.Context, candles []domain.Candle) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, candle := range candles {
r.Daily[candle.InstrumentUID] = upsertCandle(r.Daily[candle.InstrumentUID], candle, false)
}
return nil
}
func (r *MemoryRepository) ListDailyCandles(_ context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error) {
r.mu.Lock()
defer r.mu.Unlock()
return filterCandles(r.Daily[instrumentUID], from, to), nil
}
func (r *MemoryRepository) UpsertMinuteCandles(_ context.Context, candles []domain.Candle) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, candle := range candles {
r.Minute[candle.InstrumentUID] = upsertCandle(r.Minute[candle.InstrumentUID], candle, true)
}
return nil
}
func (r *MemoryRepository) ListMinuteCandles(_ context.Context, instrumentUID string, from, to time.Time) ([]domain.Candle, error) {
r.mu.Lock()
defer r.mu.Unlock()
return filterCandles(r.Minute[instrumentUID], from, to), nil
}
func (r *MemoryRepository) UpsertFeature(_ context.Context, feature domain.FeatureSet) error {
r.mu.Lock()
defer r.mu.Unlock()
r.Features[featureKey(feature.InstrumentUID, feature.TradeDate)] = feature
return nil
}
func (r *MemoryRepository) GetFeature(_ context.Context, instrumentUID string, tradeDate time.Time) (domain.FeatureSet, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.Features[featureKey(instrumentUID, tradeDate)], nil
}
func (r *MemoryRepository) UpsertSignal(_ context.Context, signal domain.Signal) error {
r.mu.Lock()
defer r.mu.Unlock()
r.Signals[featureKey(signal.InstrumentUID, signal.TradeDate)] = signal
return nil
}
func (r *MemoryRepository) ListSignals(_ context.Context, tradeDate time.Time) ([]domain.Signal, error) {
r.mu.Lock()
defer r.mu.Unlock()
var out []domain.Signal
for _, signal := range r.Signals {
if sameDate(signal.TradeDate, tradeDate) {
out = append(out, signal)
}
}
sort.Slice(out, func(i, j int) bool { return out[i].InstrumentUID < out[j].InstrumentUID })
return out, nil
}
func (r *MemoryRepository) UpsertOrder(_ context.Context, order domain.Order) error {
r.mu.Lock()
defer r.mu.Unlock()
r.Orders[order.ClientOrderID] = order
return nil
}
func (r *MemoryRepository) UpdateOrderStatus(_ context.Context, clientOrderID string, status domain.OrderStatus, filledLots int64, rawJSON string) error {
r.mu.Lock()
defer r.mu.Unlock()
order := r.Orders[clientOrderID]
order.Status = status
order.FilledLots = filledLots
order.RawStateJSON = rawJSON
r.Orders[clientOrderID] = order
return nil
}
func (r *MemoryRepository) ListActiveOrders(_ context.Context, accountIDHash string) ([]domain.Order, error) {
r.mu.Lock()
defer r.mu.Unlock()
var out []domain.Order
for _, order := range r.Orders {
if order.AccountIDHash == accountIDHash && (order.Status == domain.OrderStatusNew || order.Status == domain.OrderStatusSent || order.Status == domain.OrderStatusPartiallyFilled) {
out = append(out, order)
}
}
return out, nil
}
func (r *MemoryRepository) ListOrders(_ context.Context, accountIDHash string, from, to time.Time) ([]domain.Order, error) {
r.mu.Lock()
defer r.mu.Unlock()
var out []domain.Order
for _, order := range r.Orders {
if order.AccountIDHash == accountIDHash && !order.TradeDate.Before(dateOnly(from)) && !order.TradeDate.After(dateOnly(to)) {
out = append(out, order)
}
}
return out, nil
}
func (r *MemoryRepository) UpsertPosition(_ context.Context, position domain.Position) error {
r.mu.Lock()
defer r.mu.Unlock()
for id, existing := range r.Positions {
if existing.AccountIDHash == position.AccountIDHash &&
existing.InstrumentUID == position.InstrumentUID &&
sameDate(existing.OpenTradeDate, position.OpenTradeDate) {
position.ID = id
r.Positions[id] = position
return nil
}
}
if position.ID == 0 {
position.ID = r.nextPositionID
r.nextPositionID++
}
r.Positions[position.ID] = position
return nil
}
func (r *MemoryRepository) ListOpenPositions(_ context.Context, accountIDHash string) ([]domain.Position, error) {
r.mu.Lock()
defer r.mu.Unlock()
var out []domain.Position
for _, pos := range r.Positions {
if pos.AccountIDHash == accountIDHash && pos.Status != domain.PositionNoPosition && pos.Status != domain.PositionExitFilled && pos.Status != domain.PositionQuarantine {
out = append(out, pos)
}
}
return out, nil
}
func (r *MemoryRepository) ListPositions(_ context.Context, accountIDHash string, from, to time.Time) ([]domain.Position, error) {
r.mu.Lock()
defer r.mu.Unlock()
var out []domain.Position
for _, pos := range r.Positions {
if pos.AccountIDHash == accountIDHash && !pos.OpenTradeDate.Before(dateOnly(from)) && !pos.OpenTradeDate.After(dateOnly(to)) {
out = append(out, pos)
}
}
return out, nil
}
func (r *MemoryRepository) InsertRiskEvent(_ context.Context, event domain.RiskEvent) error {
r.mu.Lock()
defer r.mu.Unlock()
r.RiskEvents = append(r.RiskEvents, event)
return nil
}
func (r *MemoryRepository) GetFreeOrdersSent(_ context.Context, tradeDate time.Time, instrumentUID string) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.FreeOrders[featureKey(instrumentUID, tradeDate)], nil
}
func (r *MemoryRepository) IncrementFreeOrders(_ context.Context, tradeDate time.Time, instrumentUID string, delta int) error {
r.mu.Lock()
defer r.mu.Unlock()
r.FreeOrders[featureKey(instrumentUID, tradeDate)] += delta
return nil
}
func (r *MemoryRepository) GetSystemState(_ context.Context) (domain.SystemState, bool, string, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.State, r.Halted, r.HaltReason, nil
}
func (r *MemoryRepository) SaveSystemState(_ context.Context, state domain.SystemState, mode domain.Mode, halted bool, reason string, _ string) error {
r.mu.Lock()
defer r.mu.Unlock()
r.State = state
r.Mode = mode
r.Halted = halted
r.HaltReason = reason
return nil
}
func (r *MemoryRepository) Unhalt(_ context.Context, reason string) error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.Halted && r.State != domain.StateHalted {
return fmt.Errorf("system is not halted")
}
r.RiskEvents = append(r.RiskEvents, domain.RiskEvent{Severity: domain.SeverityInfo, EventType: "manual_unhalt", Message: reason})
r.State = domain.StateInit
r.Halted = false
r.HaltReason = ""
return nil
}
func (r *MemoryRepository) WasDailyReportSent(_ context.Context, reportDate time.Time, accountIDHash string) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.Reports[accountIDHash+"|"+dateOnly(reportDate).Format("2006-01-02")], nil
}
func (r *MemoryRepository) MarkDailyReportSent(_ context.Context, reportDate time.Time, accountIDHash string) error {
r.mu.Lock()
defer r.mu.Unlock()
r.Reports[accountIDHash+"|"+dateOnly(reportDate).Format("2006-01-02")] = true
return nil
}
func (r *MemoryRepository) InsertReconciliation(_ context.Context, _ time.Time, _ string, _ bool) error {
return nil
}
func upsertCandle(candles []domain.Candle, candle domain.Candle, minute bool) []domain.Candle {
for i, existing := range candles {
if (!minute && sameDate(existing.TradeDate, candle.TradeDate)) || (minute && existing.TradeDate.Equal(candle.TradeDate)) {
candles[i] = candle
return candles
}
}
return append(candles, candle)
}
func filterCandles(candles []domain.Candle, from, to time.Time) []domain.Candle {
var out []domain.Candle
for _, candle := range candles {
if !candle.TradeDate.Before(from) && !candle.TradeDate.After(to) {
out = append(out, candle)
}
}
sort.Slice(out, func(i, j int) bool { return out[i].TradeDate.Before(out[j].TradeDate) })
return out
}
func featureKey(instrumentUID string, date time.Time) string {
return instrumentUID + "|" + dateOnly(date).Format("2006-01-02")
}
func sameDate(a, b time.Time) bool {
return dateOnly(a).Equal(dateOnly(b))
}
func dateOnly(t time.Time) time.Time {
y, m, d := t.UTC().Date()
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
}
+93
View File
@@ -0,0 +1,93 @@
package timeutil
import (
"fmt"
"time"
)
type Clock interface {
Now() time.Time
Sleep(ctxDone <-chan struct{}, d time.Duration) bool
}
type RealClock struct {
Loc *time.Location
}
func (c RealClock) Now() time.Time {
now := time.Now()
if c.Loc != nil {
return now.In(c.Loc)
}
return now
}
func (c RealClock) Sleep(ctxDone <-chan struct{}, d time.Duration) bool {
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-timer.C:
return true
case <-ctxDone:
return false
}
}
type TimeOfDay struct {
Duration time.Duration
}
func ParseTimeOfDay(raw string) (TimeOfDay, error) {
parsed, err := time.Parse("15:04:05", raw)
if err != nil {
return TimeOfDay{}, fmt.Errorf("parse time of day %q: %w", raw, err)
}
return TimeOfDay{
Duration: time.Duration(parsed.Hour())*time.Hour +
time.Duration(parsed.Minute())*time.Minute +
time.Duration(parsed.Second())*time.Second,
}, nil
}
func (t *TimeOfDay) UnmarshalText(text []byte) error {
parsed, err := ParseTimeOfDay(string(text))
if err != nil {
return err
}
*t = parsed
return nil
}
func (t TimeOfDay) String() string {
total := int64(t.Duration.Seconds())
h := total / 3600
m := (total % 3600) / 60
s := total % 60
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}
func (t TimeOfDay) On(date time.Time, loc *time.Location) time.Time {
local := date.In(loc)
y, m, d := local.Date()
midnight := time.Date(y, m, d, 0, 0, 0, 0, loc)
return midnight.Add(t.Duration)
}
type Window struct {
Start TimeOfDay
End TimeOfDay
}
func (w Window) Contains(now time.Time, loc *time.Location) bool {
start := w.Start.On(now, loc)
end := w.End.On(now, loc)
return !now.Before(start) && now.Before(end)
}
func Drift(local, server time.Time) time.Duration {
d := local.Sub(server)
if d < 0 {
return -d
}
return d
}
+182
View File
@@ -0,0 +1,182 @@
package tinvest
import (
"context"
"errors"
"sync"
"time"
"github.com/shopspring/decimal"
"overnight-trading-bot/internal/domain"
)
var ErrNotFound = errors.New("not found")
type Gateway interface {
GetInstrument(ctx context.Context, ticker, classCode string) (domain.Instrument, error)
GetCandles(ctx context.Context, instrumentUID string, interval string, from, to time.Time) ([]domain.Candle, error)
GetOrderBook(ctx context.Context, instrumentUID string, depth int32) (domain.OrderBook, error)
GetTradingStatus(ctx context.Context, instrumentUID string) (domain.TradingStatus, error)
PostLimitOrder(ctx context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error)
CancelOrder(ctx context.Context, accountID, orderID string) error
GetOrderState(ctx context.Context, accountID, orderID string) (domain.Order, error)
GetActiveOrders(ctx context.Context, accountID string) ([]domain.Order, error)
GetPortfolio(ctx context.Context, accountID string) (domain.Portfolio, error)
GetOperations(ctx context.Context, accountID string, from, to time.Time) ([]domain.Operation, error)
GetServerTime(ctx context.Context) (time.Time, error)
}
type FakeGateway struct {
mu sync.Mutex
Instruments map[string]domain.Instrument
Candles map[string][]domain.Candle
OrderBooks map[string]domain.OrderBook
Statuses map[string]domain.TradingStatus
Orders map[string]domain.Order
Portfolio domain.Portfolio
Operations []domain.Operation
ServerTime time.Time
}
func NewFakeGateway() *FakeGateway {
return &FakeGateway{
Instruments: make(map[string]domain.Instrument),
Candles: make(map[string][]domain.Candle),
OrderBooks: make(map[string]domain.OrderBook),
Statuses: make(map[string]domain.TradingStatus),
Orders: make(map[string]domain.Order),
Portfolio: domain.Portfolio{
Equity: decimal.NewFromInt(100_000),
Cash: decimal.NewFromInt(100_000),
CheckedAt: time.Now().UTC(),
},
}
}
func (f *FakeGateway) GetInstrument(_ context.Context, ticker, classCode string) (domain.Instrument, error) {
f.mu.Lock()
defer f.mu.Unlock()
for _, instrument := range f.Instruments {
if instrument.Ticker == ticker && instrument.ClassCode == classCode {
return instrument, nil
}
}
return domain.Instrument{}, ErrNotFound
}
func (f *FakeGateway) GetCandles(_ context.Context, instrumentUID string, _ string, from, to time.Time) ([]domain.Candle, error) {
f.mu.Lock()
defer f.mu.Unlock()
var out []domain.Candle
for _, candle := range f.Candles[instrumentUID] {
if !candle.TradeDate.Before(from) && !candle.TradeDate.After(to) {
out = append(out, candle)
}
}
return out, nil
}
func (f *FakeGateway) GetOrderBook(_ context.Context, instrumentUID string, _ int32) (domain.OrderBook, error) {
f.mu.Lock()
defer f.mu.Unlock()
book, ok := f.OrderBooks[instrumentUID]
if !ok {
return domain.OrderBook{}, ErrNotFound
}
return book, nil
}
func (f *FakeGateway) GetTradingStatus(_ context.Context, instrumentUID string) (domain.TradingStatus, error) {
f.mu.Lock()
defer f.mu.Unlock()
status, ok := f.Statuses[instrumentUID]
if !ok {
return domain.TradingStatusNormal, nil
}
return status, nil
}
func (f *FakeGateway) PostLimitOrder(_ context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error) {
f.mu.Lock()
defer f.mu.Unlock()
order := domain.Order{
ClientOrderID: clientOrderID,
BrokerOrderID: "fake-" + clientOrderID,
AccountIDHash: accountID,
InstrumentUID: instrumentUID,
Side: side,
OrderType: domain.OrderTypeLimit,
LimitPrice: price,
QuantityLots: lots,
Status: domain.OrderStatusSent,
RawStateJSON: "{}",
CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
f.Orders[order.BrokerOrderID] = order
return order, nil
}
func (f *FakeGateway) CancelOrder(_ context.Context, _ string, orderID string) error {
f.mu.Lock()
defer f.mu.Unlock()
order, ok := f.Orders[orderID]
if !ok {
return ErrNotFound
}
order.Status = domain.OrderStatusCancelled
order.UpdatedAt = time.Now().UTC()
f.Orders[orderID] = order
return nil
}
func (f *FakeGateway) GetOrderState(_ context.Context, _ string, orderID string) (domain.Order, error) {
f.mu.Lock()
defer f.mu.Unlock()
order, ok := f.Orders[orderID]
if !ok {
return domain.Order{}, ErrNotFound
}
return order, nil
}
func (f *FakeGateway) GetActiveOrders(_ context.Context, _ string) ([]domain.Order, error) {
f.mu.Lock()
defer f.mu.Unlock()
out := make([]domain.Order, 0)
for _, order := range f.Orders {
if order.Status == domain.OrderStatusSent || order.Status == domain.OrderStatusPartiallyFilled {
out = append(out, order)
}
}
return out, nil
}
func (f *FakeGateway) GetPortfolio(_ context.Context, _ string) (domain.Portfolio, error) {
f.mu.Lock()
defer f.mu.Unlock()
f.Portfolio.CheckedAt = time.Now().UTC()
return f.Portfolio, nil
}
func (f *FakeGateway) GetOperations(_ context.Context, _ string, from, to time.Time) ([]domain.Operation, error) {
f.mu.Lock()
defer f.mu.Unlock()
var out []domain.Operation
for _, op := range f.Operations {
if !op.ExecutedAt.Before(from) && !op.ExecutedAt.After(to) {
out = append(out, op)
}
}
return out, nil
}
func (f *FakeGateway) GetServerTime(context.Context) (time.Time, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.ServerTime.IsZero() {
return time.Now().UTC(), nil
}
return f.ServerTime, nil
}
+456
View File
@@ -0,0 +1,456 @@
package tinvest
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/russianinvestments/invest-api-go-sdk/investgo"
pb "github.com/russianinvestments/invest-api-go-sdk/proto"
"github.com/shopspring/decimal"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"overnight-trading-bot/internal/domain"
"overnight-trading-bot/internal/logging"
"overnight-trading-bot/internal/money"
)
type Options struct {
Token string
AccountID string
Endpoint string
AppName string
RetryCount int
RetryBackoff time.Duration
Logger *slog.Logger
}
type RealGateway struct {
client *investgo.Client
instruments *investgo.InstrumentsServiceClient
marketData *investgo.MarketDataServiceClient
orders *investgo.OrdersServiceClient
operations *investgo.OperationsServiceClient
users *investgo.UsersServiceClient
retryAttempts int
retryBackoff time.Duration
}
func NewRealGateway(ctx context.Context, opts Options) (*RealGateway, error) {
if opts.Token == "" {
return nil, fmt.Errorf("tinvest token is required")
}
client, err := investgo.NewClient(ctx, investgo.Config{
EndPoint: opts.Endpoint,
Token: opts.Token,
AppName: opts.AppName,
AccountId: opts.AccountID,
MaxRetries: 0,
}, logging.SDKLogger{Logger: opts.Logger})
if err != nil {
return nil, err
}
return &RealGateway{
client: client,
instruments: client.NewInstrumentsServiceClient(),
marketData: client.NewMarketDataServiceClient(),
orders: client.NewOrdersServiceClient(),
operations: client.NewOperationsServiceClient(),
users: client.NewUsersServiceClient(),
retryAttempts: opts.RetryCount,
retryBackoff: opts.RetryBackoff,
}, nil
}
func (g *RealGateway) Close() error {
if g.client == nil || g.client.Conn == nil {
return nil
}
return g.client.Conn.Close()
}
func (g *RealGateway) GetInstrument(ctx context.Context, ticker, classCode string) (domain.Instrument, error) {
if err := ctx.Err(); err != nil {
return domain.Instrument{}, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.EtfResponse, error) {
return g.instruments.EtfByTicker(ticker, classCode)
})
if err != nil {
return domain.Instrument{}, err
}
etf := resp.GetInstrument()
if etf == nil {
return domain.Instrument{}, ErrNotFound
}
return domain.Instrument{
InstrumentUID: etf.GetUid(),
Figi: etf.GetFigi(),
Ticker: etf.GetTicker(),
ClassCode: etf.GetClassCode(),
Name: etf.GetName(),
Lot: int64(etf.GetLot()),
MinPriceIncrement: money.QuotationToDecimal(etf.GetMinPriceIncrement()),
Currency: strings.ToUpper(etf.GetCurrency()),
Enabled: etf.GetApiTradeAvailableFlag() && etf.GetBuyAvailableFlag() && etf.GetSellAvailableFlag(),
UpdatedAt: time.Now().UTC(),
}, nil
}
func (g *RealGateway) GetCandles(ctx context.Context, instrumentUID string, interval string, from, to time.Time) ([]domain.Candle, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetCandlesResponse, error) {
return g.marketData.GetCandles(instrumentUID, candleInterval(interval), from, to, pb.GetCandlesRequest_CANDLE_SOURCE_EXCHANGE, 0)
})
if err != nil {
return nil, err
}
candles := resp.GetCandles()
out := make([]domain.Candle, 0, len(candles))
for _, candle := range candles {
out = append(out, domain.Candle{
InstrumentUID: instrumentUID,
TradeDate: candle.GetTime().AsTime().UTC(),
Open: money.QuotationToDecimal(candle.GetOpen()),
High: money.QuotationToDecimal(candle.GetHigh()),
Low: money.QuotationToDecimal(candle.GetLow()),
Close: money.QuotationToDecimal(candle.GetClose()),
VolumeLots: decimal.NewFromInt(candle.GetVolume()),
Source: "tinvest",
LoadedAt: time.Now().UTC(),
})
}
return out, nil
}
func (g *RealGateway) GetOrderBook(ctx context.Context, instrumentUID string, depth int32) (domain.OrderBook, error) {
if err := ctx.Err(); err != nil {
return domain.OrderBook{}, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderBookResponse, error) {
return g.marketData.GetOrderBook(instrumentUID, depth)
})
if err != nil {
return domain.OrderBook{}, err
}
return domain.OrderBook{
InstrumentUID: instrumentUID,
Bids: orderBookLevels(resp.GetBids()),
Asks: orderBookLevels(resp.GetAsks()),
Time: resp.GetOrderbookTs().AsTime().UTC(),
ReceivedAt: time.Now().UTC(),
}, nil
}
func (g *RealGateway) GetTradingStatus(ctx context.Context, instrumentUID string) (domain.TradingStatus, error) {
if err := ctx.Err(); err != nil {
return domain.TradingStatusUnknown, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetTradingStatusResponse, error) {
return g.marketData.GetTradingStatus(instrumentUID)
})
if err != nil {
return domain.TradingStatusUnknown, err
}
if resp.GetTradingStatus() == pb.SecurityTradingStatus_SECURITY_TRADING_STATUS_NORMAL_TRADING &&
resp.GetLimitOrderAvailableFlag() &&
resp.GetApiTradeAvailableFlag() {
return domain.TradingStatusNormal, nil
}
return domain.TradingStatusClosed, nil
}
func (g *RealGateway) PostLimitOrder(ctx context.Context, accountID, instrumentUID string, side domain.Side, lots int64, price decimal.Decimal, clientOrderID string) (domain.Order, error) {
if err := ctx.Err(); err != nil {
return domain.Order{}, err
}
direction := pb.OrderDirection_ORDER_DIRECTION_BUY
if side == domain.SideSell {
direction = pb.OrderDirection_ORDER_DIRECTION_SELL
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PostOrderResponse, error) {
return g.orders.PostOrder(&investgo.PostOrderRequest{
InstrumentId: instrumentUID,
Quantity: lots,
Price: money.DecimalToQuotation(price),
Direction: direction,
AccountId: accountID,
OrderType: pb.OrderType_ORDER_TYPE_LIMIT,
OrderId: clientOrderID,
TimeInForce: pb.TimeInForceType_TIME_IN_FORCE_DAY,
PriceType: pb.PriceType_PRICE_TYPE_CURRENCY,
})
})
if err != nil {
return domain.Order{}, err
}
return orderFromPostResponse(resp.PostOrderResponse, accountID, clientOrderID, side, price), nil
}
func (g *RealGateway) CancelOrder(ctx context.Context, accountID, orderID string) error {
if err := ctx.Err(); err != nil {
return err
}
return withRetry(ctx, g.retryAttempts, g.retryBackoff, func() error {
_, err := g.orders.CancelOrder(accountID, orderID, nil)
return err
})
}
func (g *RealGateway) GetOrderState(ctx context.Context, accountID, orderID string) (domain.Order, error) {
if err := ctx.Err(); err != nil {
return domain.Order{}, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrderStateResponse, error) {
return g.orders.GetOrderState(accountID, orderID, pb.PriceType_PRICE_TYPE_CURRENCY, nil)
})
if err != nil {
return domain.Order{}, err
}
return orderFromState(resp.OrderState, accountID), nil
}
func (g *RealGateway) GetActiveOrders(ctx context.Context, accountID string) ([]domain.Order, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetOrdersResponse, error) {
return g.orders.GetOrders(accountID, nil)
})
if err != nil {
return nil, err
}
states := resp.GetOrders()
out := make([]domain.Order, 0, len(states))
for _, state := range states {
out = append(out, orderFromState(state, accountID))
}
return out, nil
}
func (g *RealGateway) GetPortfolio(ctx context.Context, accountID string) (domain.Portfolio, error) {
if err := ctx.Err(); err != nil {
return domain.Portfolio{}, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.PortfolioResponse, error) {
return g.operations.GetPortfolio(accountID, pb.PortfolioRequest_RUB)
})
if err != nil {
return domain.Portfolio{}, err
}
positions := resp.GetPositions()
holdings := make([]domain.Holding, 0, len(positions))
for _, position := range positions {
holdings = append(holdings, domain.Holding{
InstrumentUID: position.GetInstrumentUid(),
QuantityLots: money.QuotationToDecimal(position.GetQuantity()).IntPart(),
AveragePrice: money.MoneyValueToDecimal(position.GetAveragePositionPrice()),
MarketValue: money.MoneyValueToDecimal(position.GetCurrentPrice()).Mul(money.QuotationToDecimal(position.GetQuantity())),
})
}
equity, err := rubMoneyValueToDecimal(resp.GetTotalAmountPortfolio())
if err != nil {
return domain.Portfolio{}, err
}
cash, err := rubMoneyValueToDecimal(resp.GetTotalAmountCurrencies())
if err != nil {
return domain.Portfolio{}, err
}
return domain.Portfolio{
Equity: equity,
Cash: cash,
Holdings: holdings,
CheckedAt: time.Now().UTC(),
}, nil
}
func (g *RealGateway) GetOperations(ctx context.Context, accountID string, from, to time.Time) ([]domain.Operation, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.OperationsResponse, error) {
return g.operations.GetOperations(&investgo.GetOperationsRequest{
AccountId: accountID,
From: from,
To: to,
})
})
if err != nil {
return nil, err
}
ops := resp.GetOperations()
out := make([]domain.Operation, 0, len(ops))
for _, op := range ops {
payment := money.MoneyValueToDecimal(op.GetPayment())
out = append(out, domain.Operation{
ID: op.GetId(),
InstrumentUID: op.GetInstrumentUid(),
Type: op.GetOperationType().String(),
Payment: payment,
Commission: operationCommission(op.GetOperationType(), payment),
ExecutedAt: op.GetDate().AsTime().UTC(),
})
}
return out, nil
}
func (g *RealGateway) GetServerTime(ctx context.Context) (time.Time, error) {
if err := ctx.Err(); err != nil {
return time.Time{}, err
}
resp, err := retryValue(ctx, g.retryAttempts, g.retryBackoff, func() (*investgo.GetInfoResponse, error) {
return g.users.GetInfo()
})
if err != nil {
return time.Time{}, err
}
if serverTime, ok := serverTimeFromHeader(resp.Header); ok {
return serverTime, nil
}
return time.Time{}, errors.New("server time is unavailable in response metadata")
}
func operationCommission(operationType pb.OperationType, payment decimal.Decimal) decimal.Decimal {
if operationType != pb.OperationType_OPERATION_TYPE_BROKER_FEE &&
operationType != pb.OperationType_OPERATION_TYPE_SERVICE_FEE &&
operationType != pb.OperationType_OPERATION_TYPE_SUCCESS_FEE {
return decimal.Zero
}
return money.Abs(payment)
}
func rubMoneyValueToDecimal(value *pb.MoneyValue) (decimal.Decimal, error) {
if value == nil {
return decimal.Zero, nil
}
if currency := strings.ToUpper(value.GetCurrency()); currency != "" && currency != "RUB" {
return decimal.Zero, fmt.Errorf("expected RUB money value, got %s", currency)
}
return money.MoneyValueToDecimal(value), nil
}
func serverTimeFromHeader(header map[string][]string) (time.Time, bool) {
for _, key := range []string{"date", "Date"} {
values := header[key]
if len(values) == 0 {
continue
}
parsed, err := http.ParseTime(values[0])
if err == nil {
return parsed.UTC(), true
}
}
return time.Time{}, false
}
func candleInterval(interval string) pb.CandleInterval {
switch strings.ToLower(interval) {
case "minute", "1m", "1min":
return pb.CandleInterval_CANDLE_INTERVAL_1_MIN
default:
return pb.CandleInterval_CANDLE_INTERVAL_DAY
}
}
func orderBookLevels(levels []*pb.Order) []domain.OrderBookLevel {
out := make([]domain.OrderBookLevel, 0, len(levels))
for _, level := range levels {
out = append(out, domain.OrderBookLevel{
Price: money.QuotationToDecimal(level.GetPrice()),
QuantityLots: level.GetQuantity(),
})
}
return out
}
func orderFromPostResponse(resp *pb.PostOrderResponse, accountID, clientOrderID string, side domain.Side, limitPrice decimal.Decimal) domain.Order {
if resp == nil {
return domain.Order{}
}
now := time.Now().UTC()
return domain.Order{
ClientOrderID: clientOrderID,
BrokerOrderID: resp.GetOrderId(),
AccountIDHash: accountID,
InstrumentUID: resp.GetInstrumentUid(),
Side: side,
OrderType: domain.OrderTypeLimit,
LimitPrice: limitPrice,
QuantityLots: resp.GetLotsRequested(),
FilledLots: resp.GetLotsExecuted(),
AvgFillPrice: limitPrice,
Status: mapOrderStatus(resp.GetExecutionReportStatus()),
Commission: money.MoneyValueToDecimal(resp.GetExecutedCommission()),
RawStateJSON: marshalProto(resp),
CreatedAt: now,
UpdatedAt: now,
}
}
func orderFromState(state *pb.OrderState, accountID string) domain.Order {
if state == nil {
return domain.Order{}
}
side := domain.SideBuy
if state.GetDirection() == pb.OrderDirection_ORDER_DIRECTION_SELL {
side = domain.SideSell
}
orderDate := time.Now().UTC()
if state.GetOrderDate() != nil {
orderDate = state.GetOrderDate().AsTime().UTC()
}
return domain.Order{
ClientOrderID: state.GetOrderRequestId(),
BrokerOrderID: state.GetOrderId(),
AccountIDHash: accountID,
InstrumentUID: state.GetInstrumentUid(),
Side: side,
OrderType: domain.OrderTypeLimit,
LimitPrice: money.MoneyValueToDecimal(state.GetInitialSecurityPrice()),
QuantityLots: state.GetLotsRequested(),
FilledLots: state.GetLotsExecuted(),
AvgFillPrice: money.MoneyValueToDecimal(state.GetAveragePositionPrice()),
Status: mapOrderStatus(state.GetExecutionReportStatus()),
Commission: money.MoneyValueToDecimal(state.GetExecutedCommission()),
RawStateJSON: marshalProto(state),
CreatedAt: orderDate,
UpdatedAt: time.Now().UTC(),
}
}
func mapOrderStatus(status pb.OrderExecutionReportStatus) domain.OrderStatus {
switch status {
case pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_FILL:
return domain.OrderStatusFilled
case pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_PARTIALLYFILL:
return domain.OrderStatusPartiallyFilled
case pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_CANCELLED:
return domain.OrderStatusCancelled
case pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_REJECTED:
return domain.OrderStatusRejected
case pb.OrderExecutionReportStatus_EXECUTION_REPORT_STATUS_NEW:
return domain.OrderStatusSent
default:
return domain.OrderStatusNew
}
}
func marshalProto(msg proto.Message) string {
if msg == nil {
return "{}"
}
raw, err := protojson.Marshal(msg)
if err != nil {
fallback, _ := json.Marshal(map[string]string{"marshal_error": err.Error()})
return string(fallback)
}
return string(raw)
}
+64
View File
@@ -0,0 +1,64 @@
package tinvest
import (
"context"
"time"
backofflib "github.com/cenkalti/backoff/v4"
)
func withRetry(ctx context.Context, attempts int, interval time.Duration, fn func() error) error {
if attempts <= 0 {
attempts = 1
}
if interval < 0 {
interval = 0
}
policy := backofflib.NewExponentialBackOff()
policy.InitialInterval = interval
policy.MaxInterval = interval * 8
policy.Multiplier = 2
policy.MaxElapsedTime = 0
policy.Reset()
var lastErr error
for attempt := 0; attempt < attempts; attempt++ {
if err := ctx.Err(); err != nil {
return err
}
if err := fn(); err != nil {
lastErr = err
} else {
return nil
}
if attempt == attempts-1 || interval <= 0 {
continue
}
timer := time.NewTimer(policy.NextBackOff())
select {
case <-timer.C:
case <-ctx.Done():
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
return ctx.Err()
}
}
return lastErr
}
func retryValue[T any](ctx context.Context, attempts int, interval time.Duration, fn func() (T, error)) (T, error) {
var out T
err := withRetry(ctx, attempts, interval, func() error {
var err error
out, err = fn()
return err
})
if err != nil {
var zero T
return zero, err
}
return out, nil
}
+41
View File
@@ -0,0 +1,41 @@
package tinvest
import (
"context"
"errors"
"testing"
"time"
)
func TestWithRetryRetriesUntilSuccess(t *testing.T) {
attempts := 0
err := withRetry(context.Background(), 3, 0, func() error {
attempts++
if attempts < 3 {
return errors.New("temporary")
}
return nil
})
if err != nil {
t.Fatal(err)
}
if attempts != 3 {
t.Fatalf("attempts=%d, want 3", attempts)
}
}
func TestWithRetryStopsOnContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
attempts := 0
err := withRetry(ctx, 3, time.Millisecond, func() error {
attempts++
return errors.New("temporary")
})
if !errors.Is(err, context.Canceled) {
t.Fatalf("err=%v, want context.Canceled", err)
}
if attempts != 0 {
t.Fatalf("attempts=%d, want 0", attempts)
}
}
+10
View File
@@ -0,0 +1,10 @@
package tinvest
import "context"
const sandboxEndpoint = "sandbox-invest-public-api.tinkoff.ru:443"
func NewSandboxGateway(ctx context.Context, opts Options) (*RealGateway, error) {
opts.Endpoint = sandboxEndpoint
return NewRealGateway(ctx, opts)
}